Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Mathieu Reymond
deep-sea-treasure
Commits
059fa3ce
Commit
059fa3ce
authored
Jul 16, 2019
by
Mathieu Reymond
Browse files
per weights bugfix, simplified experience_replay
parent
03670bd8
Changes
1
Hide whitespace changes
Inline
Side-by-side
pdqn.py
View file @
059fa3ce
...
...
@@ -6,17 +6,30 @@ from sum_tree import SumTree
import
numpy
as
np
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
collections
import
namedtuple
import
copy
from
pathlib
import
Path
from
dataclasses
import
dataclass
,
astuple
import
random
plt
.
switch_backend
(
'agg'
)
Transition
=
namedtuple
(
'Transition'
,
[
'observation'
,
'action'
,
'reward'
,
'next_observation'
,
'terminal'
])
@
dataclass
class
Transition
(
object
):
observation
:
np
.
ndarray
action
:
int
reward
:
float
next_observation
:
np
.
ndarray
terminal
:
bool
@
dataclass
class
BatchTransition
(
object
):
observation
:
np
.
ndarray
action
:
np
.
ndarray
reward
:
np
.
ndarray
next_observation
:
np
.
ndarray
terminal
:
np
.
ndarray
def
unreachable
(
s
):
y
,
x
=
np
.
unravel_index
(
s
,
(
11
,
10
))
...
...
@@ -59,62 +72,21 @@ def dst_non_dominated(env, normalize):
class
Memory
(
object
):
def
__init__
(
self
,
observation_shape
,
observation_type
=
'float16'
,
size
=
1000000
,
nO
=
1
):
self
.
current
=
0
# we will only save next_states,
# as current state is simply the previous next state.
# We thus need an extra slot to prevent overlap between the first and
# last sample
size
+=
1
self
.
size
=
size
def
__init__
(
self
,
size
=
100000
):
self
.
actions
=
np
.
empty
((
size
,),
dtype
=
'uint8'
)
if
observation_shape
==
(
1
,):
self
.
next_observations
=
np
.
empty
((
size
,),
dtype
=
observation_type
)
else
:
self
.
next_observations
=
np
.
empty
((
size
,)
+
observation_shape
,
dtype
=
observation_type
)
self
.
rewards
=
np
.
empty
((
size
,
nO
),
dtype
=
'float16'
)
self
.
terminals
=
np
.
empty
((
size
,),
dtype
=
bool
)
self
.
size
=
size
self
.
memory
=
[]
self
.
current
=
0
def
add
(
self
,
transition
):
# first sample, need to save current state
if
self
.
current
==
0
:
self
.
next_observations
[
0
]
=
transition
.
observation
self
.
current
+=
1
current
=
self
.
current
%
self
.
size
self
.
actions
[
current
]
=
transition
.
action
self
.
next_observations
[
current
]
=
transition
.
next_observation
self
.
rewards
[
current
]
=
transition
.
reward
self
.
terminals
[
current
]
=
transition
.
terminal
if
len
(
self
.
memory
)
<
self
.
size
:
self
.
memory
.
append
(
None
)
self
.
memory
[
self
.
current
]
=
np
.
array
(
astuple
(
transition
))
self
.
current
=
(
self
.
current
+
1
)
%
self
.
size
def
sample
(
self
,
batch_size
):
assert
self
.
current
>
0
,
'need at least one sample in memory'
high
=
self
.
current
%
self
.
size
# did not fill memory
if
self
.
current
<
self
.
size
:
# start at 1, as 0 contains only current state
low
=
1
else
:
# do not include oldest sample, as it's state (situated in previous sample)
# has been overwritten by newest sample
low
=
high
-
self
.
size
+
2
indexes
=
np
.
empty
((
batch_size
,),
dtype
=
'int32'
)
i
=
0
while
i
<
batch_size
:
# include high
s
=
np
.
random
.
randint
(
low
,
high
+
1
)
# cannot include first step of episode, as it does not have a previous state
if
not
self
.
terminals
[
s
-
1
]:
indexes
[
i
]
=
s
i
+=
1
batch
=
Transition
(
self
.
next_observations
[
indexes
-
1
],
self
.
actions
[
indexes
],
self
.
rewards
[
indexes
],
self
.
next_observations
[
indexes
],
self
.
terminals
[
indexes
]
)
batch
=
random
.
sample
(
self
.
memory
,
batch_size
)
batch
=
BatchTransition
(
*
[
np
.
array
(
i
)
for
i
in
zip
(
*
batch
)])
return
batch
...
...
@@ -136,44 +108,37 @@ class PrioritizedMemory(Memory):
self
.
last_sampled
=
None
def
add
(
self
,
transition
):
super
(
PrioritizedMemory
,
self
).
add
(
transition
)
# new items are added with max priority, initially 1
if
self
.
current
==
1
:
if
self
.
current
==
0
:
p
=
1
else
:
_
,
p
,
_
=
self
.
tree
.
get
(
self
.
tree
.
total
())
self
.
tree
.
add
(
p
,
int
(
self
.
current
%
self
.
size
))
self
.
tree
.
add
(
p
,
int
(
self
.
current
))
super
(
PrioritizedMemory
,
self
).
add
(
transition
)
def
importance_sampling
(
self
):
# last sampled contains tree-indexes, get corresponding priorities
priorities
=
self
.
tree
.
tree
[
self
.
last_sampled
]
priorities
=
self
.
tree
.
tree
[
self
.
last_sampled
]
+
1e-8
w
=
(
self
.
tree
.
total
()
/
(
self
.
tree
.
n_entries
*
priorities
))
**
self
.
beta
# shift weights to avoid majority of 0's
# w += 1
# normalize w
w
=
w
/
np
.
max
(
w
)
w
=
w
/
(
np
.
max
(
w
)
+
1e-8
)
assert
np
.
all
(
w
>=
0
),
f
'negative normalized weights
\n
{
priorities
}
\n
{
w
}
'
return
w
def
sample
(
self
,
batch_size
):
buckets
=
np
.
linspace
(
0
,
self
.
tree
.
total
(),
batch_size
+
1
)
indexes
=
[]
batch
=
[]
self
.
last_sampled
=
[]
for
i
in
range
(
batch_size
):
sampled_priority
=
np
.
random
.
uniform
(
buckets
[
i
],
buckets
[
i
+
1
])
tree_idx
,
_
,
trans_idx
=
self
.
tree
.
get
(
sampled_priority
)
# only add transition if not first of episode, as it does not have a previous state
if
not
self
.
terminals
[
trans_idx
-
1
]:
indexes
.
append
(
trans_idx
)
self
.
last_sampled
.
append
(
tree_idx
)
indexes
=
np
.
array
(
indexes
)
self
.
last_sampled
=
np
.
array
(
self
.
last_sampled
)
batch
=
Transition
(
self
.
next_observations
[
indexes
-
1
],
self
.
actions
[
indexes
],
self
.
rewards
[
indexes
],
self
.
next_observations
[
indexes
],
self
.
terminals
[
indexes
]
)
batch
.
append
(
self
.
memory
[
trans_idx
])
self
.
last_sampled
.
append
(
tree_idx
)
batch
=
BatchTransition
(
*
[
np
.
array
(
i
)
for
i
in
zip
(
*
batch
)])
# beta annealing after every sampling step
self
.
beta
+=
self
.
beta_annealing
return
batch
...
...
@@ -223,7 +188,7 @@ class Estimator(object):
l
=
self
.
loss
(
preds
,
torch
.
from_numpy
(
targets
).
to
(
self
.
device
))
l_report
=
l
.
detach
().
cpu
().
numpy
()
if
weights
is
not
None
:
weights
=
torch
.
from_numpy
(
targe
ts
).
to
(
self
.
device
)
weights
=
torch
.
from_numpy
(
weigh
ts
).
to
(
self
.
device
)
.
float
().
unsqueeze
(
1
)
l
=
l
*
weights
if
self
.
clamp
is
not
None
:
l
=
torch
.
clamp
(
l
,
min
=-
self
.
clamp
,
max
=
self
.
clamp
)
...
...
@@ -394,7 +359,7 @@ class PDQN(Agent):
next_observation
=
next_obs
,
terminal
=
terminal
)
self
.
memory
.
add
(
t
)
if
log
.
total_steps
>
=
self
.
batch_size
:
# self.learn_start:
if
log
.
total_steps
>
self
.
batch_size
:
# self.learn_start:
batch
=
self
.
memory
.
sample
(
self
.
batch_size
)
# normalize reward for pareto_estimator
...
...
@@ -826,9 +791,9 @@ if __name__ == '__main__':
# rew_est = DSTReward(env)
if
not
args
.
per
:
memory
=
Memory
(
(
env
.
nS
,),
size
=
args
.
mem_size
,
nO
=
nO
)
memory
=
Memory
(
size
=
args
.
mem_size
)
else
:
memory
=
PrioritizedMemory
(
(
env
.
nS
,),
n_steps
=
1e5
,
size
=
args
.
mem_size
,
nO
=
nO
)
memory
=
PrioritizedMemory
(
n_steps
=
1e5
,
size
=
args
.
mem_size
)
ref_point
=
np
.
array
([
-
2
,
-
2
])
normalize
=
{
'min'
:
np
.
array
([
0
,
0
]),
'scale'
:
np
.
array
([
124
,
19
])}
if
args
.
normalize
else
None
epsilon_decrease
=
args
.
epsilon_decrease
...
...
@@ -847,7 +812,7 @@ if __name__ == '__main__':
gamma
=
1.
,
n_samples
=
args
.
n_samples
)
logdir
=
'runs/pdqn/per_{}/lr_reward_{:.2E}/copy_reward_{}/lr_pareto_{:.2E}/copy_pareto_{}/epsilon_dec_{}/samples_{}/'
.
format
(
logdir
=
'
/tmp/
runs/pdqn/per_{}/lr_reward_{:.2E}/copy_reward_{}/lr_pareto_{:.2E}/copy_pareto_{}/epsilon_dec_{}/samples_{}/'
.
format
(
int
(
args
.
per
),
args
.
lr_reward
,
args
.
copy_reward
,
args
.
lr_pareto
,
args
.
copy_pareto
,
args
.
epsilon_decrease
,
args
.
n_samples
)
# evaluate_agent(agent, env, logdir, true_non_dominated)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment