Skip to content
Snippets Groups Projects
Commit c7bb3867 authored by Denis Steckelmacher's avatar Denis Steckelmacher
Browse files

Better constant propagation

parent dba96ab7
No related branches found
No related tags found
No related merge requests found
......@@ -49,7 +49,7 @@ class Args:
"""the user or org name of the model repository from the Hugging Face Hub"""
# Algorithm specific arguments
env_id: str = "SimpleLargeAction-v0"
env_id: str = "SimpleTwoStates-v0"
"""the id of the environment"""
total_timesteps: int = int(1e4)
"""total timesteps of the experiments"""
......@@ -61,7 +61,7 @@ class Args:
"""the discount factor gamma"""
tau: float = 0.005
"""target smoothing coefficient (default: 0.005)"""
batch_size: int = 32
batch_size: int = 16
"""the batch size of sample from the reply memory"""
policy_noise: float = 0.2
"""the scale of policy noise"""
......@@ -86,17 +86,14 @@ class Args:
def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env.action_space.seed(seed)
return env
return thunk
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env.action_space.seed(seed)
return env
# ALGO LOGIC: initialize agent here:
......@@ -159,11 +156,11 @@ def run_synthesis(args: Args):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
# env setup
env = make_env(args.env_id, args.seed, 0, args.capture_video, run_name)()
env = make_env(args.env_id, args.seed, 0, args.capture_video, run_name)
assert isinstance(env.action_space, gym.spaces.Box), "only continuous action space is supported"
# Actor is a learnable program
program_optimizers = [ProgramOptimizer(args) for i in range(env.action_space.shape[0])]
program_optimizers = [ProgramOptimizer(args, env.observation_space.shape[0]) for i in range(env.action_space.shape[0])]
qf1 = QNetwork(env).to(device)
qf2 = QNetwork(env).to(device)
......@@ -208,6 +205,10 @@ def run_synthesis(args: Args):
real_next_obs = next_obs.copy()
rb.add(obs, real_next_obs, action, reward, termination, info)
# RESET
if termination or truncation:
next_obs, _ = env.reset()
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
......@@ -251,7 +252,7 @@ def run_synthesis(args: Args):
program_objective = qf1(data.observations, program_actions).mean()
program_objective.backward()
improved_actions = program_actions + 0.1 * program_actions.grad
improved_actions = program_actions + program_actions.grad
RES.append(improved_actions[0].detach().numpy())
......
......@@ -8,16 +8,18 @@ from dataclasses import dataclass
from postfix_program import Program, NUM_OPERATORS
class ProgramOptimizer:
def __init__(self, config):
def __init__(self, config, state_dim):
# Create the initial population
self.initial_program = [0.0] * (config.num_genes * 2) # Mean and log_std for each gene
# We create it so these random programs try all the operators and read all the state variables
self.initial_population = np.random.random((config.num_individuals, config.num_genes * 2)) # Random numbers between 0 and 1
self.initial_population *= NUM_OPERATORS + state_dim # Between 0 and NUM_OPERATORS + state_dim
self.initial_population *= -1.0 # Between -NUM_OPERATORS -state_dim and 0
self.best_solution = self.initial_program
self.best_solution = self.initial_population[0]
self.best_fitness = None
self.config = config
self.initial_population = [np.array(self.initial_program) for i in range(config.num_individuals)]
def get_action(self, state):
program = Program(genome=self.best_solution)
......
......@@ -60,19 +60,11 @@ class Program:
return self.run_program(inp, do_print=False)
def __str__(self):
expression = self.run_program([1.0], do_print=True)
# Simple constant propagation: if the resulting expression can be eval'd,
# it means that it only uses operators and constants, so we can simply
# show the program as the constant
try:
functions = {operator.name: operator.function for operator in OPERATORS}
return str(eval(expression, functions))
except:
return expression
return self.run_program([1.0], do_print=True)
def run_program(self, inp, do_print=False):
stack = []
functions = {operator.name: operator.function for operator in OPERATORS}
for pointer in range(0, len(self.genome), 2):
# Sample the actual token to execute
......@@ -143,6 +135,14 @@ class Program:
elif len(operands) == 3:
result = f"({operands[0]} ? {operands[1]} : {operands[2]})"
# Simple constant propagation: if the resulting expression can be eval'd,
# it means that it only uses operators and constants, so we can simply
# show the program as the constant
try:
result = str(eval(result, functions))
except:
pass
stack.append(result)
else:
# Run the operator and get the result back
......@@ -154,7 +154,6 @@ class Program:
else:
return stack[-1]
if __name__ == '__main__':
print(Program([5.0, 1.0, -2.0, -5.0, 18.0, 0.0, -8.0, -2.0]))
print(Program([-17.0, 0.0]).run_program([0.0], do_print=False))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment