Skip to content
Snippets Groups Projects
Commit a5724f80 authored by Denis Steckelmacher's avatar Denis Steckelmacher
Browse files

Stochastic program evaluation with mean and log_std of the genes

parent 6626b902
No related branches found
No related tags found
No related merge requests found
......@@ -61,13 +61,13 @@ class Args:
"""the discount factor gamma"""
tau: float = 0.005
"""target smoothing coefficient (default: 0.005)"""
batch_size: int = 1
batch_size: int = 32
"""the batch size of sample from the reply memory"""
policy_noise: float = 0.2
"""the scale of policy noise"""
exploration_noise: float = 0.1
"""the scale of exploration noise"""
learning_starts: int = 0.1 * total_timesteps
learning_starts: int = 256
"""timestep to start learning"""
policy_frequency: int = 3
"""the frequency of training policy (delayed)"""
......@@ -75,14 +75,14 @@ class Args:
"""noise clip parameter of the Target Policy Smoothing Regularization"""
# Parameters for the program optimizer
num_individuals: int = 10
num_genes: int = 2
num_individuals: int = 100
num_genes: int = 5
num_eval_runs: int = 10
num_generations: int = 10
num_parents_mating: int = 2
keep_parents: int = 1
num_generations: int = 20
num_parents_mating: int = 50
keep_parents: int = 5
mutation_percent_genes: int = 10
keep_elites: int = 1
def make_env(env_id, seed, idx, capture_video, run_name):
......@@ -115,14 +115,19 @@ class QNetwork(nn.Module):
return x
def get_state_actions(program, obs, env, grad_required=False):
def get_state_actions(program, obs, env, args, grad_required=False):
program_actions = []
obs = obs.detach().numpy()
for i, o in enumerate(obs):
program_actions.append(program(o, len_output=env.action_space.shape[0]))
action = np.zeros(env.action_space.shape, dtype=np.float32)
for eval_run in range(args.num_eval_runs):
action += program(o, len_output=env.action_space.shape[0])
program_actions.append(action / args.num_eval_runs)
program_actions = torch.tensor(program_actions, requires_grad=grad_required)
shp = (len(obs), 1)
program_actions.reshape(shp)
return program_actions
......@@ -222,7 +227,7 @@ def run_synthesis(args: Args):
)
# Go over all observations the buffer provides
next_state_actions = get_state_actions(program, data.next_observations, env)
next_state_actions = get_state_actions(program, data.next_observations, env, args)
next_state_actions = (next_state_actions + clipped_noise).clamp(
env.action_space.low[0], env.action_space.high[0]).float()
......@@ -246,12 +251,14 @@ def run_synthesis(args: Args):
# Optimize the program
if global_step % args.policy_frequency == 0:
program_actions = get_state_actions(program, data.observations, env, grad_required=True).float()
program_actions = get_state_actions(program, data.observations, env, args, grad_required=True)
program_objective = qf1(data.observations, program_actions).mean()
program_objective.backward()
improved_actions = program_actions + 0.1 * program_actions.grad
print(program_actions, improved_actions)
program_loss = -qf1(data.observations, program_actions).mean()
#program_loss.backward()
action_gradients = grad(program_loss, program_actions)
improved_actions = program_actions - (10e-2 * action_gradients[0])
RES.append(improved_actions[0].detach().numpy())
program_optimizer.fit(states=data.observations.detach().numpy(),
actions=improved_actions.detach().numpy())
......@@ -269,7 +276,7 @@ def run_synthesis(args: Args):
writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
writer.add_scalar("losses/programm_loss", program_loss.item(), global_step)
writer.add_scalar("losses/program_objective", program_objective.item(), global_step)
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
env.close()
......
......@@ -7,13 +7,11 @@ from dataclasses import dataclass
from postfix_program import Program, NUM_OPERATORS
N_INPUT_VARIABLES = 1
class ProgramOptimizer:
def __init__(self, config):
# Create the initial population
self.initial_program = [-1.0] * config.num_genes
self.initial_program = [-1.0] * (config.num_genes * 2) # Mean and log_std for each gene
self.best_solution = self.initial_program
self.best_fitness = None
......@@ -21,8 +19,6 @@ class ProgramOptimizer:
self.config = config
self.initial_population = [np.array(self.initial_program) for i in range(config.num_individuals)]
self.f = None
def get_best_program(self):
return Program(genome=self.best_solution)
......@@ -38,55 +34,46 @@ class ProgramOptimizer:
program = Program(genome=solution)
for index in range(batch_size):
action = program(states[index], len_output=action_size)
desired_action = actions[index]
# Evaluate the program several times, because evaluations are stochastic
for eval_run in range(self.config.num_eval_runs):
for index in range(batch_size):
action = program(states[index], len_output=action_size)
desired_action = actions[index]
sum_error += np.mean((action - desired_action) ** 2)
sum_error += np.mean((action - desired_action) ** 2)
fitness = -(sum_error / batch_size)
fitness = -(sum_error / (batch_size + self.config.num_eval_runs))
# !!!!
if self.best_fitness is None or fitness > self.best_fitness:
self.best_solution = solution
self.best_fitness = fitness
#print('F', fitness, file=sys.stderr)
return fitness
self.ga_instance = pygad.GA(num_generations=self.config.num_generations,
#parallel_processing=8,
save_solutions=True,
save_best_solutions=True,
num_parents_mating=self.config.num_parents_mating,
fitness_func=fitness_func,
initial_population=self.initial_population,
parent_selection_type="sss",
keep_parents=self.config.keep_parents,
crossover_type="single_point",
mutation_type="random",
mutation_percent_genes=self.config.mutation_percent_genes,
random_mutation_max_val=5,
random_mutation_min_val=-5,
gene_space={
'low': -NUM_OPERATORS - N_INPUT_VARIABLES,
'high': 1.0
},
keep_elitism=1,
)
self.ga_instance = pygad.GA(
fitness_func=fitness_func,
initial_population=self.initial_population,
num_generations=self.config.num_generations,
num_parents_mating=self.config.num_parents_mating,
keep_parents=self.config.keep_parents,
mutation_percent_genes=self.config.mutation_percent_genes,
# Work with non-deterministic objective functions
keep_elitism=0,
save_solutions=False,
save_best_solutions=False,
parent_selection_type="sss",
crossover_type="single_point",
mutation_type="random",
random_mutation_max_val=10,
random_mutation_min_val=-10,
)
self.ga_instance.run()
# Allow the population to survive
self.initial_population = self.ga_instance.population
self.f = self.ga_instance.population
# Print the best individual
#program = self.get_best_program()
#print(program(states[0], do_print=True))
#self.ga_instance.plot_fitness()
@dataclass
class Config:
......
......@@ -5,15 +5,15 @@
# Input variables # negative, we can have many of them
# <end> # OPERATOR_END
#
# 1. PyGAD produces numpy arrays (lists of floats), turn them into <see above>
# 1. PyGAD produces numpy arrays (lists of floats). Look at them in pairs of (mean, variance).
# sample a value from that normal distribution, and transform the sample to one
# of the tokens listed above
# 2. Run that
#
# Format: genes are floats. We
import math
import numpy as np
import torch as th
class Operator:
def __init__(self, name, num_operands, function):
self.name = name
......@@ -50,36 +50,36 @@ NUM_OPERATORS = len(OPERATORS)
class Program:
def __init__(self, genome=None, size=None):
self.size = size
if genome is not None:
self.genome = genome
self.size = len(genome)
else:
assert size is not None, "If genome is not specified, size must be given"
self.genome = np.ones(size)
def __init__(self, genome):
self.genome = genome
def __str__(self):
return f'{self.run_program(inp=[1], do_print=True)}'
def __call__(self, inp, len_output=None, do_print=False):
return repr(self.run_program(inp=[1], do_print=True))
res = self.run_program(inp, do_print=do_print)
def __call__(self, inp, len_output=None):
res = self.run_program(inp, do_print=False)
# If the desired output length is given, pad the result with zeroes if needed
if len_output:
res = np.array(res + [0.0] * len_output)
res = np.array(res + [0.0] * len_output, dtype=np.float32)
res = res[:len_output]
if do_print:
return res
else:
return np.array(res)
return res
def run_program(self, inp, do_print=False):
stack = []
for value in self.genome:
for pointer in range(0, len(self.genome), 2):
# Sample the actual token to execute
mean = self.genome[pointer + 0]
log_std = self.genome[pointer + 1]
if log_std > 10.0:
log_std = 10.0 # Prevent exp() from overflowing
value = np.random.normal(loc=mean, scale=math.exp(log_std))
# Execute the token
if value >= 0.0:
# Literal, push it
if do_print:
......@@ -139,7 +139,5 @@ class Program:
if __name__ == '__main__':
print(Program([2.0, -21.0, -6.0, -1.0, -1.0])([3.14, 6.28]))
print(Program([-21, -7.0, -6.0, -22.0, 0.0, 0.0, -1.0, -1.0])([1, 8]))
print(Program([5.0, -21.0, -6.0, -1.0, -1.0])([3.14, 6.28], do_print=True))
print(Program([-17.0])([0.0], len_output=1, do_print=False))
print(Program([5.0, 1.0, -21.0, -2.0]).run_program([3.14, 6.28], do_print=True))
print(Program([-17.0, 0.0]).run_program([0.0], do_print=False))
absl-py==2.1.0
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
contourpy==1.2.1
cycler==0.12.1
decorator==4.4.2
docker-pycreds==0.4.0
docstring_parser==0.16
etils==1.7.0
Farama-Notifications==0.0.4
filelock==3.15.4
fonttools==4.53.1
fsspec==2024.6.1
gitdb==4.0.11
GitPython==3.1.43
glfw==2.7.0
grpcio==1.64.1
gymnasium==0.29.1
idna==3.7
imageio==2.34.2
imageio-ffmpeg==0.5.1
importlib_resources==6.4.0
Jinja2==3.1.4
kiwisolver==1.4.5
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.1
mdurl==0.1.2
moviepy==1.0.3
mpmath==1.3.0
mujoco==3.1.6
mypy-extensions==1.0.0
networkx==3.3
numpy==1.26.4
packaging==24.1
pandas==2.2.2
pillow==10.4.0
platformdirs==4.2.2
proglog==0.1.10
protobuf==4.25.3
psutil==6.0.0
pygad==3.3.1
pygame==2.6.0
Pygments==2.18.0
PyOpenGL==3.1.7
pyparsing==3.1.2
pyrallis==0.3.1
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
requests==2.32.3
rich==13.7.1
sentry-sdk==2.9.0
setproctitle==1.3.3
shtab==1.7.1
six==1.16.0
smmap==5.0.1
stable_baselines3==2.3.2
sympy==1.13.0
tensorboard==2.17.0
tensorboard-data-server==0.7.2
torch==2.3.1
tqdm==4.66.4
typing-inspect==0.9.0
typing_extensions==4.12.2
tyro==0.8.5
tzdata==2024.1
urllib3==2.2.2
wandb==0.17.4
Werkzeug==3.0.3
zipp==3.19.2
torch
stable_baselines3
tensorboard
pyrallis
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment