If you use Gitkraken, immediately update to version 8.1 (or later) remove your SSH key from https://gitlab.ai.vub.ac.be/-/profile/keys and generate a new one. SSH keys generated with a vulnerable Gitkraken version are compromised.

Commit 059fa3ce authored by Mathieu Reymond's avatar Mathieu Reymond
Browse files

per weights bugfix, simplified experience_replay

parent 03670bd8
......@@ -6,17 +6,30 @@ from sum_tree import SumTree
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
import copy
from pathlib import Path
from dataclasses import dataclass, astuple
import random
plt.switch_backend('agg')
Transition = namedtuple('Transition',
['observation',
'action',
'reward',
'next_observation',
'terminal'])
@dataclass
class Transition(object):
observation: np.ndarray
action: int
reward: float
next_observation: np.ndarray
terminal: bool
@dataclass
class BatchTransition(object):
observation: np.ndarray
action: np.ndarray
reward: np.ndarray
next_observation: np.ndarray
terminal: np.ndarray
def unreachable(s):
y, x = np.unravel_index(s, (11, 10))
......@@ -59,62 +72,21 @@ def dst_non_dominated(env, normalize):
class Memory(object):
def __init__(self, observation_shape, observation_type='float16', size=1000000, nO=1):
self.current = 0
# we will only save next_states,
# as current state is simply the previous next state.
# We thus need an extra slot to prevent overlap between the first and
# last sample
size += 1
self.size = size
def __init__(self, size=100000):
self.actions = np.empty((size,), dtype='uint8')
if observation_shape == (1,):
self.next_observations = np.empty((size,), dtype=observation_type)
else:
self.next_observations = np.empty((size,) + observation_shape, dtype=observation_type)
self.rewards = np.empty((size, nO), dtype='float16')
self.terminals = np.empty((size,), dtype=bool)
self.size = size
self.memory = []
self.current = 0
def add(self, transition):
# first sample, need to save current state
if self.current == 0:
self.next_observations[0] = transition.observation
self.current += 1
current = self.current % self.size
self.actions[current] = transition.action
self.next_observations[current] = transition.next_observation
self.rewards[current] = transition.reward
self.terminals[current] = transition.terminal
if len(self.memory) < self.size:
self.memory.append(None)
self.memory[self.current] = np.array(astuple(transition))
self.current = (self.current + 1) % self.size
def sample(self, batch_size):
assert self.current > 0, 'need at least one sample in memory'
high = self.current % self.size
# did not fill memory
if self.current < self.size:
# start at 1, as 0 contains only current state
low = 1
else:
# do not include oldest sample, as it's state (situated in previous sample)
# has been overwritten by newest sample
low = high - self.size + 2
indexes = np.empty((batch_size,), dtype='int32')
i = 0
while i < batch_size:
# include high
s = np.random.randint(low, high+1)
# cannot include first step of episode, as it does not have a previous state
if not self.terminals[s-1]:
indexes[i] = s
i += 1
batch = Transition(
self.next_observations[indexes-1],
self.actions[indexes],
self.rewards[indexes],
self.next_observations[indexes],
self.terminals[indexes]
)
batch = random.sample(self.memory, batch_size)
batch = BatchTransition(*[np.array(i) for i in zip(*batch)])
return batch
......@@ -136,44 +108,37 @@ class PrioritizedMemory(Memory):
self.last_sampled = None
def add(self, transition):
super(PrioritizedMemory, self).add(transition)
# new items are added with max priority, initially 1
if self.current == 1:
if self.current == 0:
p = 1
else:
_, p, _ = self.tree.get(self.tree.total())
self.tree.add(p, int(self.current % self.size))
self.tree.add(p, int(self.current))
super(PrioritizedMemory, self).add(transition)
def importance_sampling(self):
# last sampled contains tree-indexes, get corresponding priorities
priorities = self.tree.tree[self.last_sampled]
priorities = self.tree.tree[self.last_sampled] + 1e-8
w = (self.tree.total()/(self.tree.n_entries*priorities))**self.beta
# shift weights to avoid majority of 0's
# w += 1
# normalize w
w = w/np.max(w)
w = w/(np.max(w) + 1e-8)
assert np.all(w >= 0), f'negative normalized weights \n {priorities} \n {w}'
return w
def sample(self, batch_size):
buckets = np.linspace(0, self.tree.total(), batch_size+1)
indexes = []
batch = []
self.last_sampled = []
for i in range(batch_size):
sampled_priority = np.random.uniform(buckets[i], buckets[i+1])
tree_idx, _, trans_idx = self.tree.get(sampled_priority)
# only add transition if not first of episode, as it does not have a previous state
if not self.terminals[trans_idx-1]:
indexes.append(trans_idx)
self.last_sampled.append(tree_idx)
indexes = np.array(indexes)
self.last_sampled = np.array(self.last_sampled)
batch = Transition(
self.next_observations[indexes - 1],
self.actions[indexes],
self.rewards[indexes],
self.next_observations[indexes],
self.terminals[indexes]
)
batch.append(self.memory[trans_idx])
self.last_sampled.append(tree_idx)
batch = BatchTransition(*[np.array(i) for i in zip(*batch)])
# beta annealing after every sampling step
self.beta += self.beta_annealing
return batch
......@@ -223,7 +188,7 @@ class Estimator(object):
l = self.loss(preds, torch.from_numpy(targets).to(self.device))
l_report = l.detach().cpu().numpy()
if weights is not None:
weights = torch.from_numpy(targets).to(self.device)
weights = torch.from_numpy(weights).to(self.device).float().unsqueeze(1)
l = l*weights
if self.clamp is not None:
l = torch.clamp(l, min=-self.clamp, max=self.clamp)
......@@ -394,7 +359,7 @@ class PDQN(Agent):
next_observation=next_obs,
terminal=terminal)
self.memory.add(t)
if log.total_steps >= self.batch_size: # self.learn_start:
if log.total_steps > self.batch_size: # self.learn_start:
batch = self.memory.sample(self.batch_size)
# normalize reward for pareto_estimator
......@@ -826,9 +791,9 @@ if __name__ == '__main__':
# rew_est = DSTReward(env)
if not args.per:
memory = Memory((env.nS,), size=args.mem_size, nO=nO)
memory = Memory(size=args.mem_size)
else:
memory = PrioritizedMemory((env.nS,), n_steps=1e5, size=args.mem_size, nO=nO)
memory = PrioritizedMemory(n_steps=1e5, size=args.mem_size)
ref_point = np.array([-2, -2])
normalize = {'min': np.array([0,0]), 'scale': np.array([124, 19])} if args.normalize else None
epsilon_decrease = args.epsilon_decrease
......@@ -847,7 +812,7 @@ if __name__ == '__main__':
gamma=1.,
n_samples=args.n_samples)
logdir = 'runs/pdqn/per_{}/lr_reward_{:.2E}/copy_reward_{}/lr_pareto_{:.2E}/copy_pareto_{}/epsilon_dec_{}/samples_{}/'.format(
logdir = '/tmp/runs/pdqn/per_{}/lr_reward_{:.2E}/copy_reward_{}/lr_pareto_{:.2E}/copy_pareto_{}/epsilon_dec_{}/samples_{}/'.format(
int(args.per), args.lr_reward, args.copy_reward, args.lr_pareto, args.copy_pareto, args.epsilon_decrease, args.n_samples
)
# evaluate_agent(agent, env, logdir, true_non_dominated)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment