Skip to content
Snippets Groups Projects
Commit 6626b902 authored by Senne Deproost's avatar Senne Deproost :speech_balloon:
Browse files

Improved synthesis on SimpleActionOnly-v0

parent bbd32d0a
No related branches found
No related tags found
No related merge requests found
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_actionpy
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_actionpy
import os
import random
import time
......@@ -21,6 +21,8 @@ from optim import ProgramOptimizer
import envs
RES = []
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
......@@ -49,7 +51,7 @@ class Args:
# Algorithm specific arguments
env_id: str = "SimpleActionOnly-v0"
"""the id of the environment"""
total_timesteps: int = int(1e3)
total_timesteps: int = int(1e4)
"""total timesteps of the experiments"""
learning_rate: float = 3e-4
"""the learning rate of the optimizer"""
......@@ -65,21 +67,21 @@ class Args:
"""the scale of policy noise"""
exploration_noise: float = 0.1
"""the scale of exploration noise"""
learning_starts: int = total_timesteps / 2
learning_starts: int = 0.1 * total_timesteps
"""timestep to start learning"""
policy_frequency: int = 2
policy_frequency: int = 3
"""the frequency of training policy (delayed)"""
noise_clip: float = 0.5
"""noise clip parameter of the Target Policy Smoothing Regularization"""
# Parameters for the program optimizer
num_individuals: int = 5
num_individuals: int = 10
num_genes: int = 2
num_generations: int = 5
num_parents_mating: int = 3
keep_parents: int = 2
mutation_percent_genes: int = 50
num_generations: int = 10
num_parents_mating: int = 2
keep_parents: int = 1
mutation_percent_genes: int = 10
keep_elites: int = 1
......@@ -115,6 +117,7 @@ class QNetwork(nn.Module):
def get_state_actions(program, obs, env, grad_required=False):
program_actions = []
obs = obs.detach().numpy()
for i, o in enumerate(obs):
program_actions.append(program(o, len_output=env.action_space.shape[0]))
program_actions = torch.tensor(program_actions, requires_grad=grad_required)
......@@ -195,7 +198,7 @@ def run_synthesis(args: Args):
action = program(torch.Tensor(obs).to(device).detach().numpy(), len_output=env.action_space.shape[0])
# TRY NOT TO MODIFY: execute the game and log data.
print(action)
print(f'Program {program} gives action {action}')
next_obs, reward, termination, truncation, info = env.step(action)
# TRY NOT TO MODIFY: record rewards for plotting purposes
......@@ -234,6 +237,7 @@ def run_synthesis(args: Args):
qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
qf_loss = qf1_loss + qf2_loss
#print(f'Loss critic: {qf1_loss}')
# optimize the model
q_optimizer.zero_grad()
......@@ -247,10 +251,11 @@ def run_synthesis(args: Args):
program_loss = -qf1(data.observations, program_actions).mean()
#program_loss.backward()
action_gradients = grad(program_loss, program_actions)
optimal_actions = program_actions + action_gradients[0]
improved_actions = program_actions - (10e-2 * action_gradients[0])
RES.append(improved_actions[0].detach().numpy())
program_optimizer.fit(states=data.observations.detach().numpy(),
actions=optimal_actions.detach().numpy())
#actions=np.array([[0.5]]))
actions=improved_actions.detach().numpy())
#actions=np.ones(shape=(args.batch_size, 1))*0.5)
# update the target network
for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
......@@ -270,6 +275,10 @@ def run_synthesis(args: Args):
env.close()
writer.close()
import matplotlib.pyplot as plt
plt.plot(RES)
plt.show()
if __name__ == "__main__":
run_synthesis()
......@@ -67,8 +67,8 @@ class ProgramOptimizer:
crossover_type="single_point",
mutation_type="random",
mutation_percent_genes=self.config.mutation_percent_genes,
random_mutation_max_val=10,
random_mutation_min_val=-10,
random_mutation_max_val=5,
random_mutation_min_val=-5,
gene_space={
'low': -NUM_OPERATORS - N_INPUT_VARIABLES,
'high': 1.0
......@@ -85,6 +85,7 @@ class ProgramOptimizer:
# Print the best individual
#program = self.get_best_program()
#print(program(states[0], do_print=True))
#self.ga_instance.plot_fitness()
@dataclass
......@@ -126,7 +127,6 @@ def main(config: Config):
# Plot
optim.ga_instance.plot_fitness()
print('done')
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment