diff --git a/requirements.txt b/requirements.txt index 19501538a4aca6678352cae9e2eb336084d6c1e2..d719f285a3765e2e183954f49f5284b3b184732c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,5 @@ -http://download.pytorch.org/whl/cpu/torch-2.2.1%2Bcpu-cp310-cp310-linux_x86_64.whl -stable_baselines3==2.2.1 \ No newline at end of file +## http://download.pytorch.org/whl/cpu/torch-2.2.1%2Bcpu-cp310-cp310-linux_x86_64.whl +stable_baselines3==2.2.1 +gymnasium==0.29.1 +gymnasium[classic-control] +gymnasium[box2d] diff --git a/workshop.ipynb b/workshop.ipynb index 29cd7437ec2b3a2850634677f085180ac916cf5c..c87fcf44af59326a55c81c8c7b7914a9f2cbda2e 100644 --- a/workshop.ipynb +++ b/workshop.ipynb @@ -3,31 +3,185 @@ { "cell_type": "code", "execution_count": null, - "id": "initial_id", + "id": "1c50ce3a-fb54-4338-b7eb-5dfb1b44b3ba", + "metadata": {}, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "import torch\n", + "import numpy as np\n", + "import random\n", + "import copy" + ] + }, + { + "cell_type": "markdown", + "id": "822863d7-e030-4bcc-9ef3-4e17e39561cc", "metadata": { - "collapsed": true + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } }, + "source": [ + "## 1. Differentiable Decision Trees" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c39a908-18c5-4371-a3e8-b1cb94a16766", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.chdir('Interpretable_DDTS_AISTATS2020')\n", + "import interpretable_ddts as tree\n", + "import torch.multiprocessing as mp\n", + "from interpretable_ddts.agents.ddt_agent import DDTAgent\n", + "from interpretable_ddts.opt_helpers.replay_buffer import discount_reward" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fe1e8a7-5f7d-4dc7-ae49-f7cb309dca34", + "metadata": {}, + "outputs": [], + "source": [ + "init_env = gym.make('CartPole-v1')\n", + "dim_in = init_env.observation_space.shape[0]\n", + "dim_out = init_env.action_space.n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6a95ed3-9746-4f50-8afa-40583b0e9b33", + "metadata": {}, + "outputs": [], + "source": [ + "def run_episode(q, agent_in, ENV_NAME, seed=0):\n", + " agent = agent_in.duplicate()\n", + " if ENV_NAME == 'lunar':\n", + " env = gym.make('LunarLander-v2')\n", + " elif ENV_NAME == 'cart':\n", + " env = gym.make('CartPole-v1')\n", + " else:\n", + " raise Exception('No valid environment selected')\n", + " terminated = False\n", + " torch.manual_seed(seed)\n", + " np.random.seed(seed)\n", + " env.action_space.seed(seed)\n", + " random.seed(seed)\n", + " state, _ = env.reset(seed=seed) # Reset environment and record the starting state\n", + "\n", + " while not terminated:\n", + " action = agent.get_action(state)\n", + " # Step through environment using chosen action\n", + " state, reward, terminated, truncated, info = env.step(action)\n", + " # env.render()\n", + " # Save reward\n", + " agent.save_reward(reward)\n", + " if terminated:\n", + " break\n", + " reward_sum = np.sum(agent.replay_buffer.rewards_list)\n", + " rewards_list, advantage_list, deeper_advantage_list = discount_reward(agent.replay_buffer.rewards_list,\n", + " agent.replay_buffer.value_list,\n", + " agent.replay_buffer.deeper_value_list)\n", + " agent.replay_buffer.rewards_list = rewards_list\n", + " agent.replay_buffer.advantage_list = advantage_list\n", + " agent.replay_buffer.deeper_advantage_list = deeper_advantage_list\n", + "\n", + " to_return = [reward_sum, copy.deepcopy(agent.replay_buffer.__getstate__())]\n", + " if q is not None:\n", + " try:\n", + " q.put(to_return)\n", + " except RuntimeError as e:\n", + " print(e)\n", + " return to_return\n", + " return to_return" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c27bd47-80f4-42e4-824b-121cb1f9bd84", + "metadata": {}, + "outputs": [], + "source": [ + "def main(episodes, agent, ENV_NAME):\n", + " running_reward_array = []\n", + " for episode in range(episodes):\n", + " reward = 0\n", + " returned_object = run_episode(None, agent_in=agent, ENV_NAME=ENV_NAME)\n", + " reward += returned_object[0]\n", + " running_reward_array.append(returned_object[0])\n", + " agent.replay_buffer.extend(returned_object[1])\n", + " #if reward >= 499:\n", + " # agent.save('../models/'+str(episode)+'th')\n", + " agent.end_episode(reward)\n", + "\n", + " running_reward = sum(running_reward_array[-100:]) / float(min(100.0, len(running_reward_array)))\n", + " if episode % 50 == 0:\n", + " print(f'Episode {episode} Last Reward: {reward} Average Reward: {running_reward}')\n", + " if episode % 500 == 0:\n", + " pass\n", + " #agent.save('../models/'+str(episode)+'th')\n", + "\n", + " return running_reward_array" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46882d36-d9e0-4fee-83c7-3780c9a4f349", + "metadata": {}, + "outputs": [], + "source": [ + "mp.set_sharing_strategy('file_system')\n", + "policy_agent = DDTAgent(bot_name='cartpole_agent',\n", + " input_dim=dim_in,\n", + " output_dim=dim_out,\n", + " rule_list=False,\n", + " num_rules=8)\n", + "reward_array = main(550, policy_agent, 'cart')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b356c1f-2eef-424c-8eb5-1a886792f80d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f1af2ab-94e9-4726-8ff9-41d85230c597", + "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "carl-venv", "language": "python", - "name": "python3" + "name": "carl-venv" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.10.13" } }, "nbformat": 4,