Skip to content

Application/traffic_light_control #665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions examples/Application/Traffic-Light-Control/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
## Reproduce Some Baselines of Traffic Light Control
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Baseline Algorithms For Traffic Light Control

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Based on PARL, we use the DDQN algorithm of deep RL to reproduce some baselines of the Traffic Light Control(TLC), reaching the same level of indicators as the papers in TLC benchmarks.

### Traffic Light Control Simulator Introduction

Please see [sumo](https://github.com/eclipse/sumo) or [cityflow](https://github.com/cityflow-project/CityFlow) to know more about the TLC simulator.
And we use the cityflow simuator in the experiments, as for how to install the cityflow, please refer [here](https://cityflow.readthedocs.io/en/latest/index.html) for more informations.

### Benchmark Result
Note that we set the yellow signal time to 5 seconds to clear the intersection, and the action intervals is set to 10 seconds as the papers, you can refer the `config.py` for details, you also can change the time as what you want. The different values of the times above may cause different results of the experiments.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for details -> for more details.
And remove the sentences after that. People may suspect that your implementations are not robust.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

You can download the data from [here](https://traffic-signal-control.github.io/) and [MPLight data](https://github.com/Chacha-Chen/MPLight/tree/master/data).
We use the average travel time of all vehicles to evaluate the performance of the signal control method in transportation.
Performances of presslight and FRAP on cityflow envrionments in training process after 300 episodes are shown below.

| average travel time| hz_1x1_tms-<br>xy_18041608| hz_1x1_bc-<br>tyc_18041608|syn_1x3_<br>gaussian|syn_2x2_<br>gaussian|anon_4_4_<br>750_0.6| anon_4_4<br>_750_0.3| anon_4_4<br>_700_0.6|anon_4_4<br>_700_0.3|
| :-----| :----: | :----: |:----: | :----: |:----: | :----: |:----: | :----: |
| max_pressure | 284.02 | 445.62 | 240.08 |316.67|589.03 | 536.89 |545.29 | 483.08 |
| presslight |110.62 | 189.97| 127.83| 184.58| 437.86| 357.10 |410.34 | 434.33|
| FRAP | 113.79 | 135.88 | 123.97| 166.45| 374.73 | 331.43 | 343.79| 300.77 |
| presslight* | 236.29| 244.87 |149.40| 953.78| -- | --| --| -- |
| FRAP* | 130.53| 159.54| 750.68| 713.48|--| -- |-- | -- |


Note that for the method `sotl`, different `t_min`, `min_green_vehicle` and `max_red_vehicle` configs may cause huge different results, which may not fair for sotl to compare its result with others, so we don't list the result of the `sotl` above.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also provide the implementation for that SOTL algorithm, but its performance heavily relies on the environment variables such as t_min and min_green_vehicle. We do not list its result here.


And results of the last two rows of the table ,`presslight*` and `FRAP*`, they are the results of the code [tlc-baselines](https://github.com/gjzheng93/tlc-baselines) provided from the paper authors' team. We run the [code](https://github.com/gjzheng93/tlc-baselines) just changing the yellow time and the action intervals to keep them same as our config as the papers without changing any other parameters. `--` in the table means the origins code doesn't perform well in the last four `anon_4X4` datas, the average travel time results of it will be more than 1000, maybe it will perform better than the `max_pressure`if you modify the other hyperparameters of the DQN agents, such as the buffer size, update_model_freq, the gamma or others.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yellow time -> yellow signal time

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


## How to use
### Dependencies
+ [parl>=1.4.3](https://github.com/PaddlePaddle/PARL)
+ torch==1.8.1+cu102
+ cityflow==0.1

### Training
First, download the data from [here](https://traffic-signal-control.github.io/) or [MPLight data](https://github.com/Chacha-Chen/MPLight/tree/master/data) and put them in the `data` directory. And the run the training script. The `train_presslight.py `for the presslight, each intersection has its own model as default(you can also choose to train with that all the intersections share one model in the script, just as what the paper MPLight used, it is suggested when the number of the intersections is large, just setting the `--is_share_model` to `True`).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the run the training script -> And run the training script

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

```bash
python train_presslight.py --is_share_model False
```

If you want the train the `FRAR`, you can run the script below:
```bash
python train_FRAP.py
```

If you want to compare the different results, you can load the right model path in the `config.py` and the right data path in the `config.json`, and then run:
```bash
python test.py
```

### Contents
+ agent
+ `agent.py`
The agent that uses the PARL agent mode, it will be used when training the RL methods such as `presslight` or `FRAP` and so on.
+ `max_pressure_agent.py` and `sotl_agnet.py`.The classic methods of the TLC.
+ data
+ You can get the data of the from here or download other data and put them here.
+ example
+ Put the `config.json` here, need to change the path of the roadnet the flow data in the `json` file.
+ model
+ Different algorithms have different models.
+ obs_reward
+ Different algorithms have different obs and rewards generators.


### Something about the Distributed Training
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the section if we do not provide parallel training algorithms.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.


We don't use the distributed traing or the parallel actors for collect the datas from the cityflow env, if you want to use the parallel actors with the cluster, you can refer to [here](https://github.com/PaddlePaddle/PARL/tree/develop/examples/A2C) or our [documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) for details.

### Some Suggestions and Conclusions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the section. PARL will not provide suggestions for choosing the algorithm.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

+ The classic method `max_pressure`, `solt` or `greedy`(just set green lights to the roads with the most vehicles) can get the not bad baselines, when you use the RL method, you can compare to those baselines to make sure there is no mistakes in the RL code and the training process.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no mistakes -> there are no mistakes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

+ As for the just one intersection roadnet data, from our experiences:
+ `presslight` can get the high baselines results, if you want to get better results, you can try `FRAP` in your own data, if the flow data and the roadnet is easy without so many vehicles, `presslight` maybe better.
+ If your roadnet contains hundreds intersections, it is unrealistic to make each model to each agent(intersection), you can choose to train with that all the intersections share one common model and one buffer. As for the complicated scene, the complicated model `FRAR`, `Colight`,`GAT` or `multi-agents` methods may be better.
+ The replay memory size and the gamma doesn't matter much from our experiences.
+ As the reward is hard or inappropriate to design, we suggest that the `ES` maybe a better choice, and we also have tested same data with the [ES](https://github.com/PaddlePaddle/PARL/tree/develop/benchmark/torch/ES), just use the negative average travel time as the fitness(rewards), it can get the better results when we create enough actors in the [cluster](https://parl.readthedocs.io/en/latest/parallel_training/setup.html).
+ The RL methods is just overfitting the env with the specific flow and roadnet data, maybe when evaluating the results we can test the model with different flow or roadnet data?


### Reference
+ [Parl](https://parl.readthedocs.io/en/latest/parallel_training/setup.html)
+ [Reinforcement Learning for Traffic Signal Control](https://traffic-signal-control.github.io/)
+ [Toward A Thousand Lights: Decentralized Deep Reinforcement Learning for Large-Scale Traffic Signal Control](https://chacha-chen.github.io/papers/chacha-AAAI2020.pdf)
+ [Traffic Light Control Baselines](https://github.com/zhc134/tlc-baselines)
+ [PressLight: Learning Max Pressure Control to Coordinate Traffic Signals in Arterial Network](http://personal.psu.edu/hzw77/publications/presslight-kdd19.pdf)
+ [PressLight](https://github.com/wingsweihua/presslight)
+ [Learning Phase Competition for Traffic Signal Control](http://www.personal.psu.edu/~gjz5038/paper/cikm2019_frap/cikm2019_frap_paper.pdf)
+ [frap-pub](https://github.com/gjzheng93/frap-pub)
64 changes: 64 additions & 0 deletions examples/Application/Traffic-Light-Control/agent/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import parl
import numpy as np


class Agent(parl.Agent):
def __init__(self, algorithm, config):
super(Agent, self).__init__(algorithm)

self.config = config
self.epsilon = self.config['epsilon']

def sample(self, obs):
# The epsilon-greedy action selector.
def sample_random(act_dim):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the simple function. Just call np.random.randint(0, act_dim).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

# Random samples
return np.random.randint(0, act_dim)

obs = paddle.to_tensor(obs, dtype='float32')
logits = self.alg.sample(obs)
act_dim = logits.shape[-1]
act_values = logits.numpy()
actions = np.argmax(act_values, axis=-1)
for i in range(obs.shape[0]):
if np.random.rand() <= self.epsilon:
actions[i] = sample_random(act_dim)
return actions

def predict(self, obs):

obs = paddle.to_tensor(obs, dtype='float32')
predict_actions = self.alg.predict(obs)
return predict_actions.numpy()

def learn(self, obs, actions, dones, rewards, next_obs):

obs = paddle.to_tensor(obs, dtype='float32')
actions = paddle.to_tensor(actions, dtype='float32')
dones = paddle.to_tensor(dones, dtype='float32')
next_obs = paddle.to_tensor(next_obs, dtype='float32')
rewards = paddle.to_tensor(rewards, dtype='float32')

Q_loss, pred_values, target_values, max_v_show_values, train_count, lr, epsilon = self.alg.learn(
obs, actions, dones, rewards, next_obs)

self.alg.sync_target(decay=self.config['decay'])
self.epsilon = epsilon

return Q_loss.numpy(), pred_values.numpy(), target_values.numpy(
), max_v_show_values.numpy(), train_count, lr, epsilon
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Third party code
#
# The following code is mainly referenced, modified and copied from:
# https://github.com/zhc134/tlc-baselines and https://github.com/gjzheng93/tlc-baseline

import numpy as np


class MaxPressureAgent(object):
"""
Agent using MaxPressure method to control traffic light
"""

def __init__(self, world):
self.world = world

def predict(self, lane_vehicle_count):
actions = []
for I_id, I in enumerate(self.world.intersections):
action = I.current_phase
max_pressure = None
action = -1
for phase_id in range(len(I.phases)):
pressure = sum([
lane_vehicle_count[start] - lane_vehicle_count[end]
for start, end in I.phase_available_lanelinks[phase_id]
])
if max_pressure is None or pressure > max_pressure:
action = phase_id
max_pressure = pressure
actions.append(action)
return np.array(actions)
45 changes: 45 additions & 0 deletions examples/Application/Traffic-Light-Control/agent/sotl_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Third party code
#
# The following code is mainly referenced, modified and copied from:
# https://github.com/zhc134/tlc-baselines and https://github.com/gjzheng93/tlc-baseline

import numpy as np


class SOTLAgent(object):
"""
Agent using Self-organizing Traffic Light(SOTL) Control method to control traffic light.
Note that different t_min, min_green_vehicle and max_red_vehicle may cause different results, which may not fair to compare to others.
"""

def __init__(self, world, t_min=3, min_green_vehicle=20,
max_red_vehicle=5):
self.world = world
# the minimum duration of time of one phase
self.t_min = t_min
# some threshold to deal with phase requests
self.min_green_vehicle = min_green_vehicle # 10
self.max_red_vehicle = max_red_vehicle # 30
self.action_dims = []
for i in self.world.intersections:
self.action_dims.append(len(i.phases))

def predict(self, lane_waiting_count):
actions = []
for I_id, I in enumerate(self.world.intersections):
action = I.current_phase
if I.current_phase_time >= self.t_min:
num_green_vehicles = sum([
lane_waiting_count[lane]
for lane in I.phase_available_startlanes[I.current_phase]
])
num_red_vehicles = sum(
[lane_waiting_count[lane] for lane in I.startlanes])
num_red_vehicles -= num_green_vehicles
if num_green_vehicles <= self.min_green_vehicle and num_red_vehicles > self.max_red_vehicle:
action = (action + 1) % self.action_dims[I_id]
actions.append(action)
return np.array(actions)

def get_reward(self):
return None
59 changes: 59 additions & 0 deletions examples/Application/Traffic-Light-Control/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

config = {

#========== env config ==========
'config_path_name':
'./examples/config_hz_2.json', # note that the path of the data can be modified in the json file.
'thread_num': 8,
'obs_fns': ['lane_count'],
'reward_fns': ['pressure'],
'is_only': False,
'average': None,
'action_interval': 10,
'metric_period': 3600, #3600
'yellow_phase_time': 5,

#========== learner config ==========
'gamma': 0.85, # also can be set to 0.95
'epsilon': 0.9,
'epsilon_min': 0.2,
'epsilon_decay': 0.99,
'start_lr': 0.00025,
'episodes': 200 + 100,
'algo': 'DQN', # DQN
'max_train_steps': int(1e6),
'lr_decay_interval': 100,
'epsilon_decay_interval': 100,
'sample_batch_size':
2048, # also can be set to 32, which doesn't matter much.
'learn_freq': 2, # update parameters every 2 or 5 steps
'decay': 0.995, # soft update of double DQN
'reward_normal_factor': 4, # rescale the rewards, also can be set to 20,
'train_count_log': 5, # add to the tensorboard
'is_show_log': False, # print in the screen
'step_count_log': 1000,

# save checkpoint frequent episode
'save_rate': 100,
'save_dir': './save_model/presslight',
'train_log_dir': './train_log/presslight',
'save_dir': './save_model/presslight4*4',
'train_log_dir': './train_log/presslight4*4',

# memory config
'memory_size': 20000,
'begin_train_mmeory_size': 3000
}
95 changes: 95 additions & 0 deletions examples/Application/Traffic-Light-Control/ddqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the parl.algorithms.DDQN directly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some tricks used in the ddqn.py, such as grad clip, epsilon decay ,lr_decay, which don't use in the parl.algorithms.DDQN . If using the parl.algorithms.DDQN directly, maybe all the experiments should be run again to make sure that parl.algorithms.DDQN performs well.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

import copy
import numpy as np
import parl
from parl.utils.scheduler import LinearDecayScheduler


class DDQN(parl.Algorithm):
def __init__(self, model, config):

self.model = model

clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=40.0)
self.optimizer = paddle.optimizer.Adam(
learning_rate=config['start_lr'],
parameters=self.model.parameters(),
grad_clip=clip)

self.mse_loss = nn.MSELoss(reduction='mean')

self.config = config
self.lr_scheduler = LinearDecayScheduler(config['start_lr'],
config['max_train_steps'])
self.lr = config['start_lr']
self.target_model = copy.deepcopy(model)

self.train_count = 0

self.epsilon = self.config['epsilon']
self.epsilon_min = self.config['epsilon_min']
self.epsilon_decay = self.config['epsilon_decay']

def sample(self, obs):
logits = self.model(obs)
return logits

def predict(self, obs):
logits = self.model(obs)
predict_actions = paddle.argmax(logits, axis=-1)
return predict_actions

def sync_target(self, decay=0.995):
# soft update
self.model.sync_weights_to(self.target_model, decay)

def learn(self, obs, actions, dones, rewards, next_obs):
# Update the Q network with the data sampled from the memory buffer.
if self.train_count > 0 and self.train_count % self.config[
'lr_decay_interval'] == 0:
self.lr = self.lr_scheduler.step(
step_num=self.config['lr_decay_interval'])
terminal = dones
actions_onehot = F.one_hot(
actions.astype('int'), num_classes=self.model.act_dim)
# shape of the pred_values: batch_size
pred_values = paddle.sum(self.model(obs) * actions_onehot, axis=-1)
greedy_action = self.model(next_obs).argmax(1)
with paddle.no_grad():
# target_model for evaluation, using the double DQN, the max_v_show just used for showing in the tensorborad
max_v_show = paddle.max(self.target_model(next_obs), axis=-1)
greedy_actions_onehot = F.one_hot(
greedy_action, num_classes=self.model.act_dim)
max_v = paddle.sum(
self.target_model(next_obs) * greedy_actions_onehot, axis=-1)
assert max_v.shape == rewards.shape
target = rewards + (1 - terminal) * self.config['gamma'] * max_v
Q_loss = 0.5 * self.mse_loss(pred_values, target)

# optimize
self.optimizer.clear_grad()
Q_loss.backward()
self.optimizer.step()
self.train_count += 1
if self.epsilon > self.epsilon_min and self.train_count % self.config[
'epsilon_decay_interval'] == 0:
self.epsilon *= self.epsilon_decay
return Q_loss, pred_values.mean(), target.mean(), max_v_show.mean(
), self.train_count, self.lr, self.epsilon
Loading