diff --git a/.teamcity/build.sh b/.teamcity/build.sh index 35d295d01..125399945 100755 --- a/.teamcity/build.sh +++ b/.teamcity/build.sh @@ -49,12 +49,14 @@ function run_example_test { python -m pip uninstall -r ./examples/DQN_variant/requirements.txt -y python -m pip install -r ./examples/PPO/requirements_atari.txt - python examples/PPO/train.py --train_total_steps 5000 --env PongNoFrameskip-v4 + python examples/PPO/atari/train.py --train_total_steps 5000 --env PongNoFrameskip-v4 python -m pip uninstall -r ./examples/PPO/requirements_atari.txt -y + xparl start --port 8010 --cpu_num 8 python -m pip install -r ./examples/PPO/requirements_mujoco.txt - python examples/PPO/train.py --train_total_steps 5000 --env HalfCheetah-v4 --continuous_action + python examples/PPO/mujoco/train.py --env 'HalfCheetah-v2' --train_total_episodes 100 --env_num 5 python -m pip uninstall -r ./examples/PPO/requirements_mujoco.txt -y + xparl stop python -m pip install -r ./examples/SAC/requirements.txt python examples/SAC/train.py --train_total_steps 5000 --env HalfCheetah-v4 diff --git a/examples/PPO/README.md b/examples/PPO/README.md index dd8d88faa..131a35494 100644 --- a/examples/PPO/README.md +++ b/examples/PPO/README.md @@ -4,15 +4,17 @@ Based on PARL, the PPO algorithm of deep reinforcement learning has been reprodu > Paper: PPO in [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) ### Mujoco/Atari games introduction -PARL currently supports the open-source version of Mujoco provided by DeepMind, so users do not need to download binaries of Mujoco as well as install mujoco-py and get license. For more details, please visit [Mujoco](https://github.com/deepmind/mujoco). +PARL currently supports the open-source version of Mujoco provided by DeepMind, so users do not need to download binaries of Mujoco as well as install [mujoco-py](https://github.com/openai/mujoco-py#install-mujoco). For more details, please visit [Mujoco](https://github.com/deepmind/mujoco). ### Benchmark result #### 1. Mujoco games results +The horizontal axis represents the number of episodes.

mujoco-result

#### 2. Atari games results +The horizontal axis represents the number of steps.

atari-result

@@ -23,29 +25,21 @@ PARL currently supports the open-source version of Mujoco provided by DeepMind, ### Mujoco-Dependencies: + python3.7+ + [paddle>=2.3.1](https://github.com/PaddlePaddle/Paddle) -+ [parl>=2.1.1](https://github.com/PaddlePaddle/PARL) -+ gym>=0.26.0 ++ [parl>=2.2.2](https://github.com/PaddlePaddle/PARL) ++ gym==0.18.0 + mujoco>=2.2.2 ++ mujoco-py==2.1.2.14 ### Atari-Dependencies: + [paddle>=2.3.1](https://github.com/PaddlePaddle/Paddle) -+ [parl>=2.1.1](https://github.com/PaddlePaddle/PARL) ++ [parl>=2.2.2](https://github.com/PaddlePaddle/PARL) + gym==0.18.0 + atari-py==0.2.6 + opencv-python -### Training: - -``` -# To train an agent for discrete action game (Atari: PongNoFrameskip-v4 by default) -python train.py - -# To train an agent for continuous action game (Mujoco) -python train.py --env 'HalfCheetah-v4' --continuous_action --train_total_steps 1000000 -``` -### Distributed Training -Accelerate training process by setting `xparl_addr` and `env_num > 1` when environment simulation running very slow. +### Training Mujoco Distributedly +Accelerate training process by setting `xparl_addr` and `env_num > 1` when environment simulation running very slowly. At first, we can start a local cluster with 8 CPUs: ``` @@ -56,14 +50,25 @@ Note that if you have started a master before, you don't have to run the above command. For more information about the cluster, please refer to our [documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html). -Then we can start the distributed training by running: +Then we can start the distributed training for mujoco games by running: ``` -# To train an agent distributedly +cd mujoco -# for discrete action game (Atari games) -python train.py --env "PongNoFrameskip-v4" --env_num 8 --xparl_addr 'localhost:8010' +python train.py --env 'HalfCheetah-v2' --train_total_episodes 100000 --env_num 5 +``` -# for continuous action game (Mujoco games) -python train.py --env 'HalfCheetah-v4' --continuous_action --train_total_steps 1000000 --env_num 5 --xparl_addr 'localhost:8010' + +### Training Atari +To train an agent for discrete action game (Atari: PongNoFrameskip-v4 by default): + +``` +cd atari + +# Local training +python train.py + +# Distributed training +xparl start --port 8010 --cpu_num 8 +python train.py --env "PongNoFrameskip-v4" --env_num 8 --xparl_addr 'localhost:8010' ``` diff --git a/examples/PPO/agent.py b/examples/PPO/atari/atari_agent.py similarity index 90% rename from examples/PPO/agent.py rename to examples/PPO/atari/atari_agent.py index 980b6a7d1..cb89e87d1 100644 --- a/examples/PPO/agent.py +++ b/examples/PPO/atari/atari_agent.py @@ -18,7 +18,7 @@ from parl.utils.scheduler import LinearDecayScheduler -class PPOAgent(parl.Agent): +class AtariAgent(parl.Agent): """ Agent of PPO env Args: @@ -27,12 +27,11 @@ class PPOAgent(parl.Agent): """ def __init__(self, algorithm, config): - super(PPOAgent, self).__init__(algorithm) + super(AtariAgent, self).__init__(algorithm) self.config = config if self.config['lr_decay']: - self.lr_scheduler = LinearDecayScheduler( - self.config['initial_lr'], self.config['num_updates']) + self.lr_scheduler = LinearDecayScheduler(self.config['initial_lr'], self.config['num_updates']) def predict(self, obs): """ Predict action from current policy given observation @@ -85,8 +84,7 @@ def learn(self, rollout): else: lr = None - minibatch_size = int( - self.config['batch_size'] // self.config['num_minibatches']) + minibatch_size = int(self.config['batch_size'] // self.config['num_minibatches']) indexes = np.arange(self.config['batch_size']) for epoch in range(self.config['update_epochs']): @@ -105,9 +103,8 @@ def learn(self, rollout): batch_return = paddle.to_tensor(batch_return) batch_value = paddle.to_tensor(batch_value) - value_loss, action_loss, entropy_loss = self.alg.learn( - batch_obs, batch_action, batch_value, batch_return, - batch_logprob, batch_adv, lr) + value_loss, action_loss, entropy_loss = self.alg.learn(batch_obs, batch_action, batch_value, + batch_return, batch_logprob, batch_adv, lr) value_loss_epoch += value_loss action_loss_epoch += action_loss diff --git a/examples/PPO/atari_config.py b/examples/PPO/atari/atari_config.py similarity index 100% rename from examples/PPO/atari_config.py rename to examples/PPO/atari/atari_config.py diff --git a/examples/PPO/atari_model.py b/examples/PPO/atari/atari_model.py similarity index 100% rename from examples/PPO/atari_model.py rename to examples/PPO/atari/atari_model.py diff --git a/examples/PPO/env_utils.py b/examples/PPO/atari/env_utils.py similarity index 59% rename from examples/PPO/env_utils.py rename to examples/PPO/atari/env_utils.py index 410a6f059..dd394e76d 100644 --- a/examples/PPO/env_utils.py +++ b/examples/PPO/atari/env_utils.py @@ -16,13 +16,11 @@ import gym import numpy as np from parl.utils import logger +from parl.env.atari_wrappers import wrap_deepmind -TEST_EPISODE = 3 # wrapper parameters for atari env ENV_DIM = 84 OBS_FORMAT = 'NCHW' -# wrapper parameters for mujoco env -GAMMA = 0.99 class ParallelEnv(object): @@ -39,14 +37,9 @@ def __init__(self, config=None): base_env = LocalEnv if config['seed']: - self.env_list = [ - base_env(config['env'], config['seed'] + i) - for i in range(self.env_num) - ] + self.env_list = [base_env(config['env'], config['seed'] + i) for i in range(self.env_num)] else: - self.env_list = [ - base_env(config['env']) for _ in range(self.env_num) - ] + self.env_list = [base_env(config['env']) for _ in range(self.env_num)] if hasattr(self.env_list[0], '_max_episode_steps'): self._max_episode_steps = self.env_list[0]._max_episode_steps else: @@ -68,10 +61,7 @@ def reset(self): def step(self, action_list): next_obs_list, reward_list, done_list, info_list = [], [], [], [] if self.use_xparl: - return_list = [ - self.env_list[i].step(action_list[i]) - for i in range(self.env_num) - ] + return_list = [self.env_list[i].step(action_list[i]) for i in range(self.env_num)] return_list = [return_.get() for return_ in return_list] return_list = np.array(return_list, dtype=object) @@ -89,8 +79,7 @@ def step(self, action_list): done = done_[i] info = info_[i] else: - next_obs, reward, done, info = self.env_list[i].step( - action_list[i]) + next_obs, reward, done, info = self.env_list[i].step(action_list[i]) self.episode_steps_list[i] += 1 self.episode_reward_list[i] += reward @@ -104,49 +93,26 @@ def step(self, action_list): next_obs = self.env_list[i].reset() self.episode_steps_list[i] = 0 self.episode_reward_list[i] = 0 - if self.env_list[i].continuous_action: - # get running mean and variance of obs - self.eval_ob_rms = self.env_list[i].env.get_ob_rms() next_obs_list.append(next_obs) reward_list.append(reward) done_list.append(done) info_list.append(info) - return np.array(next_obs_list), np.array(reward_list), np.array( - done_list), np.array(info_list) + return np.array(next_obs_list), np.array(reward_list), np.array(done_list), np.array(info_list) class LocalEnv(object): def __init__(self, env_name, env_seed=None, test=False, ob_rms=None): env = gym.make(env_name) - # is instance of gym.spaces.Box - if hasattr(env.action_space, 'high'): - from parl.env.mujoco_wrappers import wrap_rms - self._max_episode_steps = env._max_episode_steps - self.continuous_action = True - if test: - self.env = wrap_rms(env, GAMMA, test=True, ob_rms=ob_rms) - else: - self.env = wrap_rms(env, gamma=GAMMA) # is instance of gym.spaces.Discrete - elif hasattr(env.action_space, 'n'): - from parl.env.atari_wrappers import wrap_deepmind - self.continuous_action = False + if hasattr(env.action_space, 'n'): if test: - self.env = wrap_deepmind( - env, - dim=ENV_DIM, - obs_format=OBS_FORMAT, - test=True, - test_episodes=1) + self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT, test=True, test_episodes=1) else: - self.env = wrap_deepmind( - env, dim=ENV_DIM, obs_format=OBS_FORMAT) + self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT) else: - raise AssertionError( - 'act_space must be instance of gym.spaces.Box or gym.spaces.Discrete' - ) + raise AssertionError('act_space must be instance of gym.spaces.Discrete') self.obs_space = self.env.observation_space self.act_space = self.env.action_space @@ -166,31 +132,13 @@ class RemoteEnv(object): def __init__(self, env_name, env_seed=None, test=False, ob_rms=None): env = gym.make(env_name) - if hasattr(env.action_space, 'high'): - from parl.env.mujoco_wrappers import wrap_rms - self._max_episode_steps = env._max_episode_steps - self.continuous_action = True - if test: - self.env = wrap_rms(env, GAMMA, test=True, ob_rms=ob_rms) - else: - self.env = wrap_rms(env, gamma=GAMMA) - elif hasattr(env.action_space, 'n'): - from parl.env.atari_wrappers import wrap_deepmind - self.continuous_action = False + if hasattr(env.action_space, 'n'): if test: - self.env = wrap_deepmind( - env, - dim=ENV_DIM, - obs_format=OBS_FORMAT, - test=True, - test_episodes=1) + self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT, test=True, test_episodes=1) else: - self.env = wrap_deepmind( - env, dim=ENV_DIM, obs_format=OBS_FORMAT) + self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT) else: - raise AssertionError( - 'act_space must be instance of gym.spaces.Box or gym.spaces.Discrete' - ) + raise AssertionError('act_space must be instance of gym.spaces.Discrete') if env_seed: self.env.seed(env_seed) @@ -201,6 +149,4 @@ def step(self, action): return self.env.step(action) def render(self): - return logger.warning( - 'Can not render in remote environment, render() have been skipped.' - ) + return logger.warning('Can not render in remote environment, render() have been skipped.') diff --git a/examples/PPO/storage.py b/examples/PPO/atari/storage.py similarity index 84% rename from examples/PPO/storage.py rename to examples/PPO/atari/storage.py index 059beb1ec..99fd9bfe4 100644 --- a/examples/PPO/storage.py +++ b/examples/PPO/atari/storage.py @@ -17,10 +17,8 @@ class RolloutStorage(): def __init__(self, step_nums, env_num, obs_space, act_space): - self.obs = np.zeros( - (step_nums, env_num) + obs_space.shape, dtype='float32') - self.actions = np.zeros( - (step_nums, env_num) + act_space.shape, dtype='float32') + self.obs = np.zeros((step_nums, env_num) + obs_space.shape, dtype='float32') + self.actions = np.zeros((step_nums, env_num) + act_space.shape, dtype='float32') self.logprobs = np.zeros((step_nums, env_num), dtype='float32') self.rewards = np.zeros((step_nums, env_num), dtype='float32') self.dones = np.zeros((step_nums, env_num), dtype='float32') @@ -54,10 +52,8 @@ def compute_returns(self, value, done, gamma=0.99, gae_lambda=0.95): else: nextnonterminal = 1.0 - self.dones[t + 1] nextvalues = self.values[t + 1] - delta = self.rewards[ - t] + gamma * nextvalues * nextnonterminal - self.values[t] - advantages[ - t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam + delta = self.rewards[t] + gamma * nextvalues * nextnonterminal - self.values[t] + advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam returns = advantages + self.values self.returns = returns self.advantages = advantages @@ -72,5 +68,4 @@ def sample_batch(self, idx): b_returns = self.returns.reshape(-1) b_values = self.values.reshape(-1) - return b_obs[idx], b_actions[idx], b_logprobs[idx], b_advantages[ - idx], b_returns[idx], b_values[idx] + return b_obs[idx], b_actions[idx], b_logprobs[idx], b_advantages[idx], b_returns[idx], b_values[idx] diff --git a/examples/PPO/train.py b/examples/PPO/atari/train.py similarity index 60% rename from examples/PPO/train.py rename to examples/PPO/atari/train.py index 52732672b..0bcb9dc06 100644 --- a/examples/PPO/train.py +++ b/examples/PPO/atari/train.py @@ -16,14 +16,12 @@ import numpy as np from parl.utils import logger, summary -from mujoco_config import mujoco_config from atari_config import atari_config from env_utils import ParallelEnv, LocalEnv from storage import RolloutStorage from atari_model import AtariModel -from mujoco_model import MujocoModel -from parl.algorithms import PPO -from agent import PPOAgent +from parl.algorithms import PPO_Atari +from atari_agent import AtariAgent # Runs policy until 'real done' and returns episode reward @@ -43,7 +41,7 @@ def run_evaluate_episodes(agent, eval_env, eval_episodes): def main(): - config = mujoco_config if args.continuous_action else atari_config + config = atari_config if args.env_num: config['env_num'] = args.env_num config['env'] = args.env @@ -53,8 +51,7 @@ def main(): config['train_total_steps'] = args.train_total_steps config['batch_size'] = int(config['env_num'] * config['step_nums']) - config['num_updates'] = int( - config['train_total_steps'] // config['batch_size']) + config['num_updates'] = int(config['train_total_steps'] // config['batch_size']) logger.info("------------------- PPO ---------------------") logger.info('Env: {}, seed: {}'.format(config['env'], config['seed'])) @@ -67,20 +64,12 @@ def main(): obs_space = eval_env.obs_space act_space = eval_env.act_space - if config['continuous_action']: - model = MujocoModel(obs_space, act_space) - else: - model = AtariModel(obs_space, act_space) - ppo = PPO( - model, - clip_param=config['clip_param'], - entropy_coef=config['entropy_coef'], - initial_lr=config['initial_lr'], - continuous_action=config['continuous_action']) - agent = PPOAgent(ppo, config) - - rollout = RolloutStorage(config['step_nums'], config['env_num'], obs_space, - act_space) + model = AtariModel(obs_space, act_space) + ppo = PPO_Atari( + model, clip_param=config['clip_param'], entropy_coef=config['entropy_coef'], initial_lr=config['initial_lr']) + agent = AtariAgent(ppo, config) + + rollout = RolloutStorage(config['step_nums'], config['env_num'], obs_space, act_space) obs = envs.reset() done = np.zeros(config['env_num'], dtype='float32') @@ -98,11 +87,9 @@ def main(): for k in range(config['env_num']): if done[k] and "episode" in info[k].keys(): - logger.info( - "Training: total steps: {}, episode rewards: {}". - format(total_steps, info[k]['episode']['r'])) - summary.add_scalar("train/episode_reward", - info[k]["episode"]["r"], total_steps) + logger.info("Training: total steps: {}, episode rewards: {}".format( + total_steps, info[k]['episode']['r'])) + summary.add_scalar("train/episode_reward", info[k]["episode"]["r"], total_steps) # Bootstrap value if not done value = agent.value(obs) @@ -115,53 +102,22 @@ def main(): while (total_steps + 1) // config['test_every_steps'] >= test_flag: test_flag += 1 - if config['continuous_action']: - # set running mean and variance of obs - ob_rms = envs.eval_ob_rms - eval_env.env.set_ob_rms(ob_rms) - - avg_reward = run_evaluate_episodes(agent, eval_env, - config['eval_episode']) + avg_reward = run_evaluate_episodes(agent, eval_env, config['eval_episode']) summary.add_scalar('eval/episode_reward', avg_reward, total_steps) - logger.info('Evaluation over: {} episodes, Reward: {}'.format( - config['eval_episode'], avg_reward)) + logger.info('Evaluation over: {} episodes, Reward: {}'.format(config['eval_episode'], avg_reward)) if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--env", type=str, default="PongNoFrameskip-v4", help="OpenAI gym environment name") + parser.add_argument("--seed", type=int, default=None, help="seed of the experiment") parser.add_argument( - "--env", - type=str, - default="PongNoFrameskip-v4", - help="OpenAI gym environment name") - parser.add_argument( - "--seed", type=int, default=None, help="seed of the experiment") - parser.add_argument( - "--env_num", - type=int, - default=None, - help= - "number of the environment. Note: if greater than 1, xparl is needed") - parser.add_argument( - '--continuous_action', - action='store_true', - default=False, - help='action type of the environment') - parser.add_argument( - "--xparl_addr", - type=str, - default=None, - help="xparl address for distributed training ") + "--env_num", type=int, default=None, help="number of the environment. Note: if greater than 1, xparl is needed") + parser.add_argument("--xparl_addr", type=str, default=None, help="xparl address for distributed training ") parser.add_argument( - '--train_total_steps', - type=int, - default=10e6, - help='number of total time steps to train (default: 10e6)') + '--train_total_steps', type=int, default=10e6, help='number of total time steps to train (default: 10e6)') parser.add_argument( - '--test_every_steps', - type=int, - default=int(5e3), - help='the step interval between two consecutive evaluations') + '--test_every_steps', type=int, default=int(5e3), help='the step interval between two consecutive evaluations') args = parser.parse_args() main() diff --git a/examples/PPO/mujoco/actor.py b/examples/PPO/mujoco/actor.py new file mode 100644 index 000000000..1224bd398 --- /dev/null +++ b/examples/PPO/mujoco/actor.py @@ -0,0 +1,76 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gym +import numpy as np +import parl +from mujoco_model import MujocoModel +from mujoco_agent import MujocoAgent +from parl.algorithms import PPO_Mujoco +from parl.env.compat_wrappers import CompatWrapper + + +@parl.remote_class(wait=False) +class Actor(object): + def __init__(self, config, seed=None): + env = gym.make(config['env']) + self.env = CompatWrapper(env) + self.env.seed(seed) + + obs_dim = self.env.observation_space.shape[0] + act_dim = self.env.action_space.shape[0] + obs_dim += 1 # add 1 to obs dim for time step feature + + model = MujocoModel(obs_dim, act_dim) + alg = PPO_Mujoco(model, act_dim=act_dim) + self.agent = MujocoAgent(alg, config) + + def run_episode(self, scaler): + obs = self.env.reset() + observes, actions, rewards, unscaled_obs = [], [], [], [] + dones = [] + step = 0.0 + scale, offset = scaler.get() + scale[-1] = 1.0 # don't scale time step feature + offset[-1] = 0.0 # don't offset time step feature + while True: + obs = obs.reshape((1, -1)) + obs = np.append(obs, [[step]], axis=1) # add time step feature + unscaled_obs.append(obs) + obs = (obs - offset) * scale # center and scale observations + obs = obs.astype('float32') + observes.append(obs) + + action = self.agent.sample(obs) + + action = action.reshape((1, -1)).astype('float32') + actions.append(action) + + obs, reward, done, _ = self.env.step(np.squeeze(action)) + dones.append(done) + rewards.append(reward) + step += 1e-3 # increment time step feature + + if done: + break + return { + 'obs': np.concatenate(observes), + 'actions': np.concatenate(actions), + 'rewards': np.array(rewards, dtype='float32'), + 'dones': np.array(dones, dtype='float32'), + 'unscaled_obs': np.concatenate(unscaled_obs) + } + + def set_weights(self, params): + self.agent.set_weights(params) diff --git a/examples/PPO/mujoco/mujoco_agent.py b/examples/PPO/mujoco/mujoco_agent.py new file mode 100644 index 000000000..8f24106d3 --- /dev/null +++ b/examples/PPO/mujoco/mujoco_agent.py @@ -0,0 +1,146 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import paddle +import numpy as np + + +class MujocoAgent(parl.Agent): + """ Agent of PPO env + + Args: + algorithm (`parl.Algorithm`): algorithm to be used in this agent. + config (dict): configs that used in this agent + """ + def __init__(self, algorithm, config): + super(MujocoAgent, self).__init__(algorithm) + + self.config = config + self.kl_targ = self.config['kl_targ'] + # Adaptive kl penalty coefficient + self.beta = 1.0 # dynamically adjusted D_KL loss multiplier + self.lr_multiplier = 1.0 # dynamically adjust lr when D_KL out of control + + self.value_learn_buffer = None + + def sample(self, obs): + """ Sample action from current policy given observation + + Args: + obs (np.array): observation, shape([batch_size] + obs_shape) + """ + obs = paddle.to_tensor(obs, dtype='float32') + action = self.alg.sample(obs) + action_numpy = action.detach().numpy()[0] + return action_numpy + + def predict(self, obs): + """ Predict action from current policy given observation + + Args: + obs (np.array): observation, shape([batch_size] + obs_shape) + """ + obs = paddle.to_tensor(obs, dtype='float32') + action = self.alg.predict(obs) + action_numpy = action.detach().numpy()[0] + return action_numpy + + def value(self, obs): + """ use the model to predict obs values + + Args: + obs (torch tensor): observation, shape([batch_size] + obs_shape) + """ + obs = paddle.to_tensor(obs, dtype='float32') + value = self.alg.value(obs) + value = value.detach().numpy() + return value + + def _batch_policy_learn(self, obs, actions, advantages): + obs = paddle.to_tensor(obs) + actions = paddle.to_tensor(actions) + advantages = paddle.to_tensor(advantages) + + loss, kl, entropy = self.alg.policy_learn( + obs, actions, advantages, beta=self.beta, lr_multiplier=self.lr_multiplier) + return loss, kl, entropy + + def _batch_value_learn(self, obs, discount_sum_rewards): + obs = paddle.to_tensor(obs) + discount_sum_rewards = paddle.to_tensor(discount_sum_rewards) + loss = self.alg.value_learn(obs, discount_sum_rewards) + return loss + + def policy_learn(self, obs, actions, advantages): + """ policy learn + """ + self.alg.sync_old_policy() + + all_loss, all_kl = [], [] + for _ in range(self.config['policy_learn_times']): + loss, kl, entropy = self._batch_policy_learn(obs, actions, advantages) + loss, kl, entropy = loss.numpy()[0], kl.numpy()[0], entropy.numpy()[0] + if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly + break + all_loss.append(loss) + all_kl.append(kl) + + if kl > self.kl_targ * 2: # servo beta to reach D_KL target + self.beta = np.minimum(35, 1.5 * self.beta) # max clip beta + if self.beta > 30 and self.lr_multiplier > 0.1: + self.lr_multiplier /= 1.5 + elif kl < self.kl_targ / 2: + self.beta = np.maximum(1 / 35, self.beta / 1.5) # min clip beta + if self.beta < (1 / 30) and self.lr_multiplier < 10: + self.lr_multiplier *= 1.5 + return loss, kl, self.beta, self.lr_multiplier, entropy + + def value_learn(self, obs, discount_sum_rewards): + """ value learn + """ + data_size = obs.shape[0] + num_batches = max(data_size // self.config['value_batch_size'], 1) + batch_size = data_size // num_batches + + if self.value_learn_buffer is None: + obs_train, discount_sum_rewards_train = obs, discount_sum_rewards + else: + obs_train = np.concatenate([obs, self.value_learn_buffer[0]]) + discount_sum_rewards_train = np.concatenate([discount_sum_rewards, self.value_learn_buffer[1]]) + self.value_learn_buffer = (obs, discount_sum_rewards) + + all_loss = [] + y_hat = self.alg.model.value(paddle.to_tensor(obs)).numpy().reshape( + [-1]) # check explained variance prior to update + old_exp_var = 1 - np.var(discount_sum_rewards - y_hat) / np.var(discount_sum_rewards) + + for _ in range(self.config['value_learn_times']): + random_ids = np.arange(obs_train.shape[0]) + np.random.shuffle(random_ids) + shuffle_obs_train = obs_train[random_ids] + shuffle_discount_sum_rewards_train = discount_sum_rewards_train[random_ids] + start = 0 + while start < data_size: + end = start + batch_size + loss = self._batch_value_learn(shuffle_obs_train[start:end, :], + shuffle_discount_sum_rewards_train[start:end]) + loss = loss.numpy()[0] + all_loss.append(loss) + start += batch_size + y_hat = self.alg.model.value(paddle.to_tensor(obs)).numpy().reshape( + [-1]) # check explained variance prior to update + value_loss = np.mean(np.square(y_hat - discount_sum_rewards)) # explained variance after update + exp_var = 1 - np.var(discount_sum_rewards - y_hat) / np.var(discount_sum_rewards) + return value_loss, exp_var, old_exp_var diff --git a/examples/PPO/mujoco/mujoco_config.py b/examples/PPO/mujoco/mujoco_config.py new file mode 100644 index 000000000..af5310f9c --- /dev/null +++ b/examples/PPO/mujoco/mujoco_config.py @@ -0,0 +1,40 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +mujoco_config = { + ## Commented parameters are set to default values in ppo + + #========== env config ========== + 'env': 'HalfCheetah-v2', # environment name + 'env_num': 5, # number of the environment + 'seed': 120, # seed of the experiment + 'xparl_addr': "localhost:8010", # xparl address for distributed training + + #========== training config ========== + 'train_total_episodes': int(1e6), # max training steps + 'episodes_per_batch': 5, + 'policy_learn_times': 20, # number of epochs for updating (ie K in the paper) + 'value_learn_times': 10, + 'value_batch_size': 256, + 'eval_episode': 3, + 'test_every_episodes': int(5e3), # interval between evaluations + + #========== coefficient of ppo ========== + 'kl_targ': 0.003, # D_KL target value + 'loss_type': 'KLPEN', # Choose loss type of PPO algorithm, 'CLIP' or 'KLPEN' + 'eps': 1e-5, # Adam optimizer epsilon (default: 1e-5) + 'clip_param': 0.2, # epsilon in clipping loss + 'gamma': 0.995, # discounting factor + 'gae_lambda': 0.98, # Lambda parameter for calculating N-step advantage +} diff --git a/examples/PPO/mujoco/mujoco_model.py b/examples/PPO/mujoco/mujoco_model.py new file mode 100644 index 000000000..0a310469e --- /dev/null +++ b/examples/PPO/mujoco/mujoco_model.py @@ -0,0 +1,104 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import paddle +import paddle.nn as nn +import numpy as np + + +class MujocoModel(parl.Model): + """ The Model for Mujoco env + Args: + obs_dim (int): observation dimension. + act_dim (int): action dimension. + """ + + def __init__(self, obs_dim, act_dim, init_logvar=0.0): + super(MujocoModel, self).__init__() + super(MujocoModel, self).__init__() + self.policy_model = PolicyModel(obs_dim, act_dim, init_logvar) + self.value_model = ValueModel(obs_dim) + self.policy_lr = self.policy_model.lr + self.value_lr = self.value_model.lr + + def value(self, obs): + """ Get value network prediction + Args: + obs (np.array): current observation + """ + return self.value_model.value(obs) + + def policy(self, obs): + """ Get policy network prediction + Args: + obs (np.array): current observation + """ + return self.policy_model.policy(obs) + + +class PolicyModel(parl.Model): + def __init__(self, obs_dim, act_dim, init_logvar): + super(PolicyModel, self).__init__() + self.policy_logvar = -1.0 + self.obs_dim = obs_dim + self.act_dim = act_dim + hid1_size = obs_dim * 10 + hid3_size = act_dim * 10 + hid2_size = int(np.sqrt(hid1_size * hid3_size)) + + self.lr = 9e-4 / np.sqrt(hid2_size) + + self.fc1 = nn.Linear(obs_dim, hid1_size) + self.fc2 = nn.Linear(hid1_size, hid2_size) + self.fc3 = nn.Linear(hid2_size, hid3_size) + self.fc_policy = nn.Linear(hid3_size, act_dim) + + # logvar_speed is used to 'fool' gradient descent into making faster updates to log-variances. + # heuristic sets logvar_speed based on network size. + logvar_speed = (10 * hid3_size) // 48 # default setting + # logvar_speed = (10 * hid3_size) // 8 # finetuned for Humanoid-v2 to achieve fast convergence + self.fc_pi_std = paddle.create_parameter([logvar_speed, act_dim], + dtype='float32', + default_initializer=nn.initializer.Constant(value=init_logvar)) + + def policy(self, obs): + hid1 = paddle.tanh(self.fc1(obs)) + hid2 = paddle.tanh(self.fc2(hid1)) + hid3 = self.fc3(hid2) + means = self.fc_policy(hid3) + logvars = paddle.sum(self.fc_pi_std, axis=0) + self.policy_logvar + return means, logvars + + +class ValueModel(parl.Model): + def __init__(self, obs_dim): + super(ValueModel, self).__init__() + hid1_size = obs_dim * 10 + hid3_size = 5 + hid2_size = int(np.sqrt(hid1_size * hid3_size)) + + self.lr = 1e-2 / np.sqrt(hid2_size) # 1e-3 empirically determined + + self.fc1 = nn.Linear(obs_dim, hid1_size) + self.fc2 = nn.Linear(hid1_size, hid2_size) + self.fc3 = nn.Linear(hid2_size, hid3_size) + self.fc_value = nn.Linear(hid3_size, 1) + + def value(self, obs): + hid1 = paddle.tanh(self.fc1(obs)) + hid2 = paddle.tanh(self.fc2(hid1)) + hid3 = paddle.tanh(self.fc3(hid2)) + value = self.fc_value(hid3) + return value diff --git a/examples/PPO/mujoco/train.py b/examples/PPO/mujoco/train.py new file mode 100644 index 000000000..e7b33f7e2 --- /dev/null +++ b/examples/PPO/mujoco/train.py @@ -0,0 +1,178 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import gym +import numpy as np + +import parl +from parl.utils import logger, summary +from parl.utils.rl_utils import calc_gae, calc_discount_sum_rewards, Scaler +from parl.env.compat_wrappers import CompatWrapper +from parl.algorithms import PPO_Mujoco +from mujoco_model import MujocoModel +from mujoco_agent import MujocoAgent +from actor import Actor +from mujoco_config import mujoco_config + + +def run_evaluate_episodes(env, agent, scaler, eval_episodes): + eval_episode_rewards = [] + while len(eval_episode_rewards) < eval_episodes: + obs = env.reset() + rewards = 0 + step = 0.0 + scale, offset = scaler.get() + scale[-1] = 1.0 # don't scale time step feature + offset[-1] = 0.0 # don't offset time step feature + while True: + obs = obs.reshape((1, -1)) + obs = np.append(obs, [[step]], axis=1) # add time step feature + obs = (obs - offset) * scale # center and scale observations + obs = obs.astype('float32') + + action = agent.predict(obs) + obs, reward, done, _ = env.step(np.squeeze(action)) + rewards += reward + step += 1e-3 # increment time step feature + + if done: + break + eval_episode_rewards.append(rewards) + return np.mean(eval_episode_rewards) + + +def get_remote_trajectories(actors, scaler): + remote_ids = [actor.run_episode(scaler) for actor in actors] + return_list = [return_.get() for return_ in remote_ids] + + trajectories, all_unscaled_obs = [], [] + for res in return_list: + obs, actions, rewards, dones, unscaled_obs = res['obs'], res['actions'], res['rewards'], res['dones'], res[ + 'unscaled_obs'] + trajectories.append({'obs': obs, 'actions': actions, 'rewards': rewards, 'dones': dones}) + all_unscaled_obs.append(unscaled_obs) + # update running statistics for scaling observations + scaler.update(np.concatenate(all_unscaled_obs)) + return trajectories + + +def build_train_data(config, trajectories, agent): + train_obs, train_actions, train_advantages, train_discount_sum_rewards = [], [], [], [] + for trajectory in trajectories: + pred_values = agent.value(trajectory['obs']).squeeze() + + # scale rewards + scale_rewards = trajectory['rewards'] * (1 - config['gamma']) + if len(scale_rewards) <= 1: + continue + discount_sum_rewards = calc_discount_sum_rewards(scale_rewards, config['gamma']).astype('float32') + advantages = calc_gae(scale_rewards, pred_values, 0, config['gamma'], config['gae_lambda']) + advantages = advantages.astype('float32') + # normalize advantages + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + train_obs.append(trajectory['obs']) + train_actions.append(trajectory['actions']) + train_advantages.append(advantages) + train_discount_sum_rewards.append(discount_sum_rewards) + + train_obs = np.concatenate(train_obs) + train_actions = np.concatenate(train_actions) + train_advantages = np.concatenate(train_advantages) + train_discount_sum_rewards = np.concatenate(train_discount_sum_rewards) + return train_obs, train_actions, train_advantages, train_discount_sum_rewards + + +def main(): + config = mujoco_config + config['env'] = args.env + config['seed'] = args.seed + config['env_num'] = args.env_num + config['test_every_episodes'] = args.test_every_episodes + config['train_total_episodes'] = args.train_total_episodes + config['episodes_per_batch'] = args.episodes_per_batch + + logger.info("------------------- PPO ---------------------") + logger.info('Env: {}, seed: {}'.format(config['env'], config['seed'])) + logger.info("---------------------------------------------") + logger.set_dir('./train_logs/{}_{}'.format(config['env'], config['seed'])) + + env = gym.make(args.env) + env = CompatWrapper(env) + env.seed(args.seed) + + obs_dim = env.observation_space.shape[0] + act_dim = env.action_space.shape[0] + obs_dim += 1 # add 1 to obs dim for time step feature + + scaler = Scaler(obs_dim) + model = MujocoModel(obs_dim, act_dim) + alg = PPO_Mujoco(model, act_dim=act_dim) + agent = MujocoAgent(alg, config) + + parl.connect(config['xparl_addr']) + actors = [Actor(config) for _ in range(config["env_num"])] + # run a few episodes to initialize scaler + get_remote_trajectories(actors, scaler) + + test_flag = 0 + episode = 0 + while episode < config['train_total_episodes']: + latest_params = agent.get_weights() + # setting the actor to the latest_params + for remote_actor in actors: + remote_actor.set_weights(latest_params) + + trajectories = [] + while len(trajectories) < config['episodes_per_batch']: + trajectories.extend(get_remote_trajectories(actors, scaler)) + episode += len(trajectories) + + train_obs, train_actions, train_advantages, train_discount_sum_rewards = build_train_data( + config, trajectories, agent) + + policy_loss, kl, beta, lr_multiplier, entropy = agent.policy_learn(train_obs, train_actions, train_advantages) + value_loss, exp_var, old_exp_var = agent.value_learn(train_obs, train_discount_sum_rewards) + + total_train_rewards = sum([np.sum(t['rewards']) for t in trajectories]) + logger.info('Training: Episode {}, Avg train reward: {}, Policy loss: {}, KL: {}, Value loss: {}'.format( + episode, total_train_rewards / len(trajectories), policy_loss, kl, value_loss)) + summary.add_scalar("train/episode_mean_reward", total_train_rewards / len(trajectories), episode) + + if episode // config['test_every_episodes'] >= test_flag: + while episode // config['test_every_episodes'] >= test_flag: + test_flag += 1 + + avg_reward = run_evaluate_episodes(env, agent, scaler, config['eval_episode']) + summary.add_scalar('eval/episode_reward', avg_reward, episode) + logger.info('Evaluation over: {} episodes, Reward: {}'.format(config['eval_episode'], avg_reward)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--env', type=str, help='Mujoco environment name', default='Swimmer-v2') + parser.add_argument( + "--env_num", type=int, default=5, help="number of the environment, xparl is needed") + parser.add_argument('--episodes_per_batch', type=int, default=5, help='Number of episodes per training batch') + parser.add_argument('--train_total_episodes', type=int, default=int(100), help='maximum training steps') + parser.add_argument( + '--test_every_episodes', + type=int, + default=int(50), + help='the step interval between two consecutive evaluations') + parser.add_argument('--seed', type=int, default=1, help='the step interval between two consecutive evaluations') + args = parser.parse_args() + + main() diff --git a/examples/PPO/mujoco_config.py b/examples/PPO/mujoco_config.py deleted file mode 100644 index 89eec5ab1..000000000 --- a/examples/PPO/mujoco_config.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -mujoco_config = { - ## Commented parameters are set to default values in ppo - - #========== env config ========== - 'env': 'HalfCheetah-v2', # environment name - 'continuous_action': True, # action type of the environment - 'env_num': 1, # number of the environment - 'seed': None, # seed of the experiment - 'xparl_addr': None, # xparl address for distributed training - - #========== training config ========== - 'train_total_steps': int(1e6), # max training steps - 'step_nums': 2048, # data collecting time steps (ie. T in the paper) - 'num_minibatches': 32, # number of training minibatches per update. - 'update_epochs': 10, # number of epochs for updating (ie K in the paper) - 'eval_episode': 3, - 'test_every_steps': int(5e3), # interval between evaluations - - #========== coefficient of ppo ========== - 'initial_lr': 3e-4, # start learning rate - 'lr_decay': True, # whether or not to use linear decay rl - # 'eps': 1e-5, # Adam optimizer epsilon (default: 1e-5) - 'clip_param': 0.2, # epsilon in clipping loss - 'entropy_coef': 0.0, # Entropy coefficient (ie. c_2 in the paper) - # 'value_loss_coef': 0.5, # Value loss coefficient (ie. c_1 in the paper) - # 'max_grad_norm': 0.5, # Max gradient norm for gradient clipping - # 'use_clipped_value_loss': True, # advantages normalization - # 'clip_vloss': True, # whether or not to use a clipped loss for the value function - # 'gamma': 0.99, # discounting factor - # 'gae': True, # whether or not to use GAE - # 'gae_lambda': 0.95, # Lambda parameter for calculating N-step advantage -} diff --git a/examples/PPO/mujoco_model.py b/examples/PPO/mujoco_model.py deleted file mode 100644 index a45f086e8..000000000 --- a/examples/PPO/mujoco_model.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import parl -import paddle -import paddle.nn as nn -import numpy as np - - -class MujocoModel(parl.Model): - """ The Model for Mujoco env - Args: - obs_space (Box): observation space. - act_space (Box): action space. - """ - - def __init__(self, obs_space, act_space): - super(MujocoModel, self).__init__() - - self.fc1 = nn.Linear(obs_space.shape[0], 64) - self.fc2 = nn.Linear(64, 64) - - self.fc_value = nn.Linear(64, 1) - - self.fc_policy = nn.Linear(64, np.prod(act_space.shape)) - self.fc_pi_std = paddle.static.create_parameter( - [1, np.prod(act_space.shape)], - dtype='float32', - default_initializer=nn.initializer.Constant(value=0)) - - def value(self, obs): - """ Get value network prediction - Args: - obs (np.array): current observation - """ - out = paddle.tanh(self.fc1(obs)) - out = paddle.tanh(self.fc2(out)) - value = self.fc_value(out) - return value - - def policy(self, obs): - """ Get policy network prediction - Args: - obs (np.array): current observation - """ - out = paddle.tanh(self.fc1(obs)) - out = paddle.tanh(self.fc2(out)) - action_mean = self.fc_policy(out) - - action_logstd = self.fc_pi_std - action_std = paddle.exp(action_logstd) - return action_mean, action_std diff --git a/examples/PPO/requirements_atari.txt b/examples/PPO/requirements_atari.txt index 20d2b18c1..164608db5 100644 --- a/examples/PPO/requirements_atari.txt +++ b/examples/PPO/requirements_atari.txt @@ -1,5 +1,5 @@ gym==0.18.0 paddlepaddle>=2.0.0 -parl>=2.1.1 +parl>=2.2.2 atari-py==0.2.6 opencv-python diff --git a/examples/PPO/requirements_mujoco.txt b/examples/PPO/requirements_mujoco.txt index 5ec7789bb..c25354382 100644 --- a/examples/PPO/requirements_mujoco.txt +++ b/examples/PPO/requirements_mujoco.txt @@ -1,4 +1,5 @@ -gym>=0.26.0 +gym==0.18.0 mujoco==2.2.2 +mujoco-py==2.1.2.14 paddlepaddle>=2.0.0 -parl>=2.1.1 +parl>=2.2.2 diff --git a/parl/algorithms/paddle/__init__.py b/parl/algorithms/paddle/__init__.py index 7a60022fb..457ff5dcc 100644 --- a/parl/algorithms/paddle/__init__.py +++ b/parl/algorithms/paddle/__init__.py @@ -22,6 +22,7 @@ from parl.algorithms.paddle.a2c import * from parl.algorithms.paddle.ddqn import * from parl.algorithms.paddle.maddpg import * -from parl.algorithms.paddle.ppo import * +from parl.algorithms.paddle.ppo_atari import * +from parl.algorithms.paddle.ppo_mujoco import * from parl.algorithms.paddle.cql import * from parl.algorithms.paddle.impala.impala import * diff --git a/parl/algorithms/paddle/ppo.py b/parl/algorithms/paddle/ppo_atari.py similarity index 67% rename from parl/algorithms/paddle/ppo.py rename to parl/algorithms/paddle/ppo_atari.py index 1b54a294b..9146853a1 100644 --- a/parl/algorithms/paddle/ppo.py +++ b/parl/algorithms/paddle/ppo_atari.py @@ -17,13 +17,13 @@ import paddle.nn as nn import paddle.nn.functional as F import paddle.optimizer as optim -from paddle.distribution import Normal, Categorical +from paddle.distribution import Categorical from parl.utils.utils import check_model_method -__all__ = ['PPO'] +__all__ = ['PPO_Atari'] -class PPO(parl.Algorithm): +class PPO_Atari(parl.Algorithm): def __init__(self, model, clip_param=0.1, @@ -33,9 +33,8 @@ def __init__(self, eps=1e-5, max_grad_norm=0.5, use_clipped_value_loss=True, - norm_adv=True, - continuous_action=False): - """ PPO algorithm + norm_adv=True): + """ PPO algorithm for Atari Args: model (parl.Model): forward network of actor and critic. @@ -47,7 +46,6 @@ def __init__(self, max_grad_norm (float): max gradient norm for gradient clipping. use_clipped_value_loss (bool): whether or not to use a clipped loss for the value function. norm_adv (bool): whether or not to use advantages normalization. - continuous_action (bool): whether or not is continuous action environment. """ # check model methods check_model_method(model, 'value', self.__class__.__name__) @@ -61,7 +59,6 @@ def __init__(self, assert isinstance(max_grad_norm, float) assert isinstance(use_clipped_value_loss, bool) assert isinstance(norm_adv, bool) - assert isinstance(continuous_action, bool) self.clip_param = clip_param self.value_loss_coef = value_loss_coef @@ -69,24 +66,13 @@ def __init__(self, self.max_grad_norm = max_grad_norm self.use_clipped_value_loss = use_clipped_value_loss self.norm_adv = norm_adv - self.continuous_action = continuous_action self.model = model clip = nn.ClipGradByNorm(self.max_grad_norm) self.optimizer = optim.Adam( - parameters=self.model.parameters(), - learning_rate=initial_lr, - epsilon=eps, - grad_clip=clip) - - def learn(self, - batch_obs, - batch_action, - batch_value, - batch_return, - batch_logprob, - batch_adv, - lr=None): + parameters=self.model.parameters(), learning_rate=initial_lr, epsilon=eps, grad_clip=clip) + + def learn(self, batch_obs, batch_action, batch_value, batch_return, batch_logprob, batch_adv, lr=None): """ update model with PPO algorithm Args: @@ -103,44 +89,34 @@ def learn(self, entropy_loss (float): entropy loss """ values = self.model.value(batch_obs) - if self.continuous_action: - mean, std = self.model.policy(batch_obs) - dist = Normal(mean, std) - action_log_probs = dist.log_prob(batch_action).sum(1) - dist_entropy = dist.entropy().sum(1) - else: - logits = self.model.policy(batch_obs) - dist = Categorical(logits=logits) - act_dim = logits.shape[-1] - batch_action = paddle.to_tensor(batch_action, dtype='int64') - actions_onehot = F.one_hot(batch_action, act_dim) + logits = self.model.policy(batch_obs) + dist = Categorical(logits=logits) + + act_dim = logits.shape[-1] + batch_action = paddle.to_tensor(batch_action, dtype='int64') + actions_onehot = F.one_hot(batch_action, act_dim) - action_log_probs = paddle.sum( - F.log_softmax(logits) * actions_onehot, axis=-1) - dist_entropy = dist.entropy() + action_log_probs = paddle.sum(F.log_softmax(logits) * actions_onehot, axis=-1) + dist_entropy = dist.entropy() entropy_loss = dist_entropy.mean() batch_adv = batch_adv if self.norm_adv: - batch_adv = (batch_adv - batch_adv.mean()) / ( - batch_adv.std() + 1e-8) + batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-8) ratio = paddle.exp(action_log_probs - batch_logprob) surr1 = ratio * batch_adv - surr2 = paddle.clip(ratio, 1.0 - self.clip_param, - 1.0 + self.clip_param) * batch_adv + surr2 = paddle.clip(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * batch_adv action_loss = -paddle.minimum(surr1, surr2).mean() values = values.reshape([-1]) # calculate value loss using semi gradient TD if self.use_clipped_value_loss: - value_pred_clipped = batch_value + paddle.clip( - values - batch_value, -self.clip_param, self.clip_param) + value_pred_clipped = batch_value + paddle.clip(values - batch_value, -self.clip_param, self.clip_param) value_losses = (values - batch_return).pow(2) value_losses_clipped = (value_pred_clipped - batch_return).pow(2) - value_loss = 0.5 * paddle.maximum(value_losses, - value_losses_clipped).mean() + value_loss = 0.5 * paddle.maximum(value_losses, value_losses_clipped).mean() else: value_loss = 0.5 * (values - batch_return).pow(2).mean() @@ -168,23 +144,14 @@ def sample(self, obs): """ value = self.model.value(obs) - if self.continuous_action: - mean, std = self.model.policy(obs) - dist = Normal(mean, std) - action = dist.sample([1]) + logits = self.model.policy(obs) + dist = Categorical(logits=logits) + action = dist.sample([1]) - action_log_probs = dist.log_prob(action).sum(-1) - action_entropy = dist.entropy().sum(-1).mean() - else: - logits = self.model.policy(obs) - dist = Categorical(logits=logits) - action = dist.sample([1]) - - act_dim = logits.shape[-1] - actions_onehot = F.one_hot(action, act_dim) - action_log_probs = paddle.sum( - F.log_softmax(logits) * actions_onehot, axis=-1) - action_entropy = dist.entropy() + act_dim = logits.shape[-1] + actions_onehot = F.one_hot(action, act_dim) + action_log_probs = paddle.sum(F.log_softmax(logits) * actions_onehot, axis=-1) + action_entropy = dist.entropy() return value, action, action_log_probs, action_entropy @@ -197,12 +164,9 @@ def predict(self, obs): action (torch tensor): action, shape([batch_size] + action_shape), noted that in the discrete case we take the argmax along the last axis as action """ - if self.continuous_action: - action, _ = self.model.policy(obs) - else: - logits = self.model.policy(obs) - probs = F.softmax(logits) - action = paddle.argmax(probs, 1) + logits = self.model.policy(obs) + probs = F.softmax(logits) + action = paddle.argmax(probs, 1) return action def value(self, obs): diff --git a/parl/algorithms/paddle/ppo_mujoco.py b/parl/algorithms/paddle/ppo_mujoco.py new file mode 100644 index 000000000..f4b004010 --- /dev/null +++ b/parl/algorithms/paddle/ppo_mujoco.py @@ -0,0 +1,191 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import paddle +import numpy as np +from copy import deepcopy +from parl.utils.utils import check_model_method + +__all__ = ['PPO_Mujoco'] + + +class PPO_Mujoco(parl.Algorithm): + def __init__(self, model, act_dim=None, loss_type='KLPEN', kl_targ=0.003, eta=50, clip_param=0.2, eps=1e-5): + """ PPO algorithm for Mujoco + + Args: + model (parl.Model): model defining forward network of policy and value. + act_dim (float): dimension of the action space. + loss_type (string): loss type of PPO algorithm, 'CLIP' or 'KLPEN'". + kl_targ (float): D_KL target value. + eta (float): multiplier for D_KL-kl_targ hinge-squared loss. + clip_param (float): epsilon used in the CLIP loss. + eps (float): A small float value for numerical stability. + """ + # check model methods and member variables + check_model_method(model, 'value', self.__class__.__name__) + check_model_method(model, 'policy', self.__class__.__name__) + assert hasattr(model, 'policy_model') + assert hasattr(model, 'value_model') + assert hasattr(model, 'policy_lr') + assert hasattr(model, 'value_lr') + + assert isinstance(act_dim, int) + assert isinstance(kl_targ, float) + assert isinstance(eta, float) + assert isinstance(eps, float) + assert isinstance(clip_param, float) + assert loss_type == 'CLIP' or loss_type == 'KLPEN' + self.loss_type = loss_type + self.act_dim = act_dim + self.clip_param = clip_param + self.eta = eta + self.kl_targ = kl_targ + + self.model = model + # Used to calculate probability of action in old policy + self.old_policy_model = deepcopy(model.policy_model) + + self.policy_lr = self.model.policy_lr + self.value_lr = self.model.value_lr + self.policy_optimizer = paddle.optimizer.Adam( + parameters=self.model.policy_model.parameters(), learning_rate=self.policy_lr, epsilon=eps) + self.value_optimizer = paddle.optimizer.Adam( + parameters=self.model.value_model.parameters(), learning_rate=self.value_lr, epsilon=eps) + + def _calc_logprob(self, actions, means, logvars): + """ Calculate log probabilities of actions, when given means and logvars + of normal distribution. + The constant sqrt(2 * pi) is omitted, which will be eliminated in later. + + Args: + actions: shape (batch_size, act_dim) + means: shape (batch_size, act_dim) + logvars: shape (act_dim) + + Returns: + logprob: shape (batch_size) + """ + logp = -0.5 * paddle.sum(logvars) + logp += -0.5 * paddle.sum((paddle.square(actions - means) / paddle.exp(logvars)), axis=1) + logprob = logp + return logprob + + def _calc_kl(self, means, logvars, old_means, old_logvars): + """ Calculate KL divergence between old and new distributions + See: https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback.E2.80.93Leibler_divergence + + Args: + means: shape (batch_size, act_dim) + logvars: shape (act_dim) + old_means: shape (batch_size, act_dim) + old_logvars: shape (act_dim) + + Returns: + kl: shape (batch_size) + entropy + """ + log_det_cov_old = paddle.sum(old_logvars) + log_det_cov_new = paddle.sum(logvars) + tr_old_new = paddle.sum(paddle.exp(old_logvars - logvars)) + kl = 0.5 * paddle.mean( + paddle.sum(paddle.square(means - old_means) / paddle.exp(logvars), axis=1) + + (log_det_cov_new - log_det_cov_old) + tr_old_new - self.act_dim) + + entropy = 0.5 * (self.act_dim * (np.log(2 * np.pi) + 1) + paddle.sum(logvars)) + + return kl, entropy + + def value(self, obs): + """ Use value model of self.model to predict value of obs + """ + return self.model.value(obs) + + def predict(self, obs): + """ Use the policy model of self.model to predict means and logvars of actions + """ + means, logvars = self.model.policy(obs) + return means + + def sample(self, obs): + """ Use the policy model of self.model to sample actions + """ + means, logvars = self.model.policy(obs) + sampled_act = means + ( + paddle.exp(logvars / 2.0) * # stddev + paddle.standard_normal(shape=(self.act_dim, ), dtype='float32')) + return sampled_act + + def policy_learn(self, batch_obs, batch_action, batch_adv, beta, lr_multiplier): + """ Learn policy model with: + 1. CLIP loss: Clipped Surrogate Objective + 2. KLPEN loss: Adaptive KL Penalty Objective + See: https://arxiv.org/pdf/1707.02286.pdf + + Args: + batch_obs: Tensor, (batch_size, obs_dim) + batch_action: Tensor, (batch_size, act_dim) + batch_adv: Tensor (batch_size, ) + beta: Tensor (1) or None. If None, use CLIP Loss; else, use KLPEN loss. + lr_multiplier: Tensor (1) + """ + old_means, old_logvars = self.old_policy_model.policy(batch_obs) + old_means.stop_gradient = True + old_logvars.stop_gradient = True + + old_logprob = self._calc_logprob(batch_action, old_means, old_logvars) + old_logprob.stop_gradient = True + + means, logvars = self.model.policy(batch_obs) + logprob = self._calc_logprob(batch_action, means, logvars) + kl, entropy = self._calc_kl(means, logvars, old_means, old_logvars) + + if self.loss_type == "KLPEN": + loss1 = -(batch_adv * paddle.exp(logprob - old_logprob)).mean() + loss2 = (kl * beta).mean() + loss3 = self.eta * paddle.square(paddle.maximum(paddle.to_tensor(0.0), kl - 2.0 * self.kl_targ)) + loss = loss1 + loss2 + loss3 + elif self.loss_type == "CLIP": + ratio = paddle.exp(logprob - old_logprob) + surr1 = ratio * batch_adv + surr2 = paddle.clip(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * batch_adv + loss = -paddle.minimum(surr1, surr2).mean() + else: + raise ValueError("Policy loss type error, 'CLIP' or 'KLPEN'") + + self.policy_optimizer.set_lr(self.policy_lr * lr_multiplier) + + self.policy_optimizer.clear_grad() + loss.backward() + self.policy_optimizer.step() + return loss, kl, entropy + + def value_learn(self, batch_obs, batch_return): + """ Learn the value model with square error cost + """ + predict_val = self.model.value(batch_obs) + predict_val = predict_val.reshape([-1]) + + loss = (predict_val - batch_return).pow(2).mean() + + self.value_optimizer.clear_grad() + loss.backward() + self.value_optimizer.step() + return loss + + def sync_old_policy(self): + """ Synchronize weights of self.model.policy_model to self.old_policy_model + """ + self.model.policy_model.sync_weights_to(self.old_policy_model) diff --git a/parl/utils/rl_utils.py b/parl/utils/rl_utils.py index 602d12c8e..e015fd577 100644 --- a/parl/utils/rl_utils.py +++ b/parl/utils/rl_utils.py @@ -15,7 +15,7 @@ import numpy as np import scipy.signal -__all__ = ['calc_discount_sum_rewards', 'calc_gae'] +__all__ = ['calc_discount_sum_rewards', 'calc_gae', "Scaler"] def calc_discount_sum_rewards(rewards, gamma): @@ -49,3 +49,50 @@ def calc_gae(rewards, values, next_value, gamma, lam): tds = rewards + gamma * np.append(values[1:], next_value) - values advantages = calc_discount_sum_rewards(tds, gamma * lam) return advantages + + +class Scaler(object): + """ Generate scale and offset based on running mean and stddev along axis=0 + + offset = running mean + scale = 1 / (stddev + 0.1) / 3 (i.e. 3x stddev = +/- 1.0) + """ + + def __init__(self, obs_dim): + """ + Args: + obs_dim: dimension of axis=1 + """ + self.vars = np.zeros(obs_dim) + self.means = np.zeros(obs_dim) + self.cnt = 0 + self.first_pass = True + + def update(self, x): + """ Update running mean and variance (this is an exact method) + Args: + x: NumPy array, shape = (N, obs_dim) + + see: https://stats.stackexchange.com/questions/43159/how-to-calculate-pooled- + variance-of-two-groups-given-known-group-variances-mean + """ + if self.first_pass: + self.means = np.mean(x, axis=0) + self.vars = np.var(x, axis=0) + self.cnt = x.shape[0] + self.first_pass = False + else: + n = x.shape[0] + new_data_var = np.var(x, axis=0) + new_data_mean = np.mean(x, axis=0) + new_data_mean_sq = np.square(new_data_mean) + new_means = ((self.means * self.cnt) + (new_data_mean * n)) / (self.cnt + n) + self.vars = (((self.cnt * (self.vars + np.square(self.means))) + + (n * (new_data_var + new_data_mean_sq))) / (self.cnt + n) - np.square(new_means)) + self.vars = np.maximum(0.0, self.vars) # occasionally goes negative, clip + self.means = new_means + self.cnt += n + + def get(self): + """ returns 2-tuple: (scale, offset) """ + return 1 / (np.sqrt(self.vars) + 0.1) / 3, self.means