-
Notifications
You must be signed in to change notification settings - Fork 822
add RL4LMs summarization #1078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dwyzzy
wants to merge
39
commits into
PaddlePaddle:develop
Choose a base branch
from
dwyzzy:sentence_review_summarization
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
add RL4LMs summarization #1078
Changes from 34 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
a1a4c4b
init file using files from RL4LMS
dwyzzy e706ed4
benchmark of RL4LMs v0.0
dwyzzy f816028
benchmark of RL4LMs v0.1
dwyzzy 0293734
fix pg reward bug, remove no use warmstartup
dwyzzy 02efdd9
merge models and buffers, add README.md
dwyzzy 89c4efb
simplified code v0.0
dwyzzy 0b69359
remove distribution_wrapper.py and sample_util.py
dwyzzy 23735cb
remove EvaluateActionsOutput, ValueOutput and PolicyOutput
dwyzzy bbdd102
use Reviewer and ReviewerGroup instead of Env
dwyzzy a9aef6b
use Reviewer and ReviewerGroup instead of Env (parl parallel)
dwyzzy bf3c625
use Reviewer and ReviewerGroup instead of Env (parl parallel version)
dwyzzy d452685
review using sentence (parl parallel version)
dwyzzy b943f1c
remove some '**config' and change rollout util
dwyzzy 086ce6f
use instructor instead of reviewer, add examiner
dwyzzy 3acf2c3
add requirements.txt
dwyzzy 090b190
change code style
dwyzzy 78f44b8
Merge branch 'develop' into sentence_review_summarization
TomorrowIsAnOtherDay b66f07e
change train.py style
dwyzzy 0d8af33
Merge remote-tracking branch 'rl4lm_parl/sentence_review_summarizatio…
dwyzzy 337ac75
change style
dwyzzy d0ced44
change style
dwyzzy 151fcea
change code style(add copyright)
dwyzzy f91d2c9
bring for-batch-rollout loop out of rl4lms_ppo
dwyzzy dc1d835
change name of policy/value , obs-preprocess and add-to-buffer
dwyzzy a23e8fe
change config structure
dwyzzy c2be52f
change ppo code style according to parl ppo
dwyzzy b34ea18
yapf code style
dwyzzy 02c8956
change code for PARL-RL4LMs summarization version 0.1
dwyzzy 760cc9d
change code style of PARL-RL4LMs summarization version 0.1
dwyzzy 1770e45
change unreasonable name to n_steps_per_instructor in config
dwyzzy b9c3e5c
add object for all classes, adjust add-to-buffer structure
dwyzzy 59e02fa
change t5_ppo_config and README
dwyzzy 4cd67e2
yapf code style
dwyzzy 6af0e40
Merge branch 'develop' into sentence_review_summarization
TomorrowIsAnOtherDay da94226
change buffer add(), add save/load
dwyzzy 68ec090
yapf code style
dwyzzy 2b82da6
Merge remote-tracking branch 'rl4lm_parl/sentence_review_summarizatio…
dwyzzy 21e99e8
evaluate at beginning
dwyzzy 1704e4f
Merge branch 'develop' into sentence_review_summarization
TomorrowIsAnOtherDay File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
## Reproduce Summarization-RLHF in RL4LMs using PARL | ||
|
||
> Paper: [Is Reinforcement Learning (Not) for Natural Language Processing: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization](https://arxiv.org/abs/2210.01241) | ||
|
||
### Background | ||
|
||
- Summarization task in NLP: Summarization is the task of producing a shorter version | ||
of one document that preserves most of the input's meaning. | ||
- RLHF: The abbreviation of Reinforcement Learning with Human Feedback, which uses human knowledge to train RL algorithms. | ||
More information is available in the Hugging Face blog [Illustrating Reinforcement Learning from Human Feedback (RLHF)](https://huggingface.co/blog/rlhf) | ||
|
||
### Main contribution | ||
|
||
- Build new Summarization-RLHF framework using PARL | ||
- Use PARL parallel training | ||
|
||
### How to use | ||
|
||
#### Install dependencies | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
#### Start training | ||
```bash | ||
# start xparl | ||
xparl start --port 8811 --cpu_num 10 | ||
|
||
# start training | ||
python train.py | ||
``` | ||
|
||
### Code Reference | ||
|
||
- Official code: [RL4LMs](https://github.com/allenai/RL4LMs) | ||
- [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,280 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import time | ||
from collections import OrderedDict | ||
import torch | ||
from rl4lms_utils import Observation | ||
from gym import spaces | ||
from gym.spaces.dict import Dict as DictSpace | ||
from gym.spaces.discrete import Discrete | ||
import parl | ||
from collections import deque | ||
import numpy as np | ||
from rl4lms_utils import build_datapool, build_tokenizer, build_reward_fn | ||
|
||
|
||
def _flatten_obs(obs, space, n_instructor=None): | ||
if n_instructor is not None: | ||
return OrderedDict([(k, np.stack([o[k] for o in obs]).reshape((n_instructor, -1, len(obs[0][k])))) | ||
for k in space.spaces.keys()]) | ||
return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) | ||
|
||
|
||
@parl.remote_class(wait=False) | ||
class Instructor(object): | ||
def __init__( | ||
self, | ||
reward_config=None, | ||
tokenizer_config=None, | ||
datapool_config=None, | ||
max_episode_length=512, | ||
max_prompt_length=None, | ||
terminate_on_eos=False, | ||
context_start_token=None, | ||
prompt_truncation_side="left", | ||
waiting_time_idx=0, | ||
): | ||
""" | ||
Instructor who gives reward | ||
Args: | ||
max_episode_length (int, optional): Max steps to the model Defaults to 512. | ||
max_prompt_length (Optional[int], optional): maximum prompt length. Defaults to None. | ||
terminate_on_eos (bool, optional): whether to terminate on EOS. Defaults to False. | ||
context_start_token (bool, optional): start token for the context (For Encoder-Decoder models! ) | ||
prompt_truncation_side (str): truncation side for prompt text (Defaults to "left") | ||
""" | ||
time.sleep( | ||
waiting_time_idx * 90) # too many Instructors may cause problems if they load datasets at the same time | ||
tokenizer = build_tokenizer(tokenizer_config) | ||
samples = build_datapool(datapool_config, remote_train=True)["train"] | ||
reward_function = build_reward_fn(reward_config) | ||
self.tokenizer = tokenizer | ||
self.reward_function = reward_function | ||
self.max_steps = max_episode_length | ||
self._max_text_length = (max_prompt_length if max_prompt_length else tokenizer.model_max_length) | ||
self._terminate_on_eos = terminate_on_eos | ||
self._context_start_token = context_start_token | ||
self._prompt_truncation_side = prompt_truncation_side | ||
|
||
# set the observation and action space here | ||
self._vocab_size = tokenizer.vocab_size | ||
self.observation_space = DictSpace({ | ||
# while creating rollout buffers, observations are concatenated for each key | ||
"prompt_or_input_encoded_pt": | ||
spaces.Box(low=0, high=self._vocab_size, shape=(self._max_text_length, )), | ||
"prompt_or_input_attention_mask_pt": | ||
spaces.Box(low=0, high=1, shape=(self._max_text_length, )), | ||
"context_encoded_pt": | ||
spaces.Box(low=0, high=self._vocab_size, shape=(self.max_steps, )), | ||
"context_attention_mask_pt": | ||
spaces.Box(low=0, high=1, shape=(self.max_steps, )), | ||
"input_encoded_pt": | ||
spaces.Box( | ||
low=0, | ||
high=self._vocab_size, | ||
shape=(self._max_text_length + self.max_steps, ), | ||
), | ||
"input_attention_mask_pt": | ||
spaces.Box(low=0, high=1, shape=(self._max_text_length + self.max_steps, )), | ||
}) | ||
self.action_space = Discrete(n=self._vocab_size) | ||
# see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency | ||
if 'mt5' in self.tokenizer.name_or_path: | ||
n = 250112 | ||
self.action_space = Discrete(n=n) | ||
elif 't5' in self.tokenizer.name_or_path: | ||
n = 32128 | ||
self.action_space = Discrete(n=n) | ||
self.samples_for_replaying = deque() | ||
for sample, weight in samples: | ||
self.samples_for_replaying.append(sample) | ||
|
||
# check the tokenizer and add padding tokens | ||
if self.tokenizer.pad_token is None: | ||
self.tokenizer.pad_token = self.tokenizer.eos_token | ||
self.tokenizer.padding_side = "left" # TBD: configure this | ||
self.tokenizer.truncation_side = "left" # TBD: configure this | ||
|
||
# init tracking variables | ||
self.__current_sample = None | ||
self.__current_obs = None | ||
self.__time_step = None | ||
|
||
def get_new_obs_and_feedback_one_step(self, action): | ||
self.__time_step += 1 | ||
|
||
# previous obs | ||
previous_obs = self.__current_obs | ||
|
||
# just update the context tensor and gets the new observation | ||
self.__current_obs = self.__current_obs.update(action, self.tokenizer) | ||
|
||
# decide if the episode is finished or not | ||
done = (action == self.tokenizer.eos_token_id | ||
and self._terminate_on_eos) or (self.__time_step == self.max_steps) | ||
|
||
# compute reward | ||
reward = self.reward_function( | ||
previous_obs, | ||
action, | ||
self.__current_obs, | ||
done, | ||
self.__current_obs.meta_info, | ||
) | ||
|
||
# populate additional info | ||
info = { | ||
"output": self.__current_obs.context_text, | ||
"action_history": self.__current_obs.action_history, | ||
"reference_text": self.__current_obs.target_or_reference_texts, | ||
"prompt_text": self.__current_obs.prompt_or_input_text, | ||
"prev_output": previous_obs.context_text, | ||
"meta_info": previous_obs.meta_info, | ||
} | ||
|
||
if done: | ||
# save final observation where user can get it, then reset | ||
info["terminal_observation"] = self.__current_obs.to_dict() | ||
observation = self.ask() | ||
return (observation, reward, done, info) | ||
else: | ||
return (self.__current_obs.to_dict(), reward, done, info) | ||
|
||
def get_new_obs_and_feedback_sentence(self, sentence): | ||
res = [] | ||
for token in sentence: | ||
one_step_res = self.get_new_obs_and_feedback_one_step(token) | ||
res.append(one_step_res) | ||
return res | ||
|
||
def ask(self, sample=None): | ||
""" | ||
Reset the instructor and starts a new episode | ||
""" | ||
# gets a new sample if not provided | ||
if sample is None: | ||
sample = np.random.choice(a=self.samples_for_replaying, size=min(len(self.samples_for_replaying), 1))[0] | ||
self.__current_sample = sample | ||
|
||
# init the observation | ||
self.__current_obs = Observation.init_from_sample( | ||
sample, | ||
self.tokenizer, | ||
self._max_text_length, | ||
self.max_steps, | ||
self._prompt_truncation_side, | ||
self._context_start_token, | ||
sample.meta_data, | ||
) | ||
|
||
# start the time step counter | ||
self.__time_step = 0 | ||
|
||
dict_observation = self.__current_obs.to_dict() | ||
return dict_observation | ||
|
||
def get_obs_and_action_space(self): | ||
return (self.observation_space, self.action_space) | ||
|
||
|
||
class InstructorGroup(object): | ||
def __init__( | ||
self, | ||
instructor_config=None, | ||
tokenizer=None, | ||
datapool_config=None, | ||
tokenizer_config=None, | ||
): | ||
self.n_instructors = instructor_config["n_instructors"] | ||
# remote instructors need to use config to initialize due to serialization problem | ||
instructor_kwargs = { | ||
"reward_config": instructor_config["reward_fn"], | ||
"tokenizer_config": tokenizer_config, | ||
"datapool_config": datapool_config, | ||
"max_prompt_length": instructor_config["max_prompt_length"], | ||
"max_episode_length": instructor_config["max_episode_length"], | ||
"terminate_on_eos": instructor_config["terminate_on_eos"], | ||
"prompt_truncation_side": instructor_config["prompt_truncation_side"], | ||
"context_start_token": instructor_config["context_start_token"] | ||
} | ||
self.tokenizer = tokenizer | ||
self._remote_instructors = self._create_instructors(instructor_kwargs, instructor_config["parl_master_address"]) | ||
|
||
# due to serialization problem, build obs space and action space here | ||
self._vocab_size = tokenizer.vocab_size | ||
self.observation_space = DictSpace({ | ||
# while creating rollout buffers, observations are concatenated for each key | ||
"prompt_or_input_encoded_pt": | ||
spaces.Box(low=0, high=self._vocab_size, shape=(instructor_kwargs["max_prompt_length"], )), | ||
"prompt_or_input_attention_mask_pt": | ||
spaces.Box(low=0, high=1, shape=(instructor_kwargs["max_prompt_length"], )), | ||
"context_encoded_pt": | ||
spaces.Box(low=0, high=self._vocab_size, shape=(instructor_kwargs["max_episode_length"], )), | ||
"context_attention_mask_pt": | ||
spaces.Box(low=0, high=1, shape=(instructor_kwargs["max_episode_length"], )), | ||
"input_encoded_pt": | ||
spaces.Box( | ||
low=0, | ||
high=self._vocab_size, | ||
shape=(instructor_kwargs["max_prompt_length"] + instructor_kwargs["max_episode_length"], ), | ||
), | ||
"input_attention_mask_pt": | ||
spaces.Box( | ||
low=0, | ||
high=1, | ||
shape=(instructor_kwargs["max_prompt_length"] + instructor_kwargs["max_episode_length"], )), | ||
}) | ||
self.action_space = Discrete(n=self._vocab_size) | ||
|
||
def ask(self): | ||
future_object_ids = [remote_instructor.ask() for remote_instructor in self._remote_instructors] | ||
sample_questions = [future_object.get() for future_object in future_object_ids] | ||
# sample_questions = future_object_ids | ||
return _flatten_obs(sample_questions, self.observation_space) | ||
|
||
def feedback_sentense(self, gen_output): | ||
sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos = \ | ||
self._instructors_feedback_sentence(gen_output.step_wise_actions) | ||
|
||
return sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos | ||
|
||
def _instructors_feedback_sentence(self, all_sentences): | ||
all_sentences = torch.stack(all_sentences).cpu().numpy().transpose(1, 0) | ||
future_object_ids = [ | ||
self._remote_instructors[i].get_new_obs_and_feedback_sentence(all_sentences[i]) | ||
for i in range(self.n_instructors) | ||
] | ||
|
||
feedback_res = np.stack([future_object.get() for future_object in future_object_ids]) | ||
|
||
obs, rews, dones, infos = zip(*feedback_res.reshape(-1, 4)) | ||
return _flatten_obs(obs, self.observation_space, self.n_instructors), \ | ||
np.stack(rews).reshape(self.n_instructors, -1), np.stack(dones).reshape(self.n_instructors, -1),\ | ||
np.stack(infos).reshape(self.n_instructors, -1) | ||
|
||
def _create_instructors(self, instructor_kwargs, parl_port=None): | ||
parl.connect(parl_port, distributed_files=["./rl4lms_utils/*.py", "./*.py"]) | ||
return [ | ||
Instructor( | ||
reward_config=instructor_kwargs["reward_config"], | ||
tokenizer_config=instructor_kwargs["tokenizer_config"], | ||
datapool_config=instructor_kwargs["datapool_config"], | ||
max_episode_length=instructor_kwargs["max_episode_length"], | ||
max_prompt_length=instructor_kwargs["max_prompt_length"], | ||
terminate_on_eos=instructor_kwargs["terminate_on_eos"], | ||
context_start_token=instructor_kwargs["context_start_token"], | ||
prompt_truncation_side=instructor_kwargs["prompt_truncation_side"], | ||
waiting_time_idx=idx, | ||
) for idx in range(self.n_instructors) | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
parl>=2.1.1 | ||
datasets==2.10.1 | ||
torch==1.11.0 | ||
torchvision==0.12.0 | ||
transformers==4.18.0 | ||
charset-normalizer==3.0.1 | ||
gym==0.21.0 | ||
cchardet==2.1.7 | ||
nltk==3.7 | ||
gem-metrics @ git+https://github.com/GEM-benchmark/GEM-metrics.git@431a8174bd6b3637e8d6118bfad2983e39e99733 | ||
bert-score==0.3.11 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.