If you use Gitkraken, immediately update to version 8.1 (or later) remove your SSH key from https://gitlab.ai.vub.ac.be/-/profile/keys and generate a new one. SSH keys generated with a vulnerable Gitkraken version are compromised.

Commit 258067be authored by Mathieu Reymond's avatar Mathieu Reymond
Browse files

change normalization, rename to PDQN, pessimistic q-network initialization

parent 55e28cdd
......@@ -38,7 +38,8 @@ def dst_non_dominated(env, normalize):
# if t_pos[1] >= pos_a[1]:
# manhattan distance = fuel consumption
fuel = -np.sum(np.absolute(np.array(t_pos)-pos_a))
non_dominated[s][a_i].append(np.array([t, fuel])/normalize)
non_dominated[s][a_i].append((np.array([t, fuel]) - normalize['min'])/normalize['scale'])
# if position is treasure position, no other points are on the pareto front (end of episode)
if np.all(t_pos == pos_a):
......@@ -171,7 +172,7 @@ class Estimator(object):
return l
class DPQN(Agent):
class PDQN(Agent):
def __init__(self, env,
policy=None,
......@@ -315,7 +316,7 @@ class DPQN(Agent):
next_observation=next_obs,
terminal=terminal)
self.memory.add(t)
if log.total_steps >= self.learn_start:
if log.total_steps >= self.batch_size: # self.learn_start:
batch = self.memory.sample(self.batch_size)
# normalize reward for pareto_estimator
......@@ -340,8 +341,12 @@ class DPQN(Agent):
batch_actions = []
# add pareto front for each sample
for i, b_i in enumerate(batch_q_front_next):
# TODO TEST before learn start, initialize net pessimistically
if log.total_steps < self.learn_start:
non_dominated = b_i
non_dominated[:, -1] = -1.
# if state-action leads to terminal next_state, next_state has no pareto front, only immediate reward
if batch.terminal[i]:
elif batch.terminal[i]:
min_ = np.abs(b_i[:, 0] - batch_rew_norm[i][0]).argmin()
non_dominated = b_i
non_dominated[:min_, -1] = batch_rew_norm[i][-1]
......@@ -351,7 +356,8 @@ class DPQN(Agent):
else:
non_dominated = b_i
# non_dominated = get_non_dominated(b_i) # + batch.reward[i].reshape(1, -1)
non_dominated += np.array([0, 1]).reshape(1, 2)
# TODO why all the time (<- indent left)
non_dominated += np.array([0, 1]).reshape(1, 2)
batch_non_dominated.append(non_dominated)
# batch_non_dominated = np.concatenate((batch_non_dominated, non_dominated))
observation = np.tile(batch_observation[i], (len(non_dominated), *([1]*len(batch_observation[0].shape))))
......@@ -389,6 +395,7 @@ class DPQN(Agent):
use_target_network=False)
for oi in range(self.nO):
r_diff = np.mean((r_pred[:, oi] - batch.reward[:, oi])**2)
self.writer.add_scalar(f'rew_{oi}_pred', np.mean(r_pred[:, oi]), log.total_steps)
self.writer.add_scalar('rew_{}_pred_diff'.format(oi), r_diff, log.total_steps)
return {'observation': next_obs,
......@@ -412,14 +419,21 @@ class DPQN(Agent):
list(range(78, 80))+ \
list(range(88, 90))+ \
list(range(99, 100))
# estimate pareto front for states 0 to 5
# estimate pareto front for all states
obs = np.array([]).reshape(0, self.env.nS)
for s in plot_states:
obs = np.concatenate((obs, np.expand_dims(self.observe(s), 0)))
q_fronts = self.q_front(obs, self.n_samples)
# undo pessimistic bias
q_fronts += np.array([0, 1]).reshape(1, 1, 1, 2)
# unnormalize reward
q_fronts = q_fronts * self.normalize_reward['scale'].reshape(1, 1, 1, 2) + self.normalize_reward['min'].reshape(1, 1, 1, 2)
ref_point = np.array([-1, -2])
hypervolume = compute_hypervolume(q_fronts[0], self.env.nA, ref_point)
try:
ref_point = np.array([-2, -2])
hypervolume = compute_hypervolume(q_fronts[0], self.env.nA, ref_point)
except ValueError:
hypervolume = 0
act = 2
fig, axes = plt.subplots(11, 10, sharex='col', sharey='row',
......@@ -428,13 +442,15 @@ class DPQN(Agent):
fig.subplots_adjust(wspace=0, hspace=0)
for s in range(len(plot_states)):
ax = axes[np.unravel_index(plot_states[s], (11, 10))]
x = q_fronts[s, act, :, 0]*self.normalize_reward[0]
y = q_fronts[s, act, :, 1]*self.normalize_reward[1]
x = q_fronts[s, act, :, 0]
y = q_fronts[s, act, :, 1]
ax.plot(x, y)
# true pareto front
true_xy = true_non_dominated[plot_states[s]][act]
ax.plot(true_xy[:, 0]*self.normalize_reward[0], true_xy[:, 1]*self.normalize_reward[1], '+')
true_xy = true_xy * self.normalize_reward['scale'].reshape(1, 2) + self.normalize_reward['min'].reshape(1, 2)
ax.plot(true_xy[:, 0], true_xy[:, 1], '+')
for s in range(self.env.nS):
if unreachable(s):
ax = axes[np.unravel_index(s, (11, 10))]
......@@ -667,10 +683,10 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description='pareto dqn')
parser.add_argument('--lr-reward', default=1e-3, type=float)
parser.add_argument('--lr-pareto', default=1e-3, type=float)
parser.add_argument('--lr-pareto', default=3e-4, type=float)
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--copy-reward', default=2000, type=int)
parser.add_argument('--copy-pareto', default=2000, type=int)
parser.add_argument('--copy-reward', default=1, type=int)
parser.add_argument('--copy-pareto', default=100, type=int)
parser.add_argument('--mem-size', default=250000, type=int)
parser.add_argument('--normalize', action='store_true')
parser.add_argument('--epsilon-decrease', default=0.999, type=float)
......@@ -700,13 +716,13 @@ if __name__ == '__main__':
par_est = Estimator(par_model, lr=args.lr_pareto, copy_every=args.copy_pareto)
memory = Memory((env.nS,), size=args.mem_size, nO=nO)
ref_point = np.array([-1, -2])
normalize = np.array([124., 19.]) if args.normalize else None
ref_point = np.array([-2, -2])
normalize = {'min': np.array([0,0]), 'scale': np.array([124, 19])} if args.normalize else None
epsilon_decrease = args.epsilon_decrease
true_non_dominated = dst_non_dominated(env, normalize)
agent = DPQN(env, policy=lambda s, q, e: action_selection(s, q, e, ref_point),
agent = PDQN(env, policy=lambda s, q, e: action_selection(s, q, e, ref_point),
memory=memory,
observe=lambda s: one_hot(env, s),
estimate_reward=rew_est,
......@@ -718,7 +734,7 @@ if __name__ == '__main__':
gamma=1.,
n_samples=args.n_samples)
logdir = 'runs/dpqn/lr_reward_{:.2E}/copy_reward_{}/lr_pareto_{:.2E}/copy_pareto_{}/epsilon_dec_{}/samples_{}/all_actions'.format(
logdir = 'runs/pdqn/lr_reward_{:.2E}/copy_reward_{}/lr_pareto_{:.2E}/copy_pareto_{}/epsilon_dec_{}/samples_{}/all_actions'.format(
args.lr_reward, args.copy_reward, args.lr_pareto, args.copy_pareto, args.epsilon_decrease, args.n_samples
)
# evaluate_agent(agent, env, logdir, true_non_dominated)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment