diff --git a/.teamcity/build.sh b/.teamcity/build.sh
index 35d295d01..125399945 100755
--- a/.teamcity/build.sh
+++ b/.teamcity/build.sh
@@ -49,12 +49,14 @@ function run_example_test {
python -m pip uninstall -r ./examples/DQN_variant/requirements.txt -y
python -m pip install -r ./examples/PPO/requirements_atari.txt
- python examples/PPO/train.py --train_total_steps 5000 --env PongNoFrameskip-v4
+ python examples/PPO/atari/train.py --train_total_steps 5000 --env PongNoFrameskip-v4
python -m pip uninstall -r ./examples/PPO/requirements_atari.txt -y
+ xparl start --port 8010 --cpu_num 8
python -m pip install -r ./examples/PPO/requirements_mujoco.txt
- python examples/PPO/train.py --train_total_steps 5000 --env HalfCheetah-v4 --continuous_action
+ python examples/PPO/mujoco/train.py --env 'HalfCheetah-v2' --train_total_episodes 100 --env_num 5
python -m pip uninstall -r ./examples/PPO/requirements_mujoco.txt -y
+ xparl stop
python -m pip install -r ./examples/SAC/requirements.txt
python examples/SAC/train.py --train_total_steps 5000 --env HalfCheetah-v4
diff --git a/examples/PPO/README.md b/examples/PPO/README.md
index dd8d88faa..131a35494 100644
--- a/examples/PPO/README.md
+++ b/examples/PPO/README.md
@@ -4,15 +4,17 @@ Based on PARL, the PPO algorithm of deep reinforcement learning has been reprodu
> Paper: PPO in [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347)
### Mujoco/Atari games introduction
-PARL currently supports the open-source version of Mujoco provided by DeepMind, so users do not need to download binaries of Mujoco as well as install mujoco-py and get license. For more details, please visit [Mujoco](https://github.com/deepmind/mujoco).
+PARL currently supports the open-source version of Mujoco provided by DeepMind, so users do not need to download binaries of Mujoco as well as install [mujoco-py](https://github.com/openai/mujoco-py#install-mujoco). For more details, please visit [Mujoco](https://github.com/deepmind/mujoco).
### Benchmark result
#### 1. Mujoco games results
+The horizontal axis represents the number of episodes.
#### 2. Atari games results
+The horizontal axis represents the number of steps.
@@ -23,29 +25,21 @@ PARL currently supports the open-source version of Mujoco provided by DeepMind,
### Mujoco-Dependencies:
+ python3.7+
+ [paddle>=2.3.1](https://github.com/PaddlePaddle/Paddle)
-+ [parl>=2.1.1](https://github.com/PaddlePaddle/PARL)
-+ gym>=0.26.0
++ [parl>=2.2.2](https://github.com/PaddlePaddle/PARL)
++ gym==0.18.0
+ mujoco>=2.2.2
++ mujoco-py==2.1.2.14
### Atari-Dependencies:
+ [paddle>=2.3.1](https://github.com/PaddlePaddle/Paddle)
-+ [parl>=2.1.1](https://github.com/PaddlePaddle/PARL)
++ [parl>=2.2.2](https://github.com/PaddlePaddle/PARL)
+ gym==0.18.0
+ atari-py==0.2.6
+ opencv-python
-### Training:
-
-```
-# To train an agent for discrete action game (Atari: PongNoFrameskip-v4 by default)
-python train.py
-
-# To train an agent for continuous action game (Mujoco)
-python train.py --env 'HalfCheetah-v4' --continuous_action --train_total_steps 1000000
-```
-### Distributed Training
-Accelerate training process by setting `xparl_addr` and `env_num > 1` when environment simulation running very slow.
+### Training Mujoco Distributedly
+Accelerate training process by setting `xparl_addr` and `env_num > 1` when environment simulation running very slowly.
At first, we can start a local cluster with 8 CPUs:
```
@@ -56,14 +50,25 @@ Note that if you have started a master before, you don't have to run the above
command. For more information about the cluster, please refer to our
[documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html).
-Then we can start the distributed training by running:
+Then we can start the distributed training for mujoco games by running:
```
-# To train an agent distributedly
+cd mujoco
-# for discrete action game (Atari games)
-python train.py --env "PongNoFrameskip-v4" --env_num 8 --xparl_addr 'localhost:8010'
+python train.py --env 'HalfCheetah-v2' --train_total_episodes 100000 --env_num 5
+```
-# for continuous action game (Mujoco games)
-python train.py --env 'HalfCheetah-v4' --continuous_action --train_total_steps 1000000 --env_num 5 --xparl_addr 'localhost:8010'
+
+### Training Atari
+To train an agent for discrete action game (Atari: PongNoFrameskip-v4 by default):
+
+```
+cd atari
+
+# Local training
+python train.py
+
+# Distributed training
+xparl start --port 8010 --cpu_num 8
+python train.py --env "PongNoFrameskip-v4" --env_num 8 --xparl_addr 'localhost:8010'
```
diff --git a/examples/PPO/agent.py b/examples/PPO/atari/atari_agent.py
similarity index 90%
rename from examples/PPO/agent.py
rename to examples/PPO/atari/atari_agent.py
index 980b6a7d1..cb89e87d1 100644
--- a/examples/PPO/agent.py
+++ b/examples/PPO/atari/atari_agent.py
@@ -18,7 +18,7 @@
from parl.utils.scheduler import LinearDecayScheduler
-class PPOAgent(parl.Agent):
+class AtariAgent(parl.Agent):
""" Agent of PPO env
Args:
@@ -27,12 +27,11 @@ class PPOAgent(parl.Agent):
"""
def __init__(self, algorithm, config):
- super(PPOAgent, self).__init__(algorithm)
+ super(AtariAgent, self).__init__(algorithm)
self.config = config
if self.config['lr_decay']:
- self.lr_scheduler = LinearDecayScheduler(
- self.config['initial_lr'], self.config['num_updates'])
+ self.lr_scheduler = LinearDecayScheduler(self.config['initial_lr'], self.config['num_updates'])
def predict(self, obs):
""" Predict action from current policy given observation
@@ -85,8 +84,7 @@ def learn(self, rollout):
else:
lr = None
- minibatch_size = int(
- self.config['batch_size'] // self.config['num_minibatches'])
+ minibatch_size = int(self.config['batch_size'] // self.config['num_minibatches'])
indexes = np.arange(self.config['batch_size'])
for epoch in range(self.config['update_epochs']):
@@ -105,9 +103,8 @@ def learn(self, rollout):
batch_return = paddle.to_tensor(batch_return)
batch_value = paddle.to_tensor(batch_value)
- value_loss, action_loss, entropy_loss = self.alg.learn(
- batch_obs, batch_action, batch_value, batch_return,
- batch_logprob, batch_adv, lr)
+ value_loss, action_loss, entropy_loss = self.alg.learn(batch_obs, batch_action, batch_value,
+ batch_return, batch_logprob, batch_adv, lr)
value_loss_epoch += value_loss
action_loss_epoch += action_loss
diff --git a/examples/PPO/atari_config.py b/examples/PPO/atari/atari_config.py
similarity index 100%
rename from examples/PPO/atari_config.py
rename to examples/PPO/atari/atari_config.py
diff --git a/examples/PPO/atari_model.py b/examples/PPO/atari/atari_model.py
similarity index 100%
rename from examples/PPO/atari_model.py
rename to examples/PPO/atari/atari_model.py
diff --git a/examples/PPO/env_utils.py b/examples/PPO/atari/env_utils.py
similarity index 59%
rename from examples/PPO/env_utils.py
rename to examples/PPO/atari/env_utils.py
index 410a6f059..dd394e76d 100644
--- a/examples/PPO/env_utils.py
+++ b/examples/PPO/atari/env_utils.py
@@ -16,13 +16,11 @@
import gym
import numpy as np
from parl.utils import logger
+from parl.env.atari_wrappers import wrap_deepmind
-TEST_EPISODE = 3
# wrapper parameters for atari env
ENV_DIM = 84
OBS_FORMAT = 'NCHW'
-# wrapper parameters for mujoco env
-GAMMA = 0.99
class ParallelEnv(object):
@@ -39,14 +37,9 @@ def __init__(self, config=None):
base_env = LocalEnv
if config['seed']:
- self.env_list = [
- base_env(config['env'], config['seed'] + i)
- for i in range(self.env_num)
- ]
+ self.env_list = [base_env(config['env'], config['seed'] + i) for i in range(self.env_num)]
else:
- self.env_list = [
- base_env(config['env']) for _ in range(self.env_num)
- ]
+ self.env_list = [base_env(config['env']) for _ in range(self.env_num)]
if hasattr(self.env_list[0], '_max_episode_steps'):
self._max_episode_steps = self.env_list[0]._max_episode_steps
else:
@@ -68,10 +61,7 @@ def reset(self):
def step(self, action_list):
next_obs_list, reward_list, done_list, info_list = [], [], [], []
if self.use_xparl:
- return_list = [
- self.env_list[i].step(action_list[i])
- for i in range(self.env_num)
- ]
+ return_list = [self.env_list[i].step(action_list[i]) for i in range(self.env_num)]
return_list = [return_.get() for return_ in return_list]
return_list = np.array(return_list, dtype=object)
@@ -89,8 +79,7 @@ def step(self, action_list):
done = done_[i]
info = info_[i]
else:
- next_obs, reward, done, info = self.env_list[i].step(
- action_list[i])
+ next_obs, reward, done, info = self.env_list[i].step(action_list[i])
self.episode_steps_list[i] += 1
self.episode_reward_list[i] += reward
@@ -104,49 +93,26 @@ def step(self, action_list):
next_obs = self.env_list[i].reset()
self.episode_steps_list[i] = 0
self.episode_reward_list[i] = 0
- if self.env_list[i].continuous_action:
- # get running mean and variance of obs
- self.eval_ob_rms = self.env_list[i].env.get_ob_rms()
next_obs_list.append(next_obs)
reward_list.append(reward)
done_list.append(done)
info_list.append(info)
- return np.array(next_obs_list), np.array(reward_list), np.array(
- done_list), np.array(info_list)
+ return np.array(next_obs_list), np.array(reward_list), np.array(done_list), np.array(info_list)
class LocalEnv(object):
def __init__(self, env_name, env_seed=None, test=False, ob_rms=None):
env = gym.make(env_name)
- # is instance of gym.spaces.Box
- if hasattr(env.action_space, 'high'):
- from parl.env.mujoco_wrappers import wrap_rms
- self._max_episode_steps = env._max_episode_steps
- self.continuous_action = True
- if test:
- self.env = wrap_rms(env, GAMMA, test=True, ob_rms=ob_rms)
- else:
- self.env = wrap_rms(env, gamma=GAMMA)
# is instance of gym.spaces.Discrete
- elif hasattr(env.action_space, 'n'):
- from parl.env.atari_wrappers import wrap_deepmind
- self.continuous_action = False
+ if hasattr(env.action_space, 'n'):
if test:
- self.env = wrap_deepmind(
- env,
- dim=ENV_DIM,
- obs_format=OBS_FORMAT,
- test=True,
- test_episodes=1)
+ self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT, test=True, test_episodes=1)
else:
- self.env = wrap_deepmind(
- env, dim=ENV_DIM, obs_format=OBS_FORMAT)
+ self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT)
else:
- raise AssertionError(
- 'act_space must be instance of gym.spaces.Box or gym.spaces.Discrete'
- )
+ raise AssertionError('act_space must be instance of gym.spaces.Discrete')
self.obs_space = self.env.observation_space
self.act_space = self.env.action_space
@@ -166,31 +132,13 @@ class RemoteEnv(object):
def __init__(self, env_name, env_seed=None, test=False, ob_rms=None):
env = gym.make(env_name)
- if hasattr(env.action_space, 'high'):
- from parl.env.mujoco_wrappers import wrap_rms
- self._max_episode_steps = env._max_episode_steps
- self.continuous_action = True
- if test:
- self.env = wrap_rms(env, GAMMA, test=True, ob_rms=ob_rms)
- else:
- self.env = wrap_rms(env, gamma=GAMMA)
- elif hasattr(env.action_space, 'n'):
- from parl.env.atari_wrappers import wrap_deepmind
- self.continuous_action = False
+ if hasattr(env.action_space, 'n'):
if test:
- self.env = wrap_deepmind(
- env,
- dim=ENV_DIM,
- obs_format=OBS_FORMAT,
- test=True,
- test_episodes=1)
+ self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT, test=True, test_episodes=1)
else:
- self.env = wrap_deepmind(
- env, dim=ENV_DIM, obs_format=OBS_FORMAT)
+ self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT)
else:
- raise AssertionError(
- 'act_space must be instance of gym.spaces.Box or gym.spaces.Discrete'
- )
+ raise AssertionError('act_space must be instance of gym.spaces.Discrete')
if env_seed:
self.env.seed(env_seed)
@@ -201,6 +149,4 @@ def step(self, action):
return self.env.step(action)
def render(self):
- return logger.warning(
- 'Can not render in remote environment, render() have been skipped.'
- )
+ return logger.warning('Can not render in remote environment, render() have been skipped.')
diff --git a/examples/PPO/storage.py b/examples/PPO/atari/storage.py
similarity index 84%
rename from examples/PPO/storage.py
rename to examples/PPO/atari/storage.py
index 059beb1ec..99fd9bfe4 100644
--- a/examples/PPO/storage.py
+++ b/examples/PPO/atari/storage.py
@@ -17,10 +17,8 @@
class RolloutStorage():
def __init__(self, step_nums, env_num, obs_space, act_space):
- self.obs = np.zeros(
- (step_nums, env_num) + obs_space.shape, dtype='float32')
- self.actions = np.zeros(
- (step_nums, env_num) + act_space.shape, dtype='float32')
+ self.obs = np.zeros((step_nums, env_num) + obs_space.shape, dtype='float32')
+ self.actions = np.zeros((step_nums, env_num) + act_space.shape, dtype='float32')
self.logprobs = np.zeros((step_nums, env_num), dtype='float32')
self.rewards = np.zeros((step_nums, env_num), dtype='float32')
self.dones = np.zeros((step_nums, env_num), dtype='float32')
@@ -54,10 +52,8 @@ def compute_returns(self, value, done, gamma=0.99, gae_lambda=0.95):
else:
nextnonterminal = 1.0 - self.dones[t + 1]
nextvalues = self.values[t + 1]
- delta = self.rewards[
- t] + gamma * nextvalues * nextnonterminal - self.values[t]
- advantages[
- t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
+ delta = self.rewards[t] + gamma * nextvalues * nextnonterminal - self.values[t]
+ advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
returns = advantages + self.values
self.returns = returns
self.advantages = advantages
@@ -72,5 +68,4 @@ def sample_batch(self, idx):
b_returns = self.returns.reshape(-1)
b_values = self.values.reshape(-1)
- return b_obs[idx], b_actions[idx], b_logprobs[idx], b_advantages[
- idx], b_returns[idx], b_values[idx]
+ return b_obs[idx], b_actions[idx], b_logprobs[idx], b_advantages[idx], b_returns[idx], b_values[idx]
diff --git a/examples/PPO/train.py b/examples/PPO/atari/train.py
similarity index 60%
rename from examples/PPO/train.py
rename to examples/PPO/atari/train.py
index 52732672b..0bcb9dc06 100644
--- a/examples/PPO/train.py
+++ b/examples/PPO/atari/train.py
@@ -16,14 +16,12 @@
import numpy as np
from parl.utils import logger, summary
-from mujoco_config import mujoco_config
from atari_config import atari_config
from env_utils import ParallelEnv, LocalEnv
from storage import RolloutStorage
from atari_model import AtariModel
-from mujoco_model import MujocoModel
-from parl.algorithms import PPO
-from agent import PPOAgent
+from parl.algorithms import PPO_Atari
+from atari_agent import AtariAgent
# Runs policy until 'real done' and returns episode reward
@@ -43,7 +41,7 @@ def run_evaluate_episodes(agent, eval_env, eval_episodes):
def main():
- config = mujoco_config if args.continuous_action else atari_config
+ config = atari_config
if args.env_num:
config['env_num'] = args.env_num
config['env'] = args.env
@@ -53,8 +51,7 @@ def main():
config['train_total_steps'] = args.train_total_steps
config['batch_size'] = int(config['env_num'] * config['step_nums'])
- config['num_updates'] = int(
- config['train_total_steps'] // config['batch_size'])
+ config['num_updates'] = int(config['train_total_steps'] // config['batch_size'])
logger.info("------------------- PPO ---------------------")
logger.info('Env: {}, seed: {}'.format(config['env'], config['seed']))
@@ -67,20 +64,12 @@ def main():
obs_space = eval_env.obs_space
act_space = eval_env.act_space
- if config['continuous_action']:
- model = MujocoModel(obs_space, act_space)
- else:
- model = AtariModel(obs_space, act_space)
- ppo = PPO(
- model,
- clip_param=config['clip_param'],
- entropy_coef=config['entropy_coef'],
- initial_lr=config['initial_lr'],
- continuous_action=config['continuous_action'])
- agent = PPOAgent(ppo, config)
-
- rollout = RolloutStorage(config['step_nums'], config['env_num'], obs_space,
- act_space)
+ model = AtariModel(obs_space, act_space)
+ ppo = PPO_Atari(
+ model, clip_param=config['clip_param'], entropy_coef=config['entropy_coef'], initial_lr=config['initial_lr'])
+ agent = AtariAgent(ppo, config)
+
+ rollout = RolloutStorage(config['step_nums'], config['env_num'], obs_space, act_space)
obs = envs.reset()
done = np.zeros(config['env_num'], dtype='float32')
@@ -98,11 +87,9 @@ def main():
for k in range(config['env_num']):
if done[k] and "episode" in info[k].keys():
- logger.info(
- "Training: total steps: {}, episode rewards: {}".
- format(total_steps, info[k]['episode']['r']))
- summary.add_scalar("train/episode_reward",
- info[k]["episode"]["r"], total_steps)
+ logger.info("Training: total steps: {}, episode rewards: {}".format(
+ total_steps, info[k]['episode']['r']))
+ summary.add_scalar("train/episode_reward", info[k]["episode"]["r"], total_steps)
# Bootstrap value if not done
value = agent.value(obs)
@@ -115,53 +102,22 @@ def main():
while (total_steps + 1) // config['test_every_steps'] >= test_flag:
test_flag += 1
- if config['continuous_action']:
- # set running mean and variance of obs
- ob_rms = envs.eval_ob_rms
- eval_env.env.set_ob_rms(ob_rms)
-
- avg_reward = run_evaluate_episodes(agent, eval_env,
- config['eval_episode'])
+ avg_reward = run_evaluate_episodes(agent, eval_env, config['eval_episode'])
summary.add_scalar('eval/episode_reward', avg_reward, total_steps)
- logger.info('Evaluation over: {} episodes, Reward: {}'.format(
- config['eval_episode'], avg_reward))
+ logger.info('Evaluation over: {} episodes, Reward: {}'.format(config['eval_episode'], avg_reward))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
+ parser.add_argument("--env", type=str, default="PongNoFrameskip-v4", help="OpenAI gym environment name")
+ parser.add_argument("--seed", type=int, default=None, help="seed of the experiment")
parser.add_argument(
- "--env",
- type=str,
- default="PongNoFrameskip-v4",
- help="OpenAI gym environment name")
- parser.add_argument(
- "--seed", type=int, default=None, help="seed of the experiment")
- parser.add_argument(
- "--env_num",
- type=int,
- default=None,
- help=
- "number of the environment. Note: if greater than 1, xparl is needed")
- parser.add_argument(
- '--continuous_action',
- action='store_true',
- default=False,
- help='action type of the environment')
- parser.add_argument(
- "--xparl_addr",
- type=str,
- default=None,
- help="xparl address for distributed training ")
+ "--env_num", type=int, default=None, help="number of the environment. Note: if greater than 1, xparl is needed")
+ parser.add_argument("--xparl_addr", type=str, default=None, help="xparl address for distributed training ")
parser.add_argument(
- '--train_total_steps',
- type=int,
- default=10e6,
- help='number of total time steps to train (default: 10e6)')
+ '--train_total_steps', type=int, default=10e6, help='number of total time steps to train (default: 10e6)')
parser.add_argument(
- '--test_every_steps',
- type=int,
- default=int(5e3),
- help='the step interval between two consecutive evaluations')
+ '--test_every_steps', type=int, default=int(5e3), help='the step interval between two consecutive evaluations')
args = parser.parse_args()
main()
diff --git a/examples/PPO/mujoco/actor.py b/examples/PPO/mujoco/actor.py
new file mode 100644
index 000000000..1224bd398
--- /dev/null
+++ b/examples/PPO/mujoco/actor.py
@@ -0,0 +1,76 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gym
+import numpy as np
+import parl
+from mujoco_model import MujocoModel
+from mujoco_agent import MujocoAgent
+from parl.algorithms import PPO_Mujoco
+from parl.env.compat_wrappers import CompatWrapper
+
+
+@parl.remote_class(wait=False)
+class Actor(object):
+ def __init__(self, config, seed=None):
+ env = gym.make(config['env'])
+ self.env = CompatWrapper(env)
+ self.env.seed(seed)
+
+ obs_dim = self.env.observation_space.shape[0]
+ act_dim = self.env.action_space.shape[0]
+ obs_dim += 1 # add 1 to obs dim for time step feature
+
+ model = MujocoModel(obs_dim, act_dim)
+ alg = PPO_Mujoco(model, act_dim=act_dim)
+ self.agent = MujocoAgent(alg, config)
+
+ def run_episode(self, scaler):
+ obs = self.env.reset()
+ observes, actions, rewards, unscaled_obs = [], [], [], []
+ dones = []
+ step = 0.0
+ scale, offset = scaler.get()
+ scale[-1] = 1.0 # don't scale time step feature
+ offset[-1] = 0.0 # don't offset time step feature
+ while True:
+ obs = obs.reshape((1, -1))
+ obs = np.append(obs, [[step]], axis=1) # add time step feature
+ unscaled_obs.append(obs)
+ obs = (obs - offset) * scale # center and scale observations
+ obs = obs.astype('float32')
+ observes.append(obs)
+
+ action = self.agent.sample(obs)
+
+ action = action.reshape((1, -1)).astype('float32')
+ actions.append(action)
+
+ obs, reward, done, _ = self.env.step(np.squeeze(action))
+ dones.append(done)
+ rewards.append(reward)
+ step += 1e-3 # increment time step feature
+
+ if done:
+ break
+ return {
+ 'obs': np.concatenate(observes),
+ 'actions': np.concatenate(actions),
+ 'rewards': np.array(rewards, dtype='float32'),
+ 'dones': np.array(dones, dtype='float32'),
+ 'unscaled_obs': np.concatenate(unscaled_obs)
+ }
+
+ def set_weights(self, params):
+ self.agent.set_weights(params)
diff --git a/examples/PPO/mujoco/mujoco_agent.py b/examples/PPO/mujoco/mujoco_agent.py
new file mode 100644
index 000000000..8f24106d3
--- /dev/null
+++ b/examples/PPO/mujoco/mujoco_agent.py
@@ -0,0 +1,146 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import parl
+import paddle
+import numpy as np
+
+
+class MujocoAgent(parl.Agent):
+ """ Agent of PPO env
+
+ Args:
+ algorithm (`parl.Algorithm`): algorithm to be used in this agent.
+ config (dict): configs that used in this agent
+ """
+ def __init__(self, algorithm, config):
+ super(MujocoAgent, self).__init__(algorithm)
+
+ self.config = config
+ self.kl_targ = self.config['kl_targ']
+ # Adaptive kl penalty coefficient
+ self.beta = 1.0 # dynamically adjusted D_KL loss multiplier
+ self.lr_multiplier = 1.0 # dynamically adjust lr when D_KL out of control
+
+ self.value_learn_buffer = None
+
+ def sample(self, obs):
+ """ Sample action from current policy given observation
+
+ Args:
+ obs (np.array): observation, shape([batch_size] + obs_shape)
+ """
+ obs = paddle.to_tensor(obs, dtype='float32')
+ action = self.alg.sample(obs)
+ action_numpy = action.detach().numpy()[0]
+ return action_numpy
+
+ def predict(self, obs):
+ """ Predict action from current policy given observation
+
+ Args:
+ obs (np.array): observation, shape([batch_size] + obs_shape)
+ """
+ obs = paddle.to_tensor(obs, dtype='float32')
+ action = self.alg.predict(obs)
+ action_numpy = action.detach().numpy()[0]
+ return action_numpy
+
+ def value(self, obs):
+ """ use the model to predict obs values
+
+ Args:
+ obs (torch tensor): observation, shape([batch_size] + obs_shape)
+ """
+ obs = paddle.to_tensor(obs, dtype='float32')
+ value = self.alg.value(obs)
+ value = value.detach().numpy()
+ return value
+
+ def _batch_policy_learn(self, obs, actions, advantages):
+ obs = paddle.to_tensor(obs)
+ actions = paddle.to_tensor(actions)
+ advantages = paddle.to_tensor(advantages)
+
+ loss, kl, entropy = self.alg.policy_learn(
+ obs, actions, advantages, beta=self.beta, lr_multiplier=self.lr_multiplier)
+ return loss, kl, entropy
+
+ def _batch_value_learn(self, obs, discount_sum_rewards):
+ obs = paddle.to_tensor(obs)
+ discount_sum_rewards = paddle.to_tensor(discount_sum_rewards)
+ loss = self.alg.value_learn(obs, discount_sum_rewards)
+ return loss
+
+ def policy_learn(self, obs, actions, advantages):
+ """ policy learn
+ """
+ self.alg.sync_old_policy()
+
+ all_loss, all_kl = [], []
+ for _ in range(self.config['policy_learn_times']):
+ loss, kl, entropy = self._batch_policy_learn(obs, actions, advantages)
+ loss, kl, entropy = loss.numpy()[0], kl.numpy()[0], entropy.numpy()[0]
+ if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly
+ break
+ all_loss.append(loss)
+ all_kl.append(kl)
+
+ if kl > self.kl_targ * 2: # servo beta to reach D_KL target
+ self.beta = np.minimum(35, 1.5 * self.beta) # max clip beta
+ if self.beta > 30 and self.lr_multiplier > 0.1:
+ self.lr_multiplier /= 1.5
+ elif kl < self.kl_targ / 2:
+ self.beta = np.maximum(1 / 35, self.beta / 1.5) # min clip beta
+ if self.beta < (1 / 30) and self.lr_multiplier < 10:
+ self.lr_multiplier *= 1.5
+ return loss, kl, self.beta, self.lr_multiplier, entropy
+
+ def value_learn(self, obs, discount_sum_rewards):
+ """ value learn
+ """
+ data_size = obs.shape[0]
+ num_batches = max(data_size // self.config['value_batch_size'], 1)
+ batch_size = data_size // num_batches
+
+ if self.value_learn_buffer is None:
+ obs_train, discount_sum_rewards_train = obs, discount_sum_rewards
+ else:
+ obs_train = np.concatenate([obs, self.value_learn_buffer[0]])
+ discount_sum_rewards_train = np.concatenate([discount_sum_rewards, self.value_learn_buffer[1]])
+ self.value_learn_buffer = (obs, discount_sum_rewards)
+
+ all_loss = []
+ y_hat = self.alg.model.value(paddle.to_tensor(obs)).numpy().reshape(
+ [-1]) # check explained variance prior to update
+ old_exp_var = 1 - np.var(discount_sum_rewards - y_hat) / np.var(discount_sum_rewards)
+
+ for _ in range(self.config['value_learn_times']):
+ random_ids = np.arange(obs_train.shape[0])
+ np.random.shuffle(random_ids)
+ shuffle_obs_train = obs_train[random_ids]
+ shuffle_discount_sum_rewards_train = discount_sum_rewards_train[random_ids]
+ start = 0
+ while start < data_size:
+ end = start + batch_size
+ loss = self._batch_value_learn(shuffle_obs_train[start:end, :],
+ shuffle_discount_sum_rewards_train[start:end])
+ loss = loss.numpy()[0]
+ all_loss.append(loss)
+ start += batch_size
+ y_hat = self.alg.model.value(paddle.to_tensor(obs)).numpy().reshape(
+ [-1]) # check explained variance prior to update
+ value_loss = np.mean(np.square(y_hat - discount_sum_rewards)) # explained variance after update
+ exp_var = 1 - np.var(discount_sum_rewards - y_hat) / np.var(discount_sum_rewards)
+ return value_loss, exp_var, old_exp_var
diff --git a/examples/PPO/mujoco/mujoco_config.py b/examples/PPO/mujoco/mujoco_config.py
new file mode 100644
index 000000000..af5310f9c
--- /dev/null
+++ b/examples/PPO/mujoco/mujoco_config.py
@@ -0,0 +1,40 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+mujoco_config = {
+ ## Commented parameters are set to default values in ppo
+
+ #========== env config ==========
+ 'env': 'HalfCheetah-v2', # environment name
+ 'env_num': 5, # number of the environment
+ 'seed': 120, # seed of the experiment
+ 'xparl_addr': "localhost:8010", # xparl address for distributed training
+
+ #========== training config ==========
+ 'train_total_episodes': int(1e6), # max training steps
+ 'episodes_per_batch': 5,
+ 'policy_learn_times': 20, # number of epochs for updating (ie K in the paper)
+ 'value_learn_times': 10,
+ 'value_batch_size': 256,
+ 'eval_episode': 3,
+ 'test_every_episodes': int(5e3), # interval between evaluations
+
+ #========== coefficient of ppo ==========
+ 'kl_targ': 0.003, # D_KL target value
+ 'loss_type': 'KLPEN', # Choose loss type of PPO algorithm, 'CLIP' or 'KLPEN'
+ 'eps': 1e-5, # Adam optimizer epsilon (default: 1e-5)
+ 'clip_param': 0.2, # epsilon in clipping loss
+ 'gamma': 0.995, # discounting factor
+ 'gae_lambda': 0.98, # Lambda parameter for calculating N-step advantage
+}
diff --git a/examples/PPO/mujoco/mujoco_model.py b/examples/PPO/mujoco/mujoco_model.py
new file mode 100644
index 000000000..0a310469e
--- /dev/null
+++ b/examples/PPO/mujoco/mujoco_model.py
@@ -0,0 +1,104 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import parl
+import paddle
+import paddle.nn as nn
+import numpy as np
+
+
+class MujocoModel(parl.Model):
+ """ The Model for Mujoco env
+ Args:
+ obs_dim (int): observation dimension.
+ act_dim (int): action dimension.
+ """
+
+ def __init__(self, obs_dim, act_dim, init_logvar=0.0):
+ super(MujocoModel, self).__init__()
+ super(MujocoModel, self).__init__()
+ self.policy_model = PolicyModel(obs_dim, act_dim, init_logvar)
+ self.value_model = ValueModel(obs_dim)
+ self.policy_lr = self.policy_model.lr
+ self.value_lr = self.value_model.lr
+
+ def value(self, obs):
+ """ Get value network prediction
+ Args:
+ obs (np.array): current observation
+ """
+ return self.value_model.value(obs)
+
+ def policy(self, obs):
+ """ Get policy network prediction
+ Args:
+ obs (np.array): current observation
+ """
+ return self.policy_model.policy(obs)
+
+
+class PolicyModel(parl.Model):
+ def __init__(self, obs_dim, act_dim, init_logvar):
+ super(PolicyModel, self).__init__()
+ self.policy_logvar = -1.0
+ self.obs_dim = obs_dim
+ self.act_dim = act_dim
+ hid1_size = obs_dim * 10
+ hid3_size = act_dim * 10
+ hid2_size = int(np.sqrt(hid1_size * hid3_size))
+
+ self.lr = 9e-4 / np.sqrt(hid2_size)
+
+ self.fc1 = nn.Linear(obs_dim, hid1_size)
+ self.fc2 = nn.Linear(hid1_size, hid2_size)
+ self.fc3 = nn.Linear(hid2_size, hid3_size)
+ self.fc_policy = nn.Linear(hid3_size, act_dim)
+
+ # logvar_speed is used to 'fool' gradient descent into making faster updates to log-variances.
+ # heuristic sets logvar_speed based on network size.
+ logvar_speed = (10 * hid3_size) // 48 # default setting
+ # logvar_speed = (10 * hid3_size) // 8 # finetuned for Humanoid-v2 to achieve fast convergence
+ self.fc_pi_std = paddle.create_parameter([logvar_speed, act_dim],
+ dtype='float32',
+ default_initializer=nn.initializer.Constant(value=init_logvar))
+
+ def policy(self, obs):
+ hid1 = paddle.tanh(self.fc1(obs))
+ hid2 = paddle.tanh(self.fc2(hid1))
+ hid3 = self.fc3(hid2)
+ means = self.fc_policy(hid3)
+ logvars = paddle.sum(self.fc_pi_std, axis=0) + self.policy_logvar
+ return means, logvars
+
+
+class ValueModel(parl.Model):
+ def __init__(self, obs_dim):
+ super(ValueModel, self).__init__()
+ hid1_size = obs_dim * 10
+ hid3_size = 5
+ hid2_size = int(np.sqrt(hid1_size * hid3_size))
+
+ self.lr = 1e-2 / np.sqrt(hid2_size) # 1e-3 empirically determined
+
+ self.fc1 = nn.Linear(obs_dim, hid1_size)
+ self.fc2 = nn.Linear(hid1_size, hid2_size)
+ self.fc3 = nn.Linear(hid2_size, hid3_size)
+ self.fc_value = nn.Linear(hid3_size, 1)
+
+ def value(self, obs):
+ hid1 = paddle.tanh(self.fc1(obs))
+ hid2 = paddle.tanh(self.fc2(hid1))
+ hid3 = paddle.tanh(self.fc3(hid2))
+ value = self.fc_value(hid3)
+ return value
diff --git a/examples/PPO/mujoco/train.py b/examples/PPO/mujoco/train.py
new file mode 100644
index 000000000..e7b33f7e2
--- /dev/null
+++ b/examples/PPO/mujoco/train.py
@@ -0,0 +1,178 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import gym
+import numpy as np
+
+import parl
+from parl.utils import logger, summary
+from parl.utils.rl_utils import calc_gae, calc_discount_sum_rewards, Scaler
+from parl.env.compat_wrappers import CompatWrapper
+from parl.algorithms import PPO_Mujoco
+from mujoco_model import MujocoModel
+from mujoco_agent import MujocoAgent
+from actor import Actor
+from mujoco_config import mujoco_config
+
+
+def run_evaluate_episodes(env, agent, scaler, eval_episodes):
+ eval_episode_rewards = []
+ while len(eval_episode_rewards) < eval_episodes:
+ obs = env.reset()
+ rewards = 0
+ step = 0.0
+ scale, offset = scaler.get()
+ scale[-1] = 1.0 # don't scale time step feature
+ offset[-1] = 0.0 # don't offset time step feature
+ while True:
+ obs = obs.reshape((1, -1))
+ obs = np.append(obs, [[step]], axis=1) # add time step feature
+ obs = (obs - offset) * scale # center and scale observations
+ obs = obs.astype('float32')
+
+ action = agent.predict(obs)
+ obs, reward, done, _ = env.step(np.squeeze(action))
+ rewards += reward
+ step += 1e-3 # increment time step feature
+
+ if done:
+ break
+ eval_episode_rewards.append(rewards)
+ return np.mean(eval_episode_rewards)
+
+
+def get_remote_trajectories(actors, scaler):
+ remote_ids = [actor.run_episode(scaler) for actor in actors]
+ return_list = [return_.get() for return_ in remote_ids]
+
+ trajectories, all_unscaled_obs = [], []
+ for res in return_list:
+ obs, actions, rewards, dones, unscaled_obs = res['obs'], res['actions'], res['rewards'], res['dones'], res[
+ 'unscaled_obs']
+ trajectories.append({'obs': obs, 'actions': actions, 'rewards': rewards, 'dones': dones})
+ all_unscaled_obs.append(unscaled_obs)
+ # update running statistics for scaling observations
+ scaler.update(np.concatenate(all_unscaled_obs))
+ return trajectories
+
+
+def build_train_data(config, trajectories, agent):
+ train_obs, train_actions, train_advantages, train_discount_sum_rewards = [], [], [], []
+ for trajectory in trajectories:
+ pred_values = agent.value(trajectory['obs']).squeeze()
+
+ # scale rewards
+ scale_rewards = trajectory['rewards'] * (1 - config['gamma'])
+ if len(scale_rewards) <= 1:
+ continue
+ discount_sum_rewards = calc_discount_sum_rewards(scale_rewards, config['gamma']).astype('float32')
+ advantages = calc_gae(scale_rewards, pred_values, 0, config['gamma'], config['gae_lambda'])
+ advantages = advantages.astype('float32')
+ # normalize advantages
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
+
+ train_obs.append(trajectory['obs'])
+ train_actions.append(trajectory['actions'])
+ train_advantages.append(advantages)
+ train_discount_sum_rewards.append(discount_sum_rewards)
+
+ train_obs = np.concatenate(train_obs)
+ train_actions = np.concatenate(train_actions)
+ train_advantages = np.concatenate(train_advantages)
+ train_discount_sum_rewards = np.concatenate(train_discount_sum_rewards)
+ return train_obs, train_actions, train_advantages, train_discount_sum_rewards
+
+
+def main():
+ config = mujoco_config
+ config['env'] = args.env
+ config['seed'] = args.seed
+ config['env_num'] = args.env_num
+ config['test_every_episodes'] = args.test_every_episodes
+ config['train_total_episodes'] = args.train_total_episodes
+ config['episodes_per_batch'] = args.episodes_per_batch
+
+ logger.info("------------------- PPO ---------------------")
+ logger.info('Env: {}, seed: {}'.format(config['env'], config['seed']))
+ logger.info("---------------------------------------------")
+ logger.set_dir('./train_logs/{}_{}'.format(config['env'], config['seed']))
+
+ env = gym.make(args.env)
+ env = CompatWrapper(env)
+ env.seed(args.seed)
+
+ obs_dim = env.observation_space.shape[0]
+ act_dim = env.action_space.shape[0]
+ obs_dim += 1 # add 1 to obs dim for time step feature
+
+ scaler = Scaler(obs_dim)
+ model = MujocoModel(obs_dim, act_dim)
+ alg = PPO_Mujoco(model, act_dim=act_dim)
+ agent = MujocoAgent(alg, config)
+
+ parl.connect(config['xparl_addr'])
+ actors = [Actor(config) for _ in range(config["env_num"])]
+ # run a few episodes to initialize scaler
+ get_remote_trajectories(actors, scaler)
+
+ test_flag = 0
+ episode = 0
+ while episode < config['train_total_episodes']:
+ latest_params = agent.get_weights()
+ # setting the actor to the latest_params
+ for remote_actor in actors:
+ remote_actor.set_weights(latest_params)
+
+ trajectories = []
+ while len(trajectories) < config['episodes_per_batch']:
+ trajectories.extend(get_remote_trajectories(actors, scaler))
+ episode += len(trajectories)
+
+ train_obs, train_actions, train_advantages, train_discount_sum_rewards = build_train_data(
+ config, trajectories, agent)
+
+ policy_loss, kl, beta, lr_multiplier, entropy = agent.policy_learn(train_obs, train_actions, train_advantages)
+ value_loss, exp_var, old_exp_var = agent.value_learn(train_obs, train_discount_sum_rewards)
+
+ total_train_rewards = sum([np.sum(t['rewards']) for t in trajectories])
+ logger.info('Training: Episode {}, Avg train reward: {}, Policy loss: {}, KL: {}, Value loss: {}'.format(
+ episode, total_train_rewards / len(trajectories), policy_loss, kl, value_loss))
+ summary.add_scalar("train/episode_mean_reward", total_train_rewards / len(trajectories), episode)
+
+ if episode // config['test_every_episodes'] >= test_flag:
+ while episode // config['test_every_episodes'] >= test_flag:
+ test_flag += 1
+
+ avg_reward = run_evaluate_episodes(env, agent, scaler, config['eval_episode'])
+ summary.add_scalar('eval/episode_reward', avg_reward, episode)
+ logger.info('Evaluation over: {} episodes, Reward: {}'.format(config['eval_episode'], avg_reward))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--env', type=str, help='Mujoco environment name', default='Swimmer-v2')
+ parser.add_argument(
+ "--env_num", type=int, default=5, help="number of the environment, xparl is needed")
+ parser.add_argument('--episodes_per_batch', type=int, default=5, help='Number of episodes per training batch')
+ parser.add_argument('--train_total_episodes', type=int, default=int(100), help='maximum training steps')
+ parser.add_argument(
+ '--test_every_episodes',
+ type=int,
+ default=int(50),
+ help='the step interval between two consecutive evaluations')
+ parser.add_argument('--seed', type=int, default=1, help='the step interval between two consecutive evaluations')
+ args = parser.parse_args()
+
+ main()
diff --git a/examples/PPO/mujoco_config.py b/examples/PPO/mujoco_config.py
deleted file mode 100644
index 89eec5ab1..000000000
--- a/examples/PPO/mujoco_config.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-mujoco_config = {
- ## Commented parameters are set to default values in ppo
-
- #========== env config ==========
- 'env': 'HalfCheetah-v2', # environment name
- 'continuous_action': True, # action type of the environment
- 'env_num': 1, # number of the environment
- 'seed': None, # seed of the experiment
- 'xparl_addr': None, # xparl address for distributed training
-
- #========== training config ==========
- 'train_total_steps': int(1e6), # max training steps
- 'step_nums': 2048, # data collecting time steps (ie. T in the paper)
- 'num_minibatches': 32, # number of training minibatches per update.
- 'update_epochs': 10, # number of epochs for updating (ie K in the paper)
- 'eval_episode': 3,
- 'test_every_steps': int(5e3), # interval between evaluations
-
- #========== coefficient of ppo ==========
- 'initial_lr': 3e-4, # start learning rate
- 'lr_decay': True, # whether or not to use linear decay rl
- # 'eps': 1e-5, # Adam optimizer epsilon (default: 1e-5)
- 'clip_param': 0.2, # epsilon in clipping loss
- 'entropy_coef': 0.0, # Entropy coefficient (ie. c_2 in the paper)
- # 'value_loss_coef': 0.5, # Value loss coefficient (ie. c_1 in the paper)
- # 'max_grad_norm': 0.5, # Max gradient norm for gradient clipping
- # 'use_clipped_value_loss': True, # advantages normalization
- # 'clip_vloss': True, # whether or not to use a clipped loss for the value function
- # 'gamma': 0.99, # discounting factor
- # 'gae': True, # whether or not to use GAE
- # 'gae_lambda': 0.95, # Lambda parameter for calculating N-step advantage
-}
diff --git a/examples/PPO/mujoco_model.py b/examples/PPO/mujoco_model.py
deleted file mode 100644
index a45f086e8..000000000
--- a/examples/PPO/mujoco_model.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import parl
-import paddle
-import paddle.nn as nn
-import numpy as np
-
-
-class MujocoModel(parl.Model):
- """ The Model for Mujoco env
- Args:
- obs_space (Box): observation space.
- act_space (Box): action space.
- """
-
- def __init__(self, obs_space, act_space):
- super(MujocoModel, self).__init__()
-
- self.fc1 = nn.Linear(obs_space.shape[0], 64)
- self.fc2 = nn.Linear(64, 64)
-
- self.fc_value = nn.Linear(64, 1)
-
- self.fc_policy = nn.Linear(64, np.prod(act_space.shape))
- self.fc_pi_std = paddle.static.create_parameter(
- [1, np.prod(act_space.shape)],
- dtype='float32',
- default_initializer=nn.initializer.Constant(value=0))
-
- def value(self, obs):
- """ Get value network prediction
- Args:
- obs (np.array): current observation
- """
- out = paddle.tanh(self.fc1(obs))
- out = paddle.tanh(self.fc2(out))
- value = self.fc_value(out)
- return value
-
- def policy(self, obs):
- """ Get policy network prediction
- Args:
- obs (np.array): current observation
- """
- out = paddle.tanh(self.fc1(obs))
- out = paddle.tanh(self.fc2(out))
- action_mean = self.fc_policy(out)
-
- action_logstd = self.fc_pi_std
- action_std = paddle.exp(action_logstd)
- return action_mean, action_std
diff --git a/examples/PPO/requirements_atari.txt b/examples/PPO/requirements_atari.txt
index 20d2b18c1..164608db5 100644
--- a/examples/PPO/requirements_atari.txt
+++ b/examples/PPO/requirements_atari.txt
@@ -1,5 +1,5 @@
gym==0.18.0
paddlepaddle>=2.0.0
-parl>=2.1.1
+parl>=2.2.2
atari-py==0.2.6
opencv-python
diff --git a/examples/PPO/requirements_mujoco.txt b/examples/PPO/requirements_mujoco.txt
index 5ec7789bb..c25354382 100644
--- a/examples/PPO/requirements_mujoco.txt
+++ b/examples/PPO/requirements_mujoco.txt
@@ -1,4 +1,5 @@
-gym>=0.26.0
+gym==0.18.0
mujoco==2.2.2
+mujoco-py==2.1.2.14
paddlepaddle>=2.0.0
-parl>=2.1.1
+parl>=2.2.2
diff --git a/parl/algorithms/paddle/__init__.py b/parl/algorithms/paddle/__init__.py
index 7a60022fb..457ff5dcc 100644
--- a/parl/algorithms/paddle/__init__.py
+++ b/parl/algorithms/paddle/__init__.py
@@ -22,6 +22,7 @@
from parl.algorithms.paddle.a2c import *
from parl.algorithms.paddle.ddqn import *
from parl.algorithms.paddle.maddpg import *
-from parl.algorithms.paddle.ppo import *
+from parl.algorithms.paddle.ppo_atari import *
+from parl.algorithms.paddle.ppo_mujoco import *
from parl.algorithms.paddle.cql import *
from parl.algorithms.paddle.impala.impala import *
diff --git a/parl/algorithms/paddle/ppo.py b/parl/algorithms/paddle/ppo_atari.py
similarity index 67%
rename from parl/algorithms/paddle/ppo.py
rename to parl/algorithms/paddle/ppo_atari.py
index 1b54a294b..9146853a1 100644
--- a/parl/algorithms/paddle/ppo.py
+++ b/parl/algorithms/paddle/ppo_atari.py
@@ -17,13 +17,13 @@
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.optimizer as optim
-from paddle.distribution import Normal, Categorical
+from paddle.distribution import Categorical
from parl.utils.utils import check_model_method
-__all__ = ['PPO']
+__all__ = ['PPO_Atari']
-class PPO(parl.Algorithm):
+class PPO_Atari(parl.Algorithm):
def __init__(self,
model,
clip_param=0.1,
@@ -33,9 +33,8 @@ def __init__(self,
eps=1e-5,
max_grad_norm=0.5,
use_clipped_value_loss=True,
- norm_adv=True,
- continuous_action=False):
- """ PPO algorithm
+ norm_adv=True):
+ """ PPO algorithm for Atari
Args:
model (parl.Model): forward network of actor and critic.
@@ -47,7 +46,6 @@ def __init__(self,
max_grad_norm (float): max gradient norm for gradient clipping.
use_clipped_value_loss (bool): whether or not to use a clipped loss for the value function.
norm_adv (bool): whether or not to use advantages normalization.
- continuous_action (bool): whether or not is continuous action environment.
"""
# check model methods
check_model_method(model, 'value', self.__class__.__name__)
@@ -61,7 +59,6 @@ def __init__(self,
assert isinstance(max_grad_norm, float)
assert isinstance(use_clipped_value_loss, bool)
assert isinstance(norm_adv, bool)
- assert isinstance(continuous_action, bool)
self.clip_param = clip_param
self.value_loss_coef = value_loss_coef
@@ -69,24 +66,13 @@ def __init__(self,
self.max_grad_norm = max_grad_norm
self.use_clipped_value_loss = use_clipped_value_loss
self.norm_adv = norm_adv
- self.continuous_action = continuous_action
self.model = model
clip = nn.ClipGradByNorm(self.max_grad_norm)
self.optimizer = optim.Adam(
- parameters=self.model.parameters(),
- learning_rate=initial_lr,
- epsilon=eps,
- grad_clip=clip)
-
- def learn(self,
- batch_obs,
- batch_action,
- batch_value,
- batch_return,
- batch_logprob,
- batch_adv,
- lr=None):
+ parameters=self.model.parameters(), learning_rate=initial_lr, epsilon=eps, grad_clip=clip)
+
+ def learn(self, batch_obs, batch_action, batch_value, batch_return, batch_logprob, batch_adv, lr=None):
""" update model with PPO algorithm
Args:
@@ -103,44 +89,34 @@ def learn(self,
entropy_loss (float): entropy loss
"""
values = self.model.value(batch_obs)
- if self.continuous_action:
- mean, std = self.model.policy(batch_obs)
- dist = Normal(mean, std)
- action_log_probs = dist.log_prob(batch_action).sum(1)
- dist_entropy = dist.entropy().sum(1)
- else:
- logits = self.model.policy(batch_obs)
- dist = Categorical(logits=logits)
- act_dim = logits.shape[-1]
- batch_action = paddle.to_tensor(batch_action, dtype='int64')
- actions_onehot = F.one_hot(batch_action, act_dim)
+ logits = self.model.policy(batch_obs)
+ dist = Categorical(logits=logits)
+
+ act_dim = logits.shape[-1]
+ batch_action = paddle.to_tensor(batch_action, dtype='int64')
+ actions_onehot = F.one_hot(batch_action, act_dim)
- action_log_probs = paddle.sum(
- F.log_softmax(logits) * actions_onehot, axis=-1)
- dist_entropy = dist.entropy()
+ action_log_probs = paddle.sum(F.log_softmax(logits) * actions_onehot, axis=-1)
+ dist_entropy = dist.entropy()
entropy_loss = dist_entropy.mean()
batch_adv = batch_adv
if self.norm_adv:
- batch_adv = (batch_adv - batch_adv.mean()) / (
- batch_adv.std() + 1e-8)
+ batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-8)
ratio = paddle.exp(action_log_probs - batch_logprob)
surr1 = ratio * batch_adv
- surr2 = paddle.clip(ratio, 1.0 - self.clip_param,
- 1.0 + self.clip_param) * batch_adv
+ surr2 = paddle.clip(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * batch_adv
action_loss = -paddle.minimum(surr1, surr2).mean()
values = values.reshape([-1])
# calculate value loss using semi gradient TD
if self.use_clipped_value_loss:
- value_pred_clipped = batch_value + paddle.clip(
- values - batch_value, -self.clip_param, self.clip_param)
+ value_pred_clipped = batch_value + paddle.clip(values - batch_value, -self.clip_param, self.clip_param)
value_losses = (values - batch_return).pow(2)
value_losses_clipped = (value_pred_clipped - batch_return).pow(2)
- value_loss = 0.5 * paddle.maximum(value_losses,
- value_losses_clipped).mean()
+ value_loss = 0.5 * paddle.maximum(value_losses, value_losses_clipped).mean()
else:
value_loss = 0.5 * (values - batch_return).pow(2).mean()
@@ -168,23 +144,14 @@ def sample(self, obs):
"""
value = self.model.value(obs)
- if self.continuous_action:
- mean, std = self.model.policy(obs)
- dist = Normal(mean, std)
- action = dist.sample([1])
+ logits = self.model.policy(obs)
+ dist = Categorical(logits=logits)
+ action = dist.sample([1])
- action_log_probs = dist.log_prob(action).sum(-1)
- action_entropy = dist.entropy().sum(-1).mean()
- else:
- logits = self.model.policy(obs)
- dist = Categorical(logits=logits)
- action = dist.sample([1])
-
- act_dim = logits.shape[-1]
- actions_onehot = F.one_hot(action, act_dim)
- action_log_probs = paddle.sum(
- F.log_softmax(logits) * actions_onehot, axis=-1)
- action_entropy = dist.entropy()
+ act_dim = logits.shape[-1]
+ actions_onehot = F.one_hot(action, act_dim)
+ action_log_probs = paddle.sum(F.log_softmax(logits) * actions_onehot, axis=-1)
+ action_entropy = dist.entropy()
return value, action, action_log_probs, action_entropy
@@ -197,12 +164,9 @@ def predict(self, obs):
action (torch tensor): action, shape([batch_size] + action_shape),
noted that in the discrete case we take the argmax along the last axis as action
"""
- if self.continuous_action:
- action, _ = self.model.policy(obs)
- else:
- logits = self.model.policy(obs)
- probs = F.softmax(logits)
- action = paddle.argmax(probs, 1)
+ logits = self.model.policy(obs)
+ probs = F.softmax(logits)
+ action = paddle.argmax(probs, 1)
return action
def value(self, obs):
diff --git a/parl/algorithms/paddle/ppo_mujoco.py b/parl/algorithms/paddle/ppo_mujoco.py
new file mode 100644
index 000000000..f4b004010
--- /dev/null
+++ b/parl/algorithms/paddle/ppo_mujoco.py
@@ -0,0 +1,191 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import parl
+import paddle
+import numpy as np
+from copy import deepcopy
+from parl.utils.utils import check_model_method
+
+__all__ = ['PPO_Mujoco']
+
+
+class PPO_Mujoco(parl.Algorithm):
+ def __init__(self, model, act_dim=None, loss_type='KLPEN', kl_targ=0.003, eta=50, clip_param=0.2, eps=1e-5):
+ """ PPO algorithm for Mujoco
+
+ Args:
+ model (parl.Model): model defining forward network of policy and value.
+ act_dim (float): dimension of the action space.
+ loss_type (string): loss type of PPO algorithm, 'CLIP' or 'KLPEN'".
+ kl_targ (float): D_KL target value.
+ eta (float): multiplier for D_KL-kl_targ hinge-squared loss.
+ clip_param (float): epsilon used in the CLIP loss.
+ eps (float): A small float value for numerical stability.
+ """
+ # check model methods and member variables
+ check_model_method(model, 'value', self.__class__.__name__)
+ check_model_method(model, 'policy', self.__class__.__name__)
+ assert hasattr(model, 'policy_model')
+ assert hasattr(model, 'value_model')
+ assert hasattr(model, 'policy_lr')
+ assert hasattr(model, 'value_lr')
+
+ assert isinstance(act_dim, int)
+ assert isinstance(kl_targ, float)
+ assert isinstance(eta, float)
+ assert isinstance(eps, float)
+ assert isinstance(clip_param, float)
+ assert loss_type == 'CLIP' or loss_type == 'KLPEN'
+ self.loss_type = loss_type
+ self.act_dim = act_dim
+ self.clip_param = clip_param
+ self.eta = eta
+ self.kl_targ = kl_targ
+
+ self.model = model
+ # Used to calculate probability of action in old policy
+ self.old_policy_model = deepcopy(model.policy_model)
+
+ self.policy_lr = self.model.policy_lr
+ self.value_lr = self.model.value_lr
+ self.policy_optimizer = paddle.optimizer.Adam(
+ parameters=self.model.policy_model.parameters(), learning_rate=self.policy_lr, epsilon=eps)
+ self.value_optimizer = paddle.optimizer.Adam(
+ parameters=self.model.value_model.parameters(), learning_rate=self.value_lr, epsilon=eps)
+
+ def _calc_logprob(self, actions, means, logvars):
+ """ Calculate log probabilities of actions, when given means and logvars
+ of normal distribution.
+ The constant sqrt(2 * pi) is omitted, which will be eliminated in later.
+
+ Args:
+ actions: shape (batch_size, act_dim)
+ means: shape (batch_size, act_dim)
+ logvars: shape (act_dim)
+
+ Returns:
+ logprob: shape (batch_size)
+ """
+ logp = -0.5 * paddle.sum(logvars)
+ logp += -0.5 * paddle.sum((paddle.square(actions - means) / paddle.exp(logvars)), axis=1)
+ logprob = logp
+ return logprob
+
+ def _calc_kl(self, means, logvars, old_means, old_logvars):
+ """ Calculate KL divergence between old and new distributions
+ See: https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback.E2.80.93Leibler_divergence
+
+ Args:
+ means: shape (batch_size, act_dim)
+ logvars: shape (act_dim)
+ old_means: shape (batch_size, act_dim)
+ old_logvars: shape (act_dim)
+
+ Returns:
+ kl: shape (batch_size)
+ entropy
+ """
+ log_det_cov_old = paddle.sum(old_logvars)
+ log_det_cov_new = paddle.sum(logvars)
+ tr_old_new = paddle.sum(paddle.exp(old_logvars - logvars))
+ kl = 0.5 * paddle.mean(
+ paddle.sum(paddle.square(means - old_means) / paddle.exp(logvars), axis=1) +
+ (log_det_cov_new - log_det_cov_old) + tr_old_new - self.act_dim)
+
+ entropy = 0.5 * (self.act_dim * (np.log(2 * np.pi) + 1) + paddle.sum(logvars))
+
+ return kl, entropy
+
+ def value(self, obs):
+ """ Use value model of self.model to predict value of obs
+ """
+ return self.model.value(obs)
+
+ def predict(self, obs):
+ """ Use the policy model of self.model to predict means and logvars of actions
+ """
+ means, logvars = self.model.policy(obs)
+ return means
+
+ def sample(self, obs):
+ """ Use the policy model of self.model to sample actions
+ """
+ means, logvars = self.model.policy(obs)
+ sampled_act = means + (
+ paddle.exp(logvars / 2.0) * # stddev
+ paddle.standard_normal(shape=(self.act_dim, ), dtype='float32'))
+ return sampled_act
+
+ def policy_learn(self, batch_obs, batch_action, batch_adv, beta, lr_multiplier):
+ """ Learn policy model with:
+ 1. CLIP loss: Clipped Surrogate Objective
+ 2. KLPEN loss: Adaptive KL Penalty Objective
+ See: https://arxiv.org/pdf/1707.02286.pdf
+
+ Args:
+ batch_obs: Tensor, (batch_size, obs_dim)
+ batch_action: Tensor, (batch_size, act_dim)
+ batch_adv: Tensor (batch_size, )
+ beta: Tensor (1) or None. If None, use CLIP Loss; else, use KLPEN loss.
+ lr_multiplier: Tensor (1)
+ """
+ old_means, old_logvars = self.old_policy_model.policy(batch_obs)
+ old_means.stop_gradient = True
+ old_logvars.stop_gradient = True
+
+ old_logprob = self._calc_logprob(batch_action, old_means, old_logvars)
+ old_logprob.stop_gradient = True
+
+ means, logvars = self.model.policy(batch_obs)
+ logprob = self._calc_logprob(batch_action, means, logvars)
+ kl, entropy = self._calc_kl(means, logvars, old_means, old_logvars)
+
+ if self.loss_type == "KLPEN":
+ loss1 = -(batch_adv * paddle.exp(logprob - old_logprob)).mean()
+ loss2 = (kl * beta).mean()
+ loss3 = self.eta * paddle.square(paddle.maximum(paddle.to_tensor(0.0), kl - 2.0 * self.kl_targ))
+ loss = loss1 + loss2 + loss3
+ elif self.loss_type == "CLIP":
+ ratio = paddle.exp(logprob - old_logprob)
+ surr1 = ratio * batch_adv
+ surr2 = paddle.clip(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * batch_adv
+ loss = -paddle.minimum(surr1, surr2).mean()
+ else:
+ raise ValueError("Policy loss type error, 'CLIP' or 'KLPEN'")
+
+ self.policy_optimizer.set_lr(self.policy_lr * lr_multiplier)
+
+ self.policy_optimizer.clear_grad()
+ loss.backward()
+ self.policy_optimizer.step()
+ return loss, kl, entropy
+
+ def value_learn(self, batch_obs, batch_return):
+ """ Learn the value model with square error cost
+ """
+ predict_val = self.model.value(batch_obs)
+ predict_val = predict_val.reshape([-1])
+
+ loss = (predict_val - batch_return).pow(2).mean()
+
+ self.value_optimizer.clear_grad()
+ loss.backward()
+ self.value_optimizer.step()
+ return loss
+
+ def sync_old_policy(self):
+ """ Synchronize weights of self.model.policy_model to self.old_policy_model
+ """
+ self.model.policy_model.sync_weights_to(self.old_policy_model)
diff --git a/parl/utils/rl_utils.py b/parl/utils/rl_utils.py
index 602d12c8e..e015fd577 100644
--- a/parl/utils/rl_utils.py
+++ b/parl/utils/rl_utils.py
@@ -15,7 +15,7 @@
import numpy as np
import scipy.signal
-__all__ = ['calc_discount_sum_rewards', 'calc_gae']
+__all__ = ['calc_discount_sum_rewards', 'calc_gae', "Scaler"]
def calc_discount_sum_rewards(rewards, gamma):
@@ -49,3 +49,50 @@ def calc_gae(rewards, values, next_value, gamma, lam):
tds = rewards + gamma * np.append(values[1:], next_value) - values
advantages = calc_discount_sum_rewards(tds, gamma * lam)
return advantages
+
+
+class Scaler(object):
+ """ Generate scale and offset based on running mean and stddev along axis=0
+
+ offset = running mean
+ scale = 1 / (stddev + 0.1) / 3 (i.e. 3x stddev = +/- 1.0)
+ """
+
+ def __init__(self, obs_dim):
+ """
+ Args:
+ obs_dim: dimension of axis=1
+ """
+ self.vars = np.zeros(obs_dim)
+ self.means = np.zeros(obs_dim)
+ self.cnt = 0
+ self.first_pass = True
+
+ def update(self, x):
+ """ Update running mean and variance (this is an exact method)
+ Args:
+ x: NumPy array, shape = (N, obs_dim)
+
+ see: https://stats.stackexchange.com/questions/43159/how-to-calculate-pooled-
+ variance-of-two-groups-given-known-group-variances-mean
+ """
+ if self.first_pass:
+ self.means = np.mean(x, axis=0)
+ self.vars = np.var(x, axis=0)
+ self.cnt = x.shape[0]
+ self.first_pass = False
+ else:
+ n = x.shape[0]
+ new_data_var = np.var(x, axis=0)
+ new_data_mean = np.mean(x, axis=0)
+ new_data_mean_sq = np.square(new_data_mean)
+ new_means = ((self.means * self.cnt) + (new_data_mean * n)) / (self.cnt + n)
+ self.vars = (((self.cnt * (self.vars + np.square(self.means))) +
+ (n * (new_data_var + new_data_mean_sq))) / (self.cnt + n) - np.square(new_means))
+ self.vars = np.maximum(0.0, self.vars) # occasionally goes negative, clip
+ self.means = new_means
+ self.cnt += n
+
+ def get(self):
+ """ returns 2-tuple: (scale, offset) """
+ return 1 / (np.sqrt(self.vars) + 0.1) / 3, self.means