Skip to content
Snippets Groups Projects
Commit 3d659075 authored by Denis Steckelmacher's avatar Denis Steckelmacher
Browse files

ProgramOptimizer that fits entire (large) actions

parent d4afcee5
No related branches found
No related tags found
No related merge requests found
......@@ -49,7 +49,7 @@ class Args:
"""the user or org name of the model repository from the Hugging Face Hub"""
# Algorithm specific arguments
env_id: str = "SimpleActionOnly-v0"
env_id: str = "SimpleLargeAction-v0"
"""the id of the environment"""
total_timesteps: int = int(1e4)
"""total timesteps of the experiments"""
......@@ -69,14 +69,14 @@ class Args:
"""the scale of exploration noise"""
learning_starts: int = 256
"""timestep to start learning"""
policy_frequency: int = 3
policy_frequency: int = 100
"""the frequency of training policy (delayed)"""
noise_clip: float = 0.5
"""noise clip parameter of the Target Policy Smoothing Regularization"""
# Parameters for the program optimizer
num_individuals: int = 100
num_genes: int = 5
num_genes: int = 4
num_eval_runs: int = 10
num_generations: int = 20
......@@ -115,19 +115,22 @@ class QNetwork(nn.Module):
return x
def get_state_actions(program, obs, env, args, grad_required=False):
def get_state_actions(program_optimizer, obs, env, args, grad_required=False):
program_actions = []
obs = obs.detach().numpy()
for i, o in enumerate(obs):
action = np.zeros(env.action_space.shape, dtype=np.float32)
for eval_run in range(args.num_eval_runs):
action += program(o, len_output=env.action_space.shape[0])
for eval_run in range(1):
action += program_optimizer.get_actions_from_solution(
program_optimizer.best_solution,
o
)
program_actions.append(action / args.num_eval_runs)
program_actions = torch.tensor(program_actions, requires_grad=grad_required)
program_actions = torch.tensor(np.array(program_actions), requires_grad=grad_required)
return program_actions
......@@ -165,7 +168,7 @@ def run_synthesis(args: Args):
assert isinstance(env.action_space, gym.spaces.Box), "only continuous action space is supported"
# Actor is a learnable program
program_optimizer = ProgramOptimizer(args)
program_optimizer = ProgramOptimizer(args, env.action_space.shape)
qf1 = QNetwork(env).to(device)
qf2 = QNetwork(env).to(device)
......@@ -187,23 +190,20 @@ def run_synthesis(args: Args):
# TRY NOT TO MODIFY: start the game
obs, _ = env.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# Get best program from optimizer
program = program_optimizer.get_best_program()
fitness = program_optimizer.best_fitness
# Print program
print(f'Best program: {program}, with fitness {fitness}')
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
action = env.action_space.sample()
else:
with torch.no_grad():
action = program(torch.Tensor(obs).to(device).detach().numpy(), len_output=env.action_space.shape[0])
action = program_optimizer.get_actions_from_solution(
program_optimizer.best_solution,
obs
)
# TRY NOT TO MODIFY: execute the game and log data.
print(f'Program {program} gives action {action}')
next_obs, reward, termination, truncation, info = env.step(action)
# TRY NOT TO MODIFY: record rewards for plotting purposes
......@@ -227,7 +227,7 @@ def run_synthesis(args: Args):
)
# Go over all observations the buffer provides
next_state_actions = get_state_actions(program, data.next_observations, env, args)
next_state_actions = get_state_actions(program_optimizer, data.next_observations, env, args)
next_state_actions = (next_state_actions + clipped_noise).clamp(
env.action_space.low[0], env.action_space.high[0]).float()
......@@ -251,18 +251,19 @@ def run_synthesis(args: Args):
# Optimize the program
if global_step % args.policy_frequency == 0:
program_actions = get_state_actions(program, data.observations, env, args, grad_required=True)
program_actions = get_state_actions(program_optimizer, data.observations, env, args, grad_required=True)
program_objective = qf1(data.observations, program_actions).mean()
program_objective.backward()
improved_actions = program_actions + 0.1 * program_actions.grad
print(program_actions, improved_actions)
RES.append(improved_actions[0].detach().numpy())
program_optimizer.fit(states=data.observations.detach().numpy(),
actions=improved_actions.detach().numpy())
#actions=np.ones(shape=(args.batch_size, 1))*0.5)
# Print program
program_optimizer.print_best_solution()
# update the target network
for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
......
......@@ -8,10 +8,11 @@ from dataclasses import dataclass
from postfix_program import Program, NUM_OPERATORS
class ProgramOptimizer:
def __init__(self, config):
def __init__(self, config, action_shape):
# Create the initial population
self.initial_program = [-1.0] * (config.num_genes * 2) # Mean and log_std for each gene
self.action_shape = action_shape
self.initial_program = [0.0] * (config.num_genes * 2 * action_shape[0]) # Mean and log_std for each gene, for each action dimension
self.best_solution = self.initial_program
self.best_fitness = None
......@@ -19,39 +20,48 @@ class ProgramOptimizer:
self.config = config
self.initial_population = [np.array(self.initial_program) for i in range(config.num_individuals)]
def get_best_program(self):
return Program(genome=self.best_solution)
def get_actions_from_solution(self, solution, state):
# One program per action dimension
program_length = self.config.num_genes * 2
programs = [
Program(genome=solution[i*program_length : (i+1)*program_length])
for i in range(self.action_shape[0])
]
def fit(self, states, actions):
""" states is a batch of states, shape (N, state_shape)
actions is a batch of actions, shape (N, action_shape), we assume continuous actions
"""
return np.array([p(state) for p in programs], dtype=np.float32)
def print_best_solution(self):
program_length = self.config.num_genes * 2
def fitness_func(ga_instance, solution, solution_idx):
batch_size = states.shape[0]
action_size = actions.shape[1]
sum_error = 0.0
for i in range(self.action_shape[0]):
p = Program(genome=self.best_solution[i*program_length : (i+1)*program_length])
print(f'a[{i}] =', p.run_program([0.0], do_print=True))
program = Program(genome=solution)
def _fitness_func(self, ga_instance, solution, solution_idx):
batch_size = self.states.shape[0]
sum_error = 0.0
# Evaluate the program several times, because evaluations are stochastic
for eval_run in range(self.config.num_eval_runs):
for index in range(batch_size):
action = program(states[index], len_output=action_size)
desired_action = actions[index]
# Evaluate the program several times, because evaluations are stochastic
for eval_run in range(self.config.num_eval_runs):
for index in range(batch_size):
action = self.get_actions_from_solution(solution, self.states[index])
desired_action = self.actions[index]
sum_error += np.mean((action - desired_action) ** 2)
sum_error += np.mean((action - desired_action) ** 2)
fitness = -(sum_error / (batch_size + self.config.num_eval_runs))
fitness = -(sum_error / (batch_size + self.config.num_eval_runs))
if self.best_fitness is None or fitness > self.best_fitness:
self.best_solution = solution
self.best_fitness = fitness
return fitness
return fitness
def fit(self, states, actions):
""" states is a batch of states, shape (N, state_shape)
actions is a batch of actions, shape (N, action_shape), we assume continuous actions
"""
self.states = states # picklable self._fitness_func needs these instance variables
self.actions = actions
self.ga_instance = pygad.GA(
fitness_func=fitness_func,
fitness_func=self._fitness_func,
initial_population=self.initial_population,
num_generations=self.config.num_generations,
num_parents_mating=self.config.num_parents_mating,
......@@ -68,6 +78,7 @@ class ProgramOptimizer:
mutation_type="random",
random_mutation_max_val=10,
random_mutation_min_val=-10,
parallel_processing=["process", None]
)
self.ga_instance.run()
......@@ -75,6 +86,9 @@ class ProgramOptimizer:
# Allow the population to survive
self.initial_population = self.ga_instance.population
# Best solution for now
self.best_solution = self.ga_instance.best_solution()[0]
@dataclass
class Config:
num_individuals: int = 1000
......
......@@ -25,11 +25,12 @@ class Operator:
OPERATORS = [
Operator('<end>', 0, None),
Operator('<', 2, lambda a, b: float(a < b)),
Operator('>', 2, lambda a, b: float(a > b)),
Operator('==', 2, lambda a, b: float(a == b)),
Operator('!=', 2, lambda a, b: float(a != b)),
Operator('abs', 1, lambda a: abs(a)),
Operator('sin', 1, lambda a: math.sin(a)),
Operator('cos', 1, lambda a: math.cos(a)),
Operator('exp', 1, lambda a: math.exp(min(a, 10.0))),
Operator('sqrt', 1, lambda a: math.sqrt(max(a, 0.0))),
Operator('neg', 1, lambda a: -a),
Operator('+', 2, lambda a, b: a + b),
Operator('-', 2, lambda a, b: a - b),
Operator('*', 2, lambda a, b: a * b),
......@@ -38,13 +39,12 @@ OPERATORS = [
Operator('max', 2, lambda a, b: max(a, b)),
Operator('min', 2, lambda a, b: min(a, b)),
Operator('trunc', 1, lambda a: float(int(a))),
Operator('abs', 1, lambda a: abs(a)),
Operator('neg', 1, lambda a: -a),
Operator('sin', 1, lambda a: math.sin(a)),
Operator('cos', 1, lambda a: math.cos(a)),
Operator('exp', 1, lambda a: math.exp(min(a, 10.0))),
Operator('sqrt', 1, lambda a: math.sqrt(max(a, 0.0))),
Operator('<', 2, lambda a, b: float(a < b)),
Operator('>', 2, lambda a, b: float(a > b)),
Operator('==', 2, lambda a, b: float(a == b)),
Operator('!=', 2, lambda a, b: float(a != b)),
Operator('?', 3, lambda cond, a, b: a if cond > 0.5 else b),
Operator('<end>', 0, None),
]
NUM_OPERATORS = len(OPERATORS)
......@@ -56,15 +56,13 @@ class Program:
def __str__(self):
return repr(self.run_program(inp=[1], do_print=True))
def __call__(self, inp, len_output=None):
def __call__(self, inp):
res = self.run_program(inp, do_print=False)
# If the desired output length is given, pad the result with zeroes if needed
if len_output:
res = np.array(res + [0.0] * len_output, dtype=np.float32)
res = res[:len_output]
return res
if len(res) == 0:
return 0.0
else:
return res[-1]
def run_program(self, inp, do_print=False):
stack = []
......@@ -96,10 +94,10 @@ class Program:
input_index = -value - NUM_OPERATORS - 1
# Silently ignore input variables beyond the end of inp
if input_index < len(inp):
if do_print:
stack.append(f'x{input_index}')
else:
if do_print:
stack.append(f'x{input_index}')
else:
if input_index < len(inp):
stack.append(inp[input_index])
continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment