From a1a4c4b39e4a8074d740a103d068288d1d7be65e Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Wed, 1 Mar 2023 19:22:41 +0800 Subject: [PATCH 01/34] init file using files from RL4LMS --- benchmark/torch/RL4LMs/agents/__init__.py | 0 benchmark/torch/RL4LMs/algorithms/__init__.py | 1 + .../torch/RL4LMs/algorithms/rl4lm_ppo.py | 5 + .../RL4LMs/configs/summarization/t5_ppo.yml | 93 + benchmark/torch/RL4LMs/env/__init__.py | 1 + benchmark/torch/RL4LMs/env/text_gen_env.py | 177 + benchmark/torch/RL4LMs/models/__init__.py | 2 + benchmark/torch/RL4LMs/models/base_model.py | 635 +++ .../torch/RL4LMs/models/seq2seq_model.py | 329 ++ .../torch/RL4LMs/summarization/__init__.py | 1 + .../summarization/rl4lms_summa_agent.py | 435 ++ .../summarization/rl4lms_summa_model.py | 7 + benchmark/torch/RL4LMs/train.py | 79 + benchmark/torch/RL4LMs/trainers.py | 219 ++ benchmark/torch/RL4LMs/utils/__init__.py | 24 + benchmark/torch/RL4LMs/utils/buffer.py | 698 ++++ benchmark/torch/RL4LMs/utils/data_pool.py | 116 + benchmark/torch/RL4LMs/utils/data_wrapper.py | 327 ++ .../RL4LMs/utils/distribution_wrapper.py | 68 + .../torch/RL4LMs/utils/evaluation_util.py | 125 + .../utils/huggingface_generation_util.py | 3492 +++++++++++++++++ benchmark/torch/RL4LMs/utils/kl_controller.py | 32 + benchmark/torch/RL4LMs/utils/metric_util.py | 644 +++ benchmark/torch/RL4LMs/utils/registry.py | 189 + benchmark/torch/RL4LMs/utils/reward_util.py | 446 +++ benchmark/torch/RL4LMs/utils/sample_util.py | 40 + benchmark/torch/RL4LMs/utils/tracker.py | 154 + benchmark/torch/RL4LMs/utils/type_wrapper.py | 7 + benchmark/torch/RL4LMs/utils/warm_start.py | 147 + 29 files changed, 8493 insertions(+) create mode 100644 benchmark/torch/RL4LMs/agents/__init__.py create mode 100644 benchmark/torch/RL4LMs/algorithms/__init__.py create mode 100644 benchmark/torch/RL4LMs/algorithms/rl4lm_ppo.py create mode 100644 benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml create mode 100644 benchmark/torch/RL4LMs/env/__init__.py create mode 100644 benchmark/torch/RL4LMs/env/text_gen_env.py create mode 100644 benchmark/torch/RL4LMs/models/__init__.py create mode 100644 benchmark/torch/RL4LMs/models/base_model.py create mode 100644 benchmark/torch/RL4LMs/models/seq2seq_model.py create mode 100644 benchmark/torch/RL4LMs/summarization/__init__.py create mode 100644 benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py create mode 100644 benchmark/torch/RL4LMs/summarization/rl4lms_summa_model.py create mode 100644 benchmark/torch/RL4LMs/train.py create mode 100644 benchmark/torch/RL4LMs/trainers.py create mode 100644 benchmark/torch/RL4LMs/utils/__init__.py create mode 100644 benchmark/torch/RL4LMs/utils/buffer.py create mode 100644 benchmark/torch/RL4LMs/utils/data_pool.py create mode 100644 benchmark/torch/RL4LMs/utils/data_wrapper.py create mode 100644 benchmark/torch/RL4LMs/utils/distribution_wrapper.py create mode 100644 benchmark/torch/RL4LMs/utils/evaluation_util.py create mode 100644 benchmark/torch/RL4LMs/utils/huggingface_generation_util.py create mode 100644 benchmark/torch/RL4LMs/utils/kl_controller.py create mode 100644 benchmark/torch/RL4LMs/utils/metric_util.py create mode 100644 benchmark/torch/RL4LMs/utils/registry.py create mode 100644 benchmark/torch/RL4LMs/utils/reward_util.py create mode 100644 benchmark/torch/RL4LMs/utils/sample_util.py create mode 100644 benchmark/torch/RL4LMs/utils/tracker.py create mode 100644 benchmark/torch/RL4LMs/utils/type_wrapper.py create mode 100644 benchmark/torch/RL4LMs/utils/warm_start.py diff --git a/benchmark/torch/RL4LMs/agents/__init__.py b/benchmark/torch/RL4LMs/agents/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmark/torch/RL4LMs/algorithms/__init__.py b/benchmark/torch/RL4LMs/algorithms/__init__.py new file mode 100644 index 000000000..8d9429824 --- /dev/null +++ b/benchmark/torch/RL4LMs/algorithms/__init__.py @@ -0,0 +1 @@ +from .rl4lm_ppo import RL4LMPPO \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/algorithms/rl4lm_ppo.py b/benchmark/torch/RL4LMs/algorithms/rl4lm_ppo.py new file mode 100644 index 000000000..ee2592f1a --- /dev/null +++ b/benchmark/torch/RL4LMs/algorithms/rl4lm_ppo.py @@ -0,0 +1,5 @@ +from parl.algorithms.torch.ppo import PPO + + +class RL4LMPPO(PPO): + pass \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml b/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml new file mode 100644 index 000000000..de73bbb3a --- /dev/null +++ b/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml @@ -0,0 +1,93 @@ + + + +tokenizer: + model_name: t5-base + padding_side: left + truncation_side: left + pad_token_as_eos_token: False + +reward_fn: + id: rouge + args: + rouge_type: "rouge1" + +datapool: + id: cnn_daily_mail + args: + prompt_prefix: "Summarize: " + + +env: + ## CHANGE FOR DEBUG ## +# n_envs: 10 + n_envs: 2 + ## CHANGE FOR DEBUG ## + args: + max_prompt_length: 512 + max_episode_length: 100 + terminate_on_eos: True + prompt_truncation_side: "right" + context_start_token: 0 + +alg: + id: ppo + args: +# n_steps: 512 + #####CHNAGE FOR DEBUG######## + n_steps: 5 + #####CHANGE FOR DEBUG######## + batch_size: 16 + verbose: 1 + learning_rate: 0.000002 + n_epochs: 5 + ent_coef: 0.0 + kl_div: + coeff: 0.001 + target_kl: 0.2 + policy: + id: seq2seq_lm_actor_critic_policy + args: + model_name: t5-base + apply_model_parallel: True + prompt_truncation_side: "right" + generation_kwargs: + do_sample: True + top_k: 50 + min_length: 50 + max_new_tokens: 100 + +train_evaluation: + eval_batch_size: 100 + n_iters: 100 + eval_every: 10 + save_every: 1 + metrics: + - id: meteor + args: {} + - id: rouge + - id: bleu + args: {} + - id: bert_score + args: + language: en + # - id: bleurt + # args: + # config_name: bleurt-large-512 + - id: diversity + args: {} + # - id: summaCZS + # args: + # granularity: sentence + # use_ent: True + # use_con: False + # - id: summaCConv + # args: + # granularity: sentence + generation_kwargs: + do_sample: True + top_k: 0 + temperature: 0.7 + min_length: 50 + max_new_tokens: 100 + diff --git a/benchmark/torch/RL4LMs/env/__init__.py b/benchmark/torch/RL4LMs/env/__init__.py new file mode 100644 index 000000000..40764a3b1 --- /dev/null +++ b/benchmark/torch/RL4LMs/env/__init__.py @@ -0,0 +1 @@ +from .text_gen_env import TextGenEnv \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/env/text_gen_env.py b/benchmark/torch/RL4LMs/env/text_gen_env.py new file mode 100644 index 000000000..faf9eafb1 --- /dev/null +++ b/benchmark/torch/RL4LMs/env/text_gen_env.py @@ -0,0 +1,177 @@ +from cmath import inf +from typing import Dict, Tuple, Optional, List + +import torch +from gym import Env, spaces +from gym.spaces.dict import Dict as DictSpace +from gym.spaces.discrete import Discrete +from benchmark.torch.RL4LMs.utils import Sample, Observation, PrioritySampler +from benchmark.torch.RL4LMs.utils import RewardFunction, BatchedRewardFunction +from transformers import AutoTokenizer + + +class TextGenEnv(Env): + def __init__( + self, + tokenizer: AutoTokenizer, + reward_function: RewardFunction, + samples: Tuple[List[Sample], float], + max_episode_length: int = 512, + priority_scale: float = 0.0, + max_prompt_length: Optional[int] = None, + terminate_on_eos: bool = False, + context_start_token: Optional[int] = None, + prompt_truncation_side: str = "left", + ): + """ + A generic RL environment to generate textual sequences. + For eg: text generation, summarization, machine translation, text simplification + Args: + tokenizer (AutoTokenizer): pre-trained tokenizer + reward_function (RewardFunction): reward functiom + samples (Tuple[List[Sample], float]): list of samples + max_episode_length (int, optional): Max steps to the model Defaults to 512. + priority_scale (float, optional): weight for the priority sampler Defaults to 0.0. + max_prompt_length (Optional[int], optional): maximum prompt length. Defaults to None. + terminate_on_eos (bool, optional): whether to terminate on EOS. Defaults to False. + context_start_token (bool, optional): start token for the context (For Encoder-Decoder models! ) + prompt_truncation_side (str): truncation side for prompt text (Defaults to "left") + """ + self.tokenizer = tokenizer + self.reward_function = reward_function + self.max_steps = max_episode_length + self._max_text_length = ( + max_prompt_length if max_prompt_length else tokenizer.model_max_length + ) + self._terminate_on_eos = terminate_on_eos + self._context_start_token = context_start_token + self._prompt_truncation_side = prompt_truncation_side + super().__init__() + + # set the observation and action space here + self._vocab_size = tokenizer.vocab_size + self.observation_space = DictSpace( + { + # we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited + # while creating rollout buffers, observations are concatenated for each key + "prompt_or_input_encoded_pt": spaces.Box( + low=0, high=self._vocab_size, shape=(self._max_text_length,) + ), + "prompt_or_input_attention_mask_pt": spaces.Box( + low=0, high=1, shape=(self._max_text_length,) + ), + "context_encoded_pt": spaces.Box( + low=0, high=self._vocab_size, shape=(self.max_steps,) + ), + "context_attention_mask_pt": spaces.Box( + low=0, high=1, shape=(self.max_steps,) + ), + "input_encoded_pt": spaces.Box( + low=0, + high=self._vocab_size, + shape=(self._max_text_length + self.max_steps,), + ), + "input_attention_mask_pt": spaces.Box( + low=0, high=1, shape=(self._max_text_length + self.max_steps,) + ), + } + ) + self.action_space = Discrete(n=self._vocab_size) + # see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency + if 'mt5' in self.tokenizer.name_or_path: + n = 250112 + self.action_space = Discrete(n=n) + elif 't5' in self.tokenizer.name_or_path: + n = 32128 + self.action_space = Discrete(n=n) + self.sampler_for_replaying = PrioritySampler(priority_scale=priority_scale) + for sample, weight in samples: + self.sampler_for_replaying.add(sample, weight) + + # check the tokenizer and add padding tokens + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.padding_side = "left" # TBD: configure this + self.tokenizer.truncation_side = "left" # TBD: configure this + + # init tracking variables + self.__current_sample = None + self.__current_obs = None + self.__time_step = None + + def step(self, action: int) -> Tuple[Dict[str, torch.tensor], int, bool, dict]: + self.__time_step += 1 + + # previous obs + previous_obs = self.__current_obs + + # just update the context tensor and gets the new observation + self.__current_obs = self.__current_obs.update(action, self.tokenizer) + + # decide if the episode is finished or not + done = (action == self.tokenizer.eos_token_id and self._terminate_on_eos) or ( + self.__time_step == self.max_steps + ) + + # compute reward + if not isinstance(self.reward_function, BatchedRewardFunction): + reward = ( + None + if self.reward_function is None + else self.reward_function( + previous_obs, + action, + self.__current_obs, + done, + self.__current_obs.meta_info, + ) + ) + else: + reward = -inf # will be overridden later + + # populate additional info + info = { + "output": self.__current_obs.context_text, + "action_history": self.__current_obs.action_history, + "reference_text": self.__current_obs.target_or_reference_texts, + "prompt_text": self.__current_obs.prompt_or_input_text, + "prev_output": previous_obs.context_text, + "meta_info": previous_obs.meta_info, + } + + return self.__current_obs.to_dict(), reward, done, info + + def reset(self, sample: Sample = None) -> Dict[str, torch.tensor]: + """ + Resets the environment and starts a new episode + """ + # gets a new sample if not provided + if sample is None: + sample = self.sampler_for_replaying.sample(size=1)[0] + self.__current_sample = sample + + # init the observation + self.__current_obs = Observation.init_from_sample( + sample, + self.tokenizer, + self._max_text_length, + self.max_steps, + self._prompt_truncation_side, + self._context_start_token, + sample.meta_data, + ) + + # start the time step counter + self.__time_step = 0 + + dict_observation = self.__current_obs.to_dict() + return dict_observation + + def render(self): + pass + + def close(self): + pass + + def add_sample(self, sample: Sample, weight: int = 1.0): + self.sampler_for_replaying.add(sample, weight) diff --git a/benchmark/torch/RL4LMs/models/__init__.py b/benchmark/torch/RL4LMs/models/__init__.py new file mode 100644 index 000000000..0509d06e7 --- /dev/null +++ b/benchmark/torch/RL4LMs/models/__init__.py @@ -0,0 +1,2 @@ +from .base_model import BasePolicy, LMActorCriticPolicy +from .seq2seq_model import Seq2SeqLMModel \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/models/base_model.py b/benchmark/torch/RL4LMs/models/base_model.py new file mode 100644 index 000000000..cca3053f3 --- /dev/null +++ b/benchmark/torch/RL4LMs/models/base_model.py @@ -0,0 +1,635 @@ +from abc import abstractmethod, ABC +from copy import deepcopy +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +from gym.spaces import Discrete +from gym.spaces.dict import Dict as DictSpace +from torch.distributions import Categorical +from transformers import AutoTokenizer, PreTrainedModel +from transformers.modeling_utils import unwrap_model +from torch import nn + +import gym +import numpy as np + +from benchmark.torch.RL4LMs.utils import ( + Schedule, TensorDict, + + CategoricalDistribution, + + EvaluateActionsOutput, PolicyOutput, RefPolicyOutput, ValueOutput, + GenerationInputs, GenerationOutputs, PolicyType +) + + +# refer to stable_baselines3.common.policies +class BaseModel(nn.Module, ABC): + """ + The base model object: makes predictions in response to observations. + + In the case of policies, the prediction is an action. In the case of critics, it is the + estimated value of the observation. + + :param observation_space: The observation space of the environment + :param action_space: The action space of the environment + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param features_extractor: Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``torch.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + # features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor: Optional[nn.Module] = None, + normalize_images: bool = True, + optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__() + + if optimizer_kwargs is None: + optimizer_kwargs = {} + + if features_extractor_kwargs is None: + features_extractor_kwargs = {} + + self.observation_space = observation_space + self.action_space = action_space + self.features_extractor = features_extractor + self.normalize_images = normalize_images + + self.optimizer_class = optimizer_class + self.optimizer_kwargs = optimizer_kwargs + self.optimizer = None # type: Optional[torch.optim.Optimizer] + + # self.features_extractor_class = features_extractor_class + self.features_extractor_kwargs = features_extractor_kwargs + + @abstractmethod + def forward(self, *args, **kwargs): + pass + + # def _update_features_extractor( + # self, + # net_kwargs: Dict[str, Any], + # features_extractor: Optional[BaseFeaturesExtractor] = None, + # ) -> Dict[str, Any]: + # """ + # Update the network keyword arguments and create a new features extractor object if needed. + # If a ``features_extractor`` object is passed, then it will be shared. + # + # :param net_kwargs: the base network keyword arguments, without the ones + # related to features extractor + # :param features_extractor: a features extractor object. + # If None, a new object will be created. + # :return: The updated keyword arguments + # """ + # net_kwargs = net_kwargs.copy() + # if features_extractor is None: + # # The features extractor is not shared, create a new one + # features_extractor = self.make_features_extractor() + # net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim)) + # return net_kwargs + # + # def make_features_extractor(self) -> BaseFeaturesExtractor: + # """Helper method to create a features extractor.""" + # return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + # + # def extract_features(self, obs: torch.Tensor) -> torch.Tensor: + # """ + # Preprocess the observation if needed and extract features. + # + # :param obs: + # :return: + # """ + # assert self.features_extractor is not None, "No features extractor was set" + # preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) + # return self.features_extractor(preprocessed_obs) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + """ + Get data that need to be saved in order to re-create the model when loading it from disk. + + :return: The dictionary to pass to the as kwargs constructor when reconstruction this model. + """ + return dict( + observation_space=self.observation_space, + action_space=self.action_space, + # Passed to the constructor by child class + # squash_output=self.squash_output, + # features_extractor=self.features_extractor + normalize_images=self.normalize_images, + ) + + # @property + # def device(self) -> torch.device: + # """Infer which device this policy lives on by inspecting its parameters. + # If it has no parameters, the 'cpu' device is used as a fallback. + # + # :return:""" + # for param in self.parameters(): + # return param.device + # return get_device("cpu") + + def save(self, path: str) -> None: + """ + Save model to a given location. + + :param path: + """ + torch.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) + + # @classmethod + # def load(cls, path: str, device: Union[torch.device, str] = "auto") -> "BaseModel": + # """ + # Load model from patorch. + # + # :param path: + # :param device: Device on which the policy should be loaded. + # :return: + # """ + # device = get_device(device) + # saved_variables = torch.load(path, map_location=device) + # + # # Allow to load policy saved with older version of SB3 + # if "sde_net_arch" in saved_variables["data"]: + # warnings.warn( + # "sde_net_arch is deprecated, please downgrade to SB3 v1.2.0 if you need such parameter.", + # DeprecationWarning, + # ) + # del saved_variables["data"]["sde_net_arch"] + # + # # Create policy object + # model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable + # # Load weights + # model.load_state_dict(saved_variables["state_dict"]) + # model.to(device) + # return model + + def load_from_vector(self, vector: np.ndarray) -> None: + """ + Load parameters from a 1D vector. + + :param vector: + """ + torch.nn.utils.vector_to_parameters(torch.FloatTensor(vector).to(self.device), self.parameters()) + + def parameters_to_vector(self) -> np.ndarray: + """ + Convert the parameters to a 1D vector. + + :return: + """ + return torch.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() + + def set_training_mode(self, mode: bool) -> None: + """ + Put the policy in either training or evaluation mode. + + This affects certain modules, such as batch normalisation and dropout. + + :param mode: if true, set to training mode, else set to evaluation mode + """ + self.train(mode) + # + # def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[torch.Tensor, bool]: + # """ + # Convert an input observation to a PyTorch tensor that can be fed to a model. + # Includes sugar-coating to handle different observations (e.g. normalizing images). + # + # :param observation: the input observation + # :return: The observation as PyTorch tensor + # and whether the observation is vectorized or not + # """ + # vectorized_env = False + # if isinstance(observation, dict): + # # need to copy the dict as the dict in VecFrameStack will become a torch tensor + # observation = copy.deepcopy(observation) + # for key, obs in observation.items(): + # obs_space = self.observation_space.spaces[key] + # if is_image_space(obs_space): + # obs_ = maybe_transpose(obs, obs_space) + # else: + # obs_ = np.array(obs) + # vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space) + # # Add batch dimension if needed + # observation[key] = obs_.reshape((-1,) + self.observation_space[key].shape) + # + # elif is_image_space(self.observation_space): + # # Handle the different cases for images + # # as PyTorch use channel first format + # observation = maybe_transpose(observation, self.observation_space) + # + # else: + # observation = np.array(observation) + # + # if not isinstance(observation, dict): + # # Dict obs need to be handled separately + # vectorized_env = is_vectorized_observation(observation, self.observation_space) + # # Add batch dimension if needed + # observation = observation.reshape((-1,) + self.observation_space.shape) + # + # observation = obs_as_tensor(observation, self.device) + # return observation, vectorized_env + + +class BasePolicy(BaseModel): + """The base policy object. + + Parameters are mostly the same as `BaseModel`; additions are documented below. + + :param args: positional arguments passed through to `BaseModel`. + :param kwargs: keyword arguments passed through to `BaseModel`. + :param squash_output: For continuous actions, whether the output is squashed + or not using a ``tanh()`` function. + """ + + def __init__(self, *args, squash_output: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self._squash_output = squash_output + + @staticmethod + def _dummy_schedule(progress_remaining: float) -> float: + """(float) Useful for pickling policy.""" + del progress_remaining + return 0.0 + + @property + def squash_output(self) -> bool: + """(bool) Getter for squash_output.""" + return self._squash_output + + @staticmethod + def init_weights(module: nn.Module, gain: float = 1) -> None: + """ + Orthogonal initialization (used in PPO and A2C) + """ + if isinstance(module, (nn.Linear, nn.Conv2d)): + nn.init.orthogonal_(module.weight, gain=gain) + if module.bias is not None: + module.bias.data.fill_(0.0) + + @abstractmethod + def _predict(self, observation: torch.Tensor, deterministic: bool = False) -> torch.Tensor: + """ + Get the action according to the policy for a given observation. + + By default provides a dummy implementation -- not all BasePolicy classes + implement this, e.g. if they are a Critic in an Actor-Critic method. + + :param observation: + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy + """ + + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + # TODO (GH/1): add support for RNN policies + # if state is None: + # state = self.initial_state + # if episode_start is None: + # episode_start = [False for _ in range(self.n_envs)] + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + observation, vectorized_env = self.obs_to_tensor(observation) + + with torch.no_grad(): + actions = self._predict(observation, deterministic=deterministic) + # Convert to numpy + actions = actions.cpu().numpy() + + if isinstance(self.action_space, gym.spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + # Remove batch dimension if needed + if not vectorized_env: + actions = actions[0] + + return actions, state + + def scale_action(self, action: np.ndarray) -> np.ndarray: + """ + Rescale the action from [low, high] to [-1, 1] + (no need for symmetric action space) + + :param action: Action to scale + :return: Scaled action + """ + low, high = self.action_space.low, self.action_space.high + return 2.0 * ((action - low) / (high - low)) - 1.0 + + def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: + """ + Rescale the action from [-1, 1] to [low, high] + (no need for symmetric action space) + + :param scaled_action: Action to un-scale + """ + low, high = self.action_space.low, self.action_space.high + return low + (0.5 * (scaled_action + 1.0) * (high - low)) + +class LMActorCriticPolicy(BasePolicy): + def __init__( + self, + observation_space: DictSpace, + action_space: Discrete, + lr_schedule: Schedule, + model_name: str, + optimizer_kwargs: Dict[str, Any] = {}, + weight_decay: float = 1e-6, + use_sde: bool = None, + apply_model_parallel: bool = True, + optimizer_class: torch.optim.Optimizer = torch.optim.AdamW, + generation_kwargs: Dict[str, Any] = {}, + prompt_truncation_side: str = "left", + ): + """ + + Args: + observation_space (DictSpace): Observation space + action_space (Discrete): Action space + lr_schedule (Schedule): Learning rate schedule + model_name (str): name of the causal or seq2seq model from transformers library + optimizer_kwargs (Dict[str, Any], optional): optimizer kwargs. Defaults to {}. + weight_decay (float, optional): weight decay. Defaults to 1e-6. + use_sde (bool, optional): Use state-dependent exploration. Defaults to None. (Unused parameter from stable-baselines3) + apply_model_parallel (bool, optional): whether to apply model parallel. Defaults to True. + optimizer_class (torch.optim.Optimizer, optional): Optimizer class. Defaults to torch.optim.AdamW. + generation_kwargs (Dict[str, Any], optional): generation parameters for rollout. Defaults to {}. + prompt_truncation_side (str, optional): truncation side for prompt text. Defaults to "left". + """ + super().__init__(observation_space, action_space) + self._action_space = action_space + self._apply_model_parallel = apply_model_parallel + self._build_model_heads(model_name) + self._setup_optimizer(optimizer_kwargs, weight_decay, optimizer_class) + self._action_dist = CategoricalDistribution(self._action_space.n) + self._generation_kwargs = generation_kwargs + self._prompt_truncation_side = prompt_truncation_side + + def _setup_optimizer( + self, + optimizer_kwargs: Dict[str, Any], + weight_decay: float, + optimizer_class: torch.optim, + ): + params = list(self.named_parameters()) + + no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in params if not any(nd in n for nd in no_decay)], + "weight_decay": weight_decay, + }, + { + "params": [p for n, p in params if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + self.optimizer = optimizer_class( + optimizer_grouped_parameters, **optimizer_kwargs + ) + + def forward(self, *args, **kwargs): + # dummy just to comply with base policy + pass + + @staticmethod + def _predict( + self, observation: Dict[str, torch.tensor], deterministic: bool = False + ) -> torch.Tensor: + # dummy just to comply with base policy + pass + + def is_encoder_decoder(self, model: PreTrainedModel): + return unwrap_model(model).config.is_encoder_decoder + + def generate( + self, + tokenizer: AutoTokenizer, + texts: List[str] = None, + max_prompt_length: int = None, + input_ids: torch.tensor = None, + attention_mask: torch.tensor = None, + gen_kwargs: Dict[str, Any] = None, + ) -> GenerationOutputs: + + # if it different from rollout gen kwargs + if gen_kwargs is None: + gen_kwargs = self._generation_kwargs + + # switch to eval + self._policy_model.eval() + + if ( + input_ids is None + and attention_mask is None + and texts is not None + and max_prompt_length is not None + ): + # override truncation side for prompt + prev_truncation_side = tokenizer.truncation_side + tokenizer.truncation_side = self._prompt_truncation_side + encodings = tokenizer( + texts, + padding="max_length", + max_length=max_prompt_length, + return_tensors="pt", + return_attention_mask=True, + truncation=True, + ) + input_ids = encodings.input_ids + attention_mask = encodings.attention_mask + tokenizer.truncation_side = prev_truncation_side + + # if min_length argument is set and if policy is not a seq2seq LM (ie. causal LM) + # then it has to be adjusted to input_size + min_length + if "min_length" in gen_kwargs.keys() and not self.is_encoder_decoder( + self._policy_model + ): + generation_kwargs_ = deepcopy(gen_kwargs) + generation_kwargs_["min_length"] = ( + input_ids.shape[1] + gen_kwargs["min_length"] + ) + else: + generation_kwargs_ = gen_kwargs + + # generate + gen_output = unwrap_model(self._policy_model).generate( + inputs=input_ids.to(self.get_policy_first_device()), + attention_mask=attention_mask.to(self.get_policy_first_device()), + return_dict_in_generate=True, + output_scores=True, + **generation_kwargs_, + ) + + # number of tokens generated + seq_length = len(gen_output["scores"]) + + # get only the generated text (excluding prompt) + gen_tokens = gen_output["sequences"][:, -seq_length:] + + # to texts + gen_texts = [ + tokenizer.decode(output, skip_special_tokens=True) + for output in gen_tokens.tolist() + ] + + # extract scores (logits) + step_wise_logprobs = [] + step_wise_actions = [] + for step, logits in enumerate(gen_output["scores"]): + raw_logits, _ = logits + actions_at_step = gen_tokens[:, step] + distribution = Categorical(logits=raw_logits) + log_probs = distribution.log_prob(actions_at_step) + step_wise_logprobs.append(log_probs) + step_wise_actions.append(actions_at_step) + + gen_output = GenerationOutputs( + step_wise_logprobs, step_wise_actions, gen_tokens, gen_texts + ) + return gen_output + + def get_language_model(self): + return unwrap_model(self._policy_model) + + # Following methods need to be implemented by sub-classing + @abstractmethod + def _build_model_heads(self, model_name: str): + """ + Builds policy and value models + and sets self._policy_model and self._value_model + """ + raise NotImplementedError + + @abstractmethod + def forward_policy( + self, + obs: TensorDict, + actions: torch.tensor, + past_model_kwargs: Optional[Dict[str, torch.tensor]] = None, + ) -> PolicyOutput: + """ + Performs a forward pass on the policy and gets log_probs, entropy etc + corresponding to specified observation, actions + + This is invoked during rollout generation + + Args: + obs (TensorDict): observation + actions (torch.tensor): actions + past_model_kwargs (Optional[Dict[str, torch.tensor]], optional): Any cached past model activations which can be used for sequential foward passes. + Defaults to None. + """ + raise NotImplementedError + + @abstractmethod + def forward_value( + self, + obs: TensorDict, + past_model_kwargs: Optional[Dict[str, torch.tensor]] = None, + ) -> ValueOutput: + """ + Performs a forward pass on the value network and gets values corresponding to observations + + This is invoked during rollout generation + + Args: + obs (TensorDict): observation + past_model_kwargs (Optional[Dict[str, torch.tensor]], optional): Any cached past model activations which can be used for sequential foward passes. + Defaults to None. + """ + raise NotImplementedError + + @abstractmethod + def evaluate_actions( + self, obs: torch.Tensor, actions: torch.Tensor + ) -> EvaluateActionsOutput: + """ + Evaluates specified + and returns log_probs, values, entropy + + This is invoked for each mini-batch in rollout buffer during training iteration + """ + raise NotImplementedError + + @abstractmethod + def get_log_probs_ref_model( + self, + obs: TensorDict, + action: torch.tensor, + past_model_kwargs: Dict[str, Any] = None, + ) -> RefPolicyOutput: + """ + Performs a forward pass on the reference policy and gets log_probs + corresponding to specified observation, actions + + This is invoked during rollout generation to compute KL rewards + + Args: + obs (TensorDict): observation + past_model_kwargs (Optional[Dict[str, torch.tensor]], optional): Any cached past model activations which can be used for sequential foward passes. + Defaults to None. + """ + raise NotImplementedError + + @abstractmethod + def get_policy_first_device(self) -> torch.device: + """ + Returns the first device of the policy. Used in the case of model parallel + """ + raise NotImplementedError + + @abstractmethod + def get_policy_type(self) -> PolicyType: + """ + Returns the type of policy (causal or seq2seq) + """ + raise NotImplementedError + + @abstractmethod + def get_inputs_for_generation(self, obs: TensorDict) -> GenerationInputs: + """ + Extracts the prompt inputs and attention masks which is used as seed for generation + """ + raise NotImplementedError diff --git a/benchmark/torch/RL4LMs/models/seq2seq_model.py b/benchmark/torch/RL4LMs/models/seq2seq_model.py new file mode 100644 index 000000000..ba9cb1a74 --- /dev/null +++ b/benchmark/torch/RL4LMs/models/seq2seq_model.py @@ -0,0 +1,329 @@ +from typing import Any, Dict, Optional, List, Union +import torch +from gym.spaces import Discrete +from gym.spaces.dict import Dict as DictSpace +from torch import nn +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from copy import deepcopy + +from transformers.modeling_utils import unwrap_model +from benchmark.torch.RL4LMs.utils import ( + override_generation_routines, + + ActorCriticWarmStartMixin, + + TensorDict, Schedule, + + GenerationInputs, PolicyOutput, RefPolicyOutput, ValueOutput, + PolicyType, EvaluateActionsOutput, GenerationOutputs, +) + +from base_model import LMActorCriticPolicy + + +class Seq2SeqLMModel(LMActorCriticPolicy, ActorCriticWarmStartMixin): + def __init__( + self, + observation_space: DictSpace, + action_space: Discrete, + lr_schedule: Schedule, + model_name: str, + optimizer_kwargs: Dict[str, Any] = {}, + weight_decay: float = 1e-6, + use_sde: bool = None, + apply_model_parallel: bool = True, + optimizer_class: torch.optim.Optimizer = torch.optim.AdamW, + generation_kwargs: Dict[str, Any] = {}, + prompt_truncation_side: str = "left", + state_dict: Dict[str, Any] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + model_name, + optimizer_kwargs, + weight_decay, + use_sde, + apply_model_parallel, + optimizer_class, + generation_kwargs, + prompt_truncation_side, + ) + self.load_from_dict(state_dict) + + def _build_model_heads(self, model_name: str): + self._policy_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + self._policy_model.__class__ = override_generation_routines(type(self._policy_model)) + + self._value_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + self._ref_model = deepcopy(self._policy_model).eval() + + self._value_head = nn.Linear( + self._value_model.config.hidden_size, 1, bias=False + ) + + # apply model parallel + if torch.cuda.is_available(): + if self._apply_model_parallel and self._policy_model.is_parallelizable: + self._policy_model.parallelize() + self._ref_model.parallelize() + self._value_model.parallelize() + self._value_head = self._value_head.to(self.device) + else: # else defaults to data parallel + self._policy_model = torch.nn.DataParallel(self._policy_model) + self._ref_model = torch.nn.DataParallel(self._ref_model) + self._value_model = torch.nn.DataParallel(self._value_model) + self._value_head = torch.nn.DataParallel( + self._value_head.to(self.device) + ) + + def forward_policy( + self, + obs: TensorDict, + actions: torch.tensor, + past_model_kwargs: Optional[Dict[str, torch.tensor]] = None, + ) -> PolicyOutput: + + # Temp workaround for Seq2seq policy + past_model_kwargs = None + + if past_model_kwargs is None: + # 1. prepare model inputs + past_model_kwargs = { + "attention_mask": obs["prompt_or_input_attention_mask_pt"], + } + inputs_tensor, model_input_name, past_model_kwargs = unwrap_model( + self._policy_model + )._prepare_model_inputs( + obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs + ) + + # 2. prepare encoder outputs + past_model_kwargs = unwrap_model( + self._policy_model + )._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, past_model_kwargs, model_input_name + ) + + # 3. Prepare input_ids for auto-regressive generation + input_ids = obs["context_encoded_pt"].int() + decoder_attn_mask = obs["context_attention_mask_pt"] + else: + input_ids = obs["context_encoded_pt"].int() + decoder_attn_mask = past_model_kwargs.pop("decoder_attention_mask") + + # all set to get into auto-regressive mode + # prepare all of the model inputs for the decoder + batch_size = input_ids.shape[0] + model_inputs = unwrap_model(self._policy_model).prepare_inputs_for_generation( + input_ids, **past_model_kwargs + ) + + # and forward pass to get next token logits + outputs = self._policy_model( + **model_inputs, decoder_attention_mask=decoder_attn_mask, return_dict=True + ) + next_token_logits = outputs.logits[:, -1, :] + + # get log probs + dist = self._action_dist.proba_distribution(action_logits=next_token_logits) + log_prob = dist.log_prob(actions) + entropy = dist.entropy() + + # update the model kwargs for further generation + past_model_kwargs = unwrap_model( + self._policy_model + )._update_model_kwargs_for_generation( + outputs, + past_model_kwargs, + is_encoder_decoder=unwrap_model( + self._policy_model + ).config.is_encoder_decoder, + ) + past_model_kwargs["decoder_attention_mask"] = torch.cat( + (decoder_attn_mask, torch.ones(batch_size, 1).to(decoder_attn_mask.device)), + dim=-1, + ) + + policy_output = PolicyOutput( + actions, log_prob, log_prob, entropy, past_model_kwargs + ) + + return policy_output + + def forward_value( + self, + obs: TensorDict, + past_model_kwargs: Optional[Dict[str, torch.tensor]] = None, + ) -> ValueOutput: + # Temp workaround for Seq2seq policy + past_model_kwargs = None + + if past_model_kwargs is None: + # 1. prepare model inputs + past_model_kwargs = { + "attention_mask": obs["prompt_or_input_attention_mask_pt"], + } + inputs_tensor, model_input_name, past_model_kwargs = unwrap_model( + self._value_model + )._prepare_model_inputs( + obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs + ) + + # 2. prepare encoder outputs + past_model_kwargs = unwrap_model( + self._value_model + )._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, past_model_kwargs, model_input_name + ) + + # 3. Prepare input_ids for auto-regressive generation + input_ids = obs["context_encoded_pt"].int() + decoder_attn_mask = obs["context_attention_mask_pt"] + else: + input_ids = obs["context_encoded_pt"].int() + decoder_attn_mask = past_model_kwargs.pop("decoder_attention_mask") + + # all set to get into auto-regressive mode + # prepare all of the model inputs for the decoder + batch_size = input_ids.shape[0] + model_inputs = unwrap_model(self._value_model).prepare_inputs_for_generation( + input_ids, **past_model_kwargs + ) + + # and forrward pass to get hidden states + outputs = self._value_model( + **model_inputs, + output_hidden_states=True, + decoder_attention_mask=decoder_attn_mask, + return_dict=True + ) + + # get decoder's last hidden state + last_tokens_hidden = outputs.decoder_hidden_states[-1][:, -1, :].to(self.device) + values = self._value_head.forward(last_tokens_hidden) + + # update the model kwargs for further generation + past_model_kwargs = unwrap_model( + self._value_model + )._update_model_kwargs_for_generation( + outputs, + past_model_kwargs, + is_encoder_decoder=unwrap_model( + self._value_model + ).config.is_encoder_decoder, + ) + past_model_kwargs["decoder_attention_mask"] = torch.cat( + (decoder_attn_mask, torch.ones(batch_size, 1).to(decoder_attn_mask.device)), + dim=-1, + ) + + value_output = ValueOutput(values, past_model_kwargs) + return value_output + + def evaluate_actions( + self, obs: torch.Tensor, actions: torch.Tensor + ) -> EvaluateActionsOutput: + + policy_outputs = self.forward_policy(obs=obs, actions=actions) + value_outputs = self.forward_value(obs) + + eval_outputs = EvaluateActionsOutput( + values=value_outputs.values, + log_prob=policy_outputs.log_probs, + entropy=policy_outputs.entropy, + ) + return eval_outputs + + def to(self, device: str): + if self._apply_model_parallel: + self._value_head = self._value_head.to(device) + return self + else: + return super().to(device) + + def get_log_probs_ref_model( + self, + obs: TensorDict, + action: torch.tensor, + model_kwarpast_model_kwargsgs: Dict[str, Any] = None, + ) -> RefPolicyOutput: + # Temp workaround for Seq2seq policy + past_model_kwargs = None + + if past_model_kwargs is None: + # 1. prepare model inputs + past_model_kwargs = { + "attention_mask": obs["prompt_or_input_attention_mask_pt"], + } + inputs_tensor, model_input_name, past_model_kwargs = unwrap_model( + self._ref_model + )._prepare_model_inputs( + obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs + ) + + # 2. prepare encoder outputs + past_model_kwargs = unwrap_model( + self._ref_model + )._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, past_model_kwargs, model_input_name + ) + + # 3. Prepare input_ids for auto-regressive generation + input_ids = obs["context_encoded_pt"].int() + decoder_attn_mask = obs["context_attention_mask_pt"] + else: + input_ids = obs["context_encoded_pt"].int() + decoder_attn_mask = past_model_kwargs.pop("decoder_attention_mask") + + # all set to get into auto-regressive mode + # prepare all of the model inputs for the decoder + batch_size = input_ids.shape[0] + model_inputs = unwrap_model(self._ref_model).prepare_inputs_for_generation( + input_ids, **past_model_kwargs + ) + + # and forward pass to get next token logits + outputs = self._ref_model( + **model_inputs, decoder_attention_mask=decoder_attn_mask, return_dict=True + ) + next_token_logits = outputs.logits[:, -1, :] + + # get log probs + dist = self._action_dist.proba_distribution(action_logits=next_token_logits) + log_prob = dist.log_prob(action) + + # update the model kwargs for further generation + past_model_kwargs = unwrap_model( + self._ref_model + )._update_model_kwargs_for_generation( + outputs, + past_model_kwargs, + is_encoder_decoder=unwrap_model(self._ref_model).config.is_encoder_decoder, + ) + past_model_kwargs["decoder_attention_mask"] = torch.cat( + (decoder_attn_mask, torch.ones(batch_size, 1).to(decoder_attn_mask.device)), + dim=-1, + ) + + ref_policy_output = RefPolicyOutput(log_prob, past_model_kwargs) + + return ref_policy_output + + def get_policy_first_device(self): + return ( + self._policy_model.get_encoder().first_device + if self._apply_model_parallel + else self.device + ) + + def get_inputs_for_generation(self, obs: TensorDict) -> GenerationInputs: + + generation_inputs = GenerationInputs( + obs["prompt_or_input_encoded_pt"], obs["prompt_or_input_attention_mask_pt"] + ) + return generation_inputs + + def get_policy_type(self): + return PolicyType.SEQ2SEQ diff --git a/benchmark/torch/RL4LMs/summarization/__init__.py b/benchmark/torch/RL4LMs/summarization/__init__.py new file mode 100644 index 000000000..dcf74dbe4 --- /dev/null +++ b/benchmark/torch/RL4LMs/summarization/__init__.py @@ -0,0 +1 @@ +from .rl4lms_summa_agent import RL4LMsSummaAgent diff --git a/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py b/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py new file mode 100644 index 000000000..a829300ac --- /dev/null +++ b/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py @@ -0,0 +1,435 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import torch +import numpy as np + +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Type, Union +import numpy as np +import torch +from benchmark.torch.RL4LMs.utils import DictRolloutBuffer, RolloutBuffer, TransitionInfo, TensorDict,\ + BatchedRewardFunction, RewardFunction, PolicyOutput, RefPolicyOutput, ValueOutput, \ + MaskableDictRolloutBuffer, OnPolicyWarmStartMixin, KLController, Tracker + +from transformers import PreTrainedTokenizer + + + +def obs_as_tensor( + obs: Union[np.ndarray, Dict[Union[str, int], np.ndarray]], device: torch.device +) -> Union[torch.Tensor, TensorDict]: + """ + Moves the observation to the given device. + + :param obs: + :param device: PyTorch device + :return: PyTorch tensor of the observation on a desired device. + """ + if isinstance(obs, np.ndarray): + return torch.as_tensor(obs).to(device) + elif isinstance(obs, dict): + return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} + else: + raise Exception(f"Unrecognized type of observation {type(obs)}") + + + + + +def unpack_observations(obs_tensor, n_envs: int): + """ + Unpacks vectorized dict observations into separate dict observations + """ + unpacked_obs = [] + keys = obs_tensor.keys() + for env_ix in range(n_envs): + obs_dict = {} + for key in keys: + obs_dict[key] = obs_tensor[key][env_ix].reshape(1, -1).cpu() + unpacked_obs.append(obs_dict) + return unpacked_obs + + +def compute_batched_rewards( + episode_wise_transitions: List[List[TransitionInfo]], reward_fn: RewardFunction +): + # first collect all the prompts, ref and gen texts + prompts = [] + reference_texts = [] + generated_texts = [] + is_dones = [] + indices = [] + meta_infos = [] + for env_ix, transitions in enumerate(episode_wise_transitions): + for trans_ix, transition in enumerate(transitions): + done = transition.done + info = transition.info + prompts.append(info["prompt_text"]) + reference_texts.append(info["reference_text"]) + generated_texts.append(info["output"]) + is_dones.append(done) + meta_infos.append(info["meta_info"]) + indices.append((env_ix, trans_ix)) + + # compute rewards all at once + rewards = reward_fn(prompts, generated_texts, reference_texts, is_dones, meta_infos) + # rewards = rewards.numpy().flatten() + + # override the rewards in transitions + for (env_ix, trans_ix), reward in zip(indices, rewards): + episode_wise_transitions[env_ix][trans_ix].task_reward = reward + episode_wise_transitions[env_ix][trans_ix].total_reward = ( + reward + episode_wise_transitions[env_ix][trans_ix].kl_reward + ) + + +def wrap_onpolicy_alg( + alg_class, + alg_kwargs: Dict[str, Any], + kl_coeff: float, + tracker: Tracker, + target_kl: float = None, + norm_reward: bool = False, +): + class OnPolicyAlgText(alg_class, OnPolicyWarmStartMixin): + def __init__( + self, + alg_kwargs: Dict[str, Any], + kl_coeff: float, + tracker: Tracker, + target_kl: float = None, + norm_reward: bool = False, + ): + alg_kwargs["tracker"] = tracker + super().__init__(**alg_kwargs) + self._kl_controller = KLController(kl_coeff, target_kl) + self.tracker = tracker + self._norm_reward = norm_reward + # flattened rollout buffer + self.rollout_buffer = MaskableDictRolloutBuffer( + self.n_steps * self.env.num_envs, + self.observation_space, + self.action_space, + device=self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=1, + ) + self.reward_fn = self.env.get_attr("reward_function", 0)[0] + + def get_policy_kwargs( + self, + obs: TensorDict, + action: torch.tensor, + past_state: Dict[str, torch.tensor], + action_mask: torch.tensor, + ): + + policy_kwargs = { + "obs": obs, + "actions": action, + "past_model_kwargs": past_state, + } + if action_mask is not None: + policy_kwargs["action_masks"] = action_mask + return policy_kwargs + + def generate_batch( + self, + rollout_buffer: DictRolloutBuffer, + tokenizer: PreTrainedTokenizer, + max_steps: int, + rollout_info: Dict[str, Any], + ): + # if rollout buffer is already full, do not continue + if rollout_buffer.full: + return + + # start parallel episodes + current_obs = self.env.reset() + episode_starts = np.ones((self.env.num_envs,), dtype=bool) + + # generate text using the model + obs_tensor = obs_as_tensor(current_obs, self.device) + generation_inputs = self.policy.get_inputs_for_generation(obs_tensor) + gen_output = self.policy.generate( + input_ids=generation_inputs.inputs, + attention_mask=generation_inputs.attention_masks, + tokenizer=tokenizer, + ) + + # process them one step at a time to collect rollout info + episode_wise_transitions = [[] for _ in range(self.env.num_envs)] + ep_terminated = np.zeros((self.env.num_envs,), dtype=bool) + value_past_state = None + ref_past_state = None + policy_past_state = None + masks = ( + gen_output.action_masks + if gen_output.action_masks is not None + else [None] * len(gen_output.step_wise_logprobs) + ) + + for actions_tensor, _, action_mask in zip( + gen_output.step_wise_actions, gen_output.step_wise_logprobs, masks + ): + # if all episodes are done, just break and do not continue + if np.all(ep_terminated): + break + + # evaluate actions with actions from rollout + with torch.no_grad(): + obs_tensor = obs_as_tensor(current_obs, self.device) + + # get log probs (TBD: generalize this a bit) + policy_kwargs = self.get_policy_kwargs( + obs_tensor, actions_tensor, policy_past_state, action_mask + ) + + policy_outputs: PolicyOutput = self.policy.forward_policy( + **policy_kwargs + ) + raw_log_probs, log_probs, policy_past_state = ( + policy_outputs.raw_log_probs, + policy_outputs.log_probs, + policy_outputs.past_model_kwargs, + ) + + # sanity check + assert torch.all( + torch.isfinite(log_probs) + ), "Infinite values in log probs" + + # sanity check + assert torch.all( + torch.isfinite(raw_log_probs) + ), "Infinite values in log probs" + + # get values + value_outputs: ValueOutput = self.policy.forward_value( + obs_tensor, value_past_state + ) + values, value_past_state = ( + value_outputs.values, + value_outputs.past_model_kwargs, + ) + + # get reference log probs + ref_policy_outputs: RefPolicyOutput = ( + self.policy.get_log_probs_ref_model( + obs_tensor, actions_tensor, ref_past_state + ) + ) + ref_log_probs, ref_past_state = ( + ref_policy_outputs.log_probs, + ref_policy_outputs.past_model_kwargs, + ) + + # sanity check + assert torch.all( + torch.isfinite(ref_log_probs) + ), "Infinite values in log probs" + + # compute KL rewards + kl_div = raw_log_probs - ref_log_probs + kl_rewards = -1 * self._kl_controller.kl_coeff * kl_div + + # step into env to get rewards + actions = actions_tensor.cpu().numpy() + new_obs, rewards, dones, infos = self.env.step(actions) + + self.num_timesteps += self.env.num_envs + + # compute total rewards + total_rewards = rewards + kl_rewards.cpu().numpy() + + # unpack individual observations + unpacked_obs = unpack_observations(obs_tensor, self.env.num_envs) + + # store episode wise transitions separately + for env_ix in range(self.env.num_envs): + # only if not terminated already + if not ep_terminated[env_ix]: + transtion = TransitionInfo( + observation=unpacked_obs[env_ix], + action=actions[env_ix], + task_reward=rewards[env_ix], + total_reward=total_rewards[env_ix], + kl_div=kl_div.cpu().numpy()[env_ix], + episode_start=episode_starts[env_ix], + value=values[env_ix].cpu(), + log_prob=log_probs[env_ix].cpu(), + done=dones[env_ix], + ref_log_prob=ref_log_probs[env_ix].cpu(), + kl_reward=kl_rewards.cpu().numpy()[env_ix], + action_mask=action_mask[env_ix].cpu().numpy() + if action_mask is not None + else None, + info=infos[env_ix], + ) + + episode_wise_transitions[env_ix].append(transtion) + + # mark this episode to terminated if done occurs once + if dones[env_ix]: + ep_terminated[env_ix] = True + + episode_starts = np.zeros((self.env.num_envs,), dtype=bool) + current_obs = new_obs + + # now we flush all episode wise info to the 1-D buffer + rollout_info = self._add_to_buffer( + rollout_buffer, episode_wise_transitions, rollout_info + ) + return rollout_info + + def _add_to_buffer( + self, rollout_buffer, episode_wise_transitions, rollout_info + ): + # if the reward function is batchable, we override the rewards here + if isinstance(self.reward_fn, BatchedRewardFunction): + compute_batched_rewards(episode_wise_transitions, self.reward_fn) + + advantages_computed = False + for ep_ix, transitions in enumerate(episode_wise_transitions): + ep_length = len(transitions) + total_reward = 0.0 + total_kl_reward = 0.0 + for transition_ix, transition in enumerate(transitions): + total_reward += transition.task_reward + total_kl_reward += transition.kl_reward + rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div) + rollout_info["rollout_info/log_prob"].append(transition.log_prob) + rollout_info["rollout_info/ref_log_prob"].append( + transition.ref_log_prob + ) + rollout_info["rollout_info/values"].append(transition.value.numpy()) + + if not rollout_buffer.full: + rollout_buffer.add( + transition.observation, + transition.action, + transition.total_reward, + transition.episode_start, + transition.value, + transition.log_prob, + action_masks=transition.action_mask, + ) + + # if the buffer is full, compute advantages + if rollout_buffer.full and not advantages_computed: + + # normalize the rewards + if self._norm_reward: + mean = rollout_buffer.rewards.mean() + std = rollout_buffer.rewards.std() + rollout_buffer.rewards = (rollout_buffer.rewards - mean) / ( + std + 1e-8 + ) + + # we fetch the last value for the last time step + # values come from the next transitions's values + next_values = ( + transitions[transition_ix + 1].value + if (transition_ix + 1) < ep_length + else torch.tensor([0.0]) + ) + + rollout_buffer.compute_returns_and_advantage( + last_values=next_values, dones=transition.done + ) + advantages_computed = True + + rollout_info["rollout_info/ep_rew"].append(total_reward) + rollout_info["rollout_info/ep_lens"].append(ep_length) + rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward) + return rollout_info + + def collect_rollouts( + self, + env, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, + ) -> bool: + # max episode steps + max_steps = env.unwrapped.get_attr("max_steps", [0])[0] + + # get tokenizer + tokenizer = env.unwrapped.get_attr("tokenizer", [0]) + tokenizer = tokenizer[0] + + # Switch to eval mode + self.policy.set_training_mode(False) + + # reset rollout buffer and stats + rollout_buffer.reset() + + # start the rollout process + rollout_info = { + "rollout_info/ep_rew": [], + "rollout_info/kl_div_mean": [], + "rollout_info/ep_lens": [], + "rollout_info/ep_kl_rew": [], + "rollout_info/log_prob": [], + "rollout_info/ref_log_prob": [], + "rollout_info/values": [], + } + while not rollout_buffer.full: + # generate batch of rollouts + rollout_info = self.generate_batch( + rollout_buffer, tokenizer, max_steps, rollout_info + ) + + # aggregate rollout info + aggregated_rollout_info = {} + for key, values in rollout_info.items(): + aggregated_rollout_info[key] = np.mean(values).item() + aggregated_rollout_info[f"{key}_std"] = np.std(values).item() + aggregated_rollout_info[ + "rollout_info/kl_coeff" + ] = self._kl_controller.kl_coeff + + if self.tracker is not None: + self.tracker.log_rollout_infos(aggregated_rollout_info) + + # adapt the KL coeff + self._kl_controller.step( + torch.tensor(aggregated_rollout_info["rollout_info/kl_div_mean"]) + ) + return True + + # instantiate the wrapped alg + alg = OnPolicyAlgText(alg_kwargs, kl_coeff, tracker, target_kl, norm_reward) + return alg + + + + + +class RL4LMsSummaAgent(parl.Agent): + def __init__(self, algorithm, config): + super(RL4LMsSummaAgent, self).__init__(algorithm) + self.dataset = None + self.config = config + + def learn(self, *args, **kwargs): + pass + + def predict(self, *args, **kwargs): + pass + + def sample(self, *args, **kwargs): + pass diff --git a/benchmark/torch/RL4LMs/summarization/rl4lms_summa_model.py b/benchmark/torch/RL4LMs/summarization/rl4lms_summa_model.py new file mode 100644 index 000000000..7bcc5588a --- /dev/null +++ b/benchmark/torch/RL4LMs/summarization/rl4lms_summa_model.py @@ -0,0 +1,7 @@ +import parl +import torch +import torch.nn as nn + + +class RL4LMsSummaModel(parl.Model): + pass \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py new file mode 100644 index 000000000..8ee888815 --- /dev/null +++ b/benchmark/torch/RL4LMs/train.py @@ -0,0 +1,79 @@ +import os +from argparse import ArgumentParser + +import yaml +import collections +from trainers import OnPolicyTrainer +from utils import Tracker + + +def recursive_dict_update(d, u): + for k, v in u.items(): + if isinstance(v, collections.Mapping): + d[k] = recursive_dict_update(d.get(k, {}), v) + else: + d[k] = v + return d + + +def main(config): + + # load tracker + tracker = Tracker( + config["base_path_to_store_results"], + config, + config["project_name"], + config["experiment_name"], + config["entity_name"], + False, + ) + + # instantiate the trainer here + # TODO: currently only complete ppo + if "ppo" == config["alg"]["id"]: + trainer = OnPolicyTrainer( + tokenizer_config=config["tokenizer"], + datapool_config=config["datapool"], + reward_config=config["reward_fn"], + env_config=config["env"], + on_policy_alg_config=config["alg"], + train_eval_config=config["train_evaluation"], + tracker=tracker, + ) + else: + raise NotImplementedError + trainer.train_and_eval() + + + + + + +if __name__ == '__main__': + parser = ArgumentParser(description="Fine-tune LM to generate controlled text") + parser.add_argument("--config_path", type=str, help="path to the config file") + parser.add_argument( + "--project_name", type=str, help="project name", default="rl4lm_exps" + ) + parser.add_argument( + "--experiment_name", + type=str, + help="experiment name", + default="rl4lm_experiment", + ) + parser.add_argument( + "--base_path_to_store_results", + type=str, + help="Base path to store experiment results", + default=os.getcwd(), + ) + args = parser.parse_args() + + # load the config file + with open(args.config_path, "r") as fp: + config = yaml.safe_load(fp) + + recursive_dict_update(config, args) + + main(config) + diff --git a/benchmark/torch/RL4LMs/trainers.py b/benchmark/torch/RL4LMs/trainers.py new file mode 100644 index 000000000..78bd390ac --- /dev/null +++ b/benchmark/torch/RL4LMs/trainers.py @@ -0,0 +1,219 @@ +import os +from functools import partial +from typing import Any, Dict, List +import numpy as np + +from benchmark.torch.RL4LMs.utils import Sample +from benchmark.torch.RL4LMs.env import TextGenEnv +from rl4lms.envs.text_generation.evaluation_utils import evaluate_on_samples +from rl4lms.envs.text_generation.logging_utils import Tracker +from rl4lms.envs.text_generation.registry import (DataPoolRegistry, + MetricRegistry, + RewardFunctionRegistry, + PolicyRegistry, + AlgorithmRegistry, + WrapperRegistry) +from rl4lms.envs.text_generation.reward import RewardFunction +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env import SubprocVecEnv +from transformers import (AutoTokenizer, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + Trainer, + TrainingArguments, + DataCollatorForLanguageModeling, + DataCollatorForSeq2Seq) + +from rl4lms.envs.text_generation.warm_start import TrainerWarmStartMixin + + + + +def build_tokenizer(tokenizer_config: Dict[str, Any]): + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_config["model_name"]) + if tokenizer.pad_token is None and tokenizer_config.get("pad_token_as_eos_token", True): + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = tokenizer_config.get( + "padding_side", "left") + tokenizer.truncation_side = tokenizer_config.get( + "truncation_side", "left") + return tokenizer + + +def build_reward_fn(reward_config: Dict[str, Any]): + reward_fn = RewardFunctionRegistry.get(reward_config["id"], + reward_config.get("args", {})) + return reward_fn + + +def build_metrics(metric_configs: List[Dict[str, Any]]): + metrics = [MetricRegistry.get(metric_config["id"], metric_config.get("args", {})) + for metric_config in metric_configs] + return metrics + + +def build_datapool(datapool_config: Dict[str, Any]): + + def _get_datapool_by_split(split: str): + kwargs = datapool_config.get("args", {}) + kwargs["split"] = split + dp_split = DataPoolRegistry.get(datapool_config["id"], kwargs) + return dp_split + + train_datapool = _get_datapool_by_split("train") + val_datapool = _get_datapool_by_split("val") + test_datapool = _get_datapool_by_split("test") + + samples_by_split = { + "train": [(sample, weight) + for sample, weight in train_datapool], + "val": [sample for sample, _ in val_datapool], + "test": [sample for sample, _ in test_datapool] + } + return samples_by_split + + +def build_env(env_config: Dict[str, Any], + reward_fn: RewardFunction, + tokenizer: AutoTokenizer, + train_samples: List[Sample]): + # vectoried env + env_kwargs = { + "reward_function": reward_fn, + "tokenizer": tokenizer, + "samples": train_samples, + } + env_kwargs = {**env_kwargs, **env_config.get("args", {})} + env = make_vec_env(TextGenEnv, + n_envs=env_config.get( + "n_envs", 1), + vec_env_cls=SubprocVecEnv, + env_kwargs=env_kwargs) + return env + + +def build_alg(alg_config: Dict[str, Any], + env: TextGenEnv, + tracker: Tracker, + policy_state: Dict[str, Any], + alg_state: Dict[str, Any]): + # TBD - move these to a registry once the experimentation is done + # Also switch to Sb3 algos when possible with minimal code adaptations + policy_config = alg_config["policy"] + policy_cls = PolicyRegistry.get(policy_config["id"]) + alg_cls = AlgorithmRegistry.get(alg_config["id"]) + + policy_args = policy_config["args"] + policy_args["state_dict"] = policy_state + alg_kwargs = { + "policy": policy_cls, + "env": env, + "policy_kwargs": policy_args, + } + alg_kwargs = {**alg_kwargs, **alg_config.get("args")} + wrapper = WrapperRegistry.get(alg_config["id"]) + alg = wrapper(alg_cls, alg_kwargs, + alg_config["kl_div"]["coeff"], tracker, + alg_config["kl_div"].get("target_kl", None), + alg_config["kl_div"].get("norm_reward", False)) + alg.load_from_dict(alg_state) + return alg + + +class OnPolicyTrainer(TrainerWarmStartMixin): + """ + A generic trainer for training LMs with onpolicy algorithms from SB3 + """ + + def __init__(self, + tokenizer_config: Dict[str, Any], + datapool_config: Dict[str, Any], + reward_config: Dict[str, Any], + env_config: Dict[str, Any], + on_policy_alg_config: Dict[str, Any], + train_eval_config: Dict[str, Any], + tracker: Tracker = None, + experiment_name: str = '' + ): + self._tokenizer_config = tokenizer_config + self._datapool_config = datapool_config + self._reward_config = reward_config + self._env_config = env_config + self._on_policy_alg_config = on_policy_alg_config + self._train_eval_config = train_eval_config + self._tracker = tracker + self._experiment_name = experiment_name + self._setup() + + def _setup(self): + # load trainer state from available previous checkpoint if available + self.load_trainer_state(self._tracker) + + # build components + self._tokenizer = build_tokenizer(self._tokenizer_config) + self._reward_fn = build_reward_fn(self._reward_config) + self._metrics = build_metrics( + self._train_eval_config.get("metrics", [])) + self._samples_by_split = build_datapool(self._datapool_config) + self._env = build_env(self._env_config, self._reward_fn, + self._tokenizer, self._samples_by_split["train"]) + self._alg = build_alg(self._on_policy_alg_config, + self._env, self._tracker, + self._policy_state_dict, + self._alg_state_dict) + + # extract train params + self._max_episode_length = self._env_config["args"]["max_episode_length"] + self._max_prompt_length = self._env_config["args"]["max_prompt_length"] + self._eval_batch_size = self._train_eval_config["eval_batch_size"] + self._n_iters = int(self._train_eval_config["n_iters"]) + self._n_steps_per_iter = self._env.num_envs * self._alg.n_steps + + # gen kwargs for evaluation (if it is different from rollout gen kwargs) + self._eval_gen_kwargs = self._train_eval_config.get( + "generation_kwargs", None) + + def _evaluate_on_datapools(self, epoch: int, + splits: List[str] = ["val", "test"]): + for split in splits: + evaluate_on_samples(policy=self._alg.policy, + tokenizer=self._tokenizer, + samples=self._samples_by_split[split], + batch_size=self._eval_batch_size, + max_prompt_length=self._max_prompt_length, + metrics=self._metrics, + epoch=epoch, + split_name=split, + tracker=self._tracker, + gen_kwargs=self._eval_gen_kwargs) + + def train_and_eval(self): + # evaluate on val and test set before fine-tuning once + iter_start = self._trainer_state["current_iter"] + self._evaluate_on_datapools(epoch=iter_start) + + # train for given number of iters + for epoch in range(iter_start, self._n_iters): + # current state + self._trainer_state["current_iter"] = epoch + + # inner rollout and learn loop for on-policy algorithm + self._alg.learn(self._n_steps_per_iter) + + # save the policy checkpoint + if (epoch + 1) % self._train_eval_config.get("save_every", 20) == 0: + self.save_trainer_state( + self._tracker, self._alg.policy, self._trainer_state) + + # evaluate on val set in the given intervals + if (epoch + 1) % self._train_eval_config["eval_every"] == 0: + self._evaluate_on_datapools(epoch=epoch, splits=["val"]) + + # finally evaluate on val and test samples + self._evaluate_on_datapools(epoch=epoch) + + # save model here - we save only the language model + if self._tracker is not None: + self._tracker.save_auto_model( + self._alg.policy.get_language_model()) \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/__init__.py b/benchmark/torch/RL4LMs/utils/__init__.py new file mode 100644 index 000000000..ec9908986 --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/__init__.py @@ -0,0 +1,24 @@ +from .data_wrapper import EvaluateActionsOutput, PolicyOutput, \ + RefPolicyOutput, ValueOutput, GenerationInputs, GenerationOutputs,\ + PolicyType, Sample, Observation, TransitionInfo + + +from .huggingface_generation_util import override_generation_routines + +from .warm_start import ActorCriticWarmStartMixin, OnPolicyWarmStartMixin + +from .type_wrapper import TensorDict, Schedule + +from .distribution_wrapper import CategoricalDistribution + +from .reward_util import RewardFunction, BatchedRewardFunction + +from .sample_util import PrioritySampler + +from .buffer import DictRolloutBuffer, RolloutBuffer,\ + MaskableDictRolloutBuffer, MaskableRolloutBuffer + +from .kl_controller import KLController + +from .tracker import Tracker + diff --git a/benchmark/torch/RL4LMs/utils/buffer.py b/benchmark/torch/RL4LMs/utils/buffer.py new file mode 100644 index 000000000..380dc1435 --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/buffer.py @@ -0,0 +1,698 @@ +import warnings +from abc import ABC, abstractmethod +from typing import Any, Dict, Generator, List, Optional, Union, Tuple + +import numpy as np +import torch +from gym import spaces + +from .data_wrapper import RolloutBufferSamples, DictRolloutBufferSamples,\ + MaskableRolloutBufferSamples, MaskableDictRolloutBufferSamples + +try: + # Check memory used by replay buffer when possible + import psutil +except ImportError: + psutil = None + + +def get_action_dim(action_space: spaces.Space) -> int: + """ + Get the dimension of the action space. + + :param action_space: + :return: + """ + if isinstance(action_space, spaces.Box): + return int(np.prod(action_space.shape)) + elif isinstance(action_space, spaces.Discrete): + # Action is an int + return 1 + elif isinstance(action_space, spaces.MultiDiscrete): + # Number of discrete actions + return int(len(action_space.nvec)) + elif isinstance(action_space, spaces.MultiBinary): + # Number of binary actions + return int(action_space.n) + else: + raise NotImplementedError(f"{action_space} action space is not supported") + + +def get_obs_shape( + observation_space: spaces.Space, +) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]: + """ + Get the shape of the observation (useful for the buffers). + + :param observation_space: + :return: + """ + if isinstance(observation_space, spaces.Box): + return observation_space.shape + elif isinstance(observation_space, spaces.Discrete): + # Observation is an int + return (1,) + elif isinstance(observation_space, spaces.MultiDiscrete): + # Number of discrete features + return (int(len(observation_space.nvec)),) + elif isinstance(observation_space, spaces.MultiBinary): + # Number of binary features + return (int(observation_space.n),) + elif isinstance(observation_space, spaces.Dict): + return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} + + else: + raise NotImplementedError(f"{observation_space} observation space is not supported") + + +class BaseBuffer(ABC): + """ + Base class that represent a buffer (rollout or replay) + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: PyTorch device + to which the values will be converted + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[torch.device, str] = "cpu", + n_envs: int = 1, + ): + super().__init__() + self.buffer_size = buffer_size + self.observation_space = observation_space + self.action_space = action_space + self.obs_shape = get_obs_shape(observation_space) + + self.action_dim = get_action_dim(action_space) + self.pos = 0 + self.full = False + self.device = device + self.n_envs = n_envs + + @staticmethod + def swap_and_flatten(arr: np.ndarray) -> np.ndarray: + """ + Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) + to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) + to [n_steps * n_envs, ...] (which maintain the order) + + :param arr: + :return: + """ + shape = arr.shape + if len(shape) < 3: + shape = shape + (1,) + return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) + + def size(self) -> int: + """ + :return: The current size of the buffer + """ + if self.full: + return self.buffer_size + return self.pos + + def add(self, *args, **kwargs) -> None: + """ + Add elements to the buffer. + """ + raise NotImplementedError() + + def extend(self, *args, **kwargs) -> None: + """ + Add a new batch of transitions to the buffer + """ + # Do a for loop along the batch axis + for data in zip(*args): + self.add(*data) + + def reset(self) -> None: + """ + Reset the buffer. + """ + self.pos = 0 + self.full = False + + def sample(self, batch_size: int, env = None): + """ + :param batch_size: Number of element to sample + :param env: associated gym VecEnv + to normalize the observations/rewards when sampling + :return: + """ + upper_bound = self.buffer_size if self.full else self.pos + batch_inds = np.random.randint(0, upper_bound, size=batch_size) + return self._get_samples(batch_inds, env=env) + + @abstractmethod + def _get_samples( + self, batch_inds: np.ndarray, env = None + ) -> RolloutBufferSamples: + """ + :param batch_inds: + :param env: + :return: + """ + raise NotImplementedError() + + def to_torch(self, array: np.ndarray, copy: bool = True) -> torch.Tensor: + """ + Convert a numpy array to a PyTorch tensor. + Note: it copies the data by default + + :param array: + :param copy: Whether to copy or not the data + (may be useful to avoid changing things be reference) + :return: + """ + if copy: + return torch.tensor(array).to(self.device) + return torch.as_tensor(array).to(self.device) + + @staticmethod + def _normalize_obs( + obs: Union[np.ndarray, Dict[str, np.ndarray]], + env = None, + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + if env is not None: + return env.normalize_obs(obs) + return obs + + @staticmethod + def _normalize_reward(reward: np.ndarray, env = None) -> np.ndarray: + if env is not None: + return env.normalize_reward(reward).astype(np.float32) + return reward + + + +class RolloutBuffer(BaseBuffer): + """ + Rollout buffer used in on-policy algorithms like A2C/PPO. + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[torch.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + + super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + self.gae_lambda = gae_lambda + self.gamma = gamma + self.observations, self.actions, self.rewards, self.advantages = None, None, None, None + self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None + self.generator_ready = False + self.reset() + + def reset(self) -> None: + + self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32) + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.generator_ready = False + super().reset() + + def compute_returns_and_advantage(self, last_values: torch.Tensor, dones: np.ndarray) -> None: + """ + Post-processing step: compute the lambda-return (TD(lambda) estimate) + and GAE(lambda) advantage. + + Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) + to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S)) + where R is the sum of discounted reward with value bootstrap + (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization. + + The TD(lambda) estimator has also two special cases: + - TD(1) is Monte-Carlo estimate (sum of discounted rewards) + - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1})) + + For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375. + + :param last_values: state value estimation for the last step (one for each env) + :param dones: if the last step was a terminal step (one bool for each env). + """ + # Convert to numpy + last_values = last_values.clone().cpu().numpy().flatten() + + last_gae_lam = 0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - dones + next_values = last_values + else: + next_non_terminal = 1.0 - self.episode_starts[step + 1] + next_values = self.values[step + 1] + delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] + last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam + self.advantages[step] = last_gae_lam + # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)" + # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA + self.returns = self.advantages + self.values + + def add( + self, + obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: torch.Tensor, + log_prob: torch.Tensor, + ) -> None: + """ + :param obs: Observation + :param action: Action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob: log probability of the action + following the current policy. + """ + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space, spaces.Discrete): + obs = obs.reshape((self.n_envs,) + self.obs_shape) + + self.observations[self.pos] = np.array(obs).copy() + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + + _tensor_names = [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + ] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray, env = None) -> RolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten(), + self.log_probs[batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + ) + return RolloutBufferSamples(*tuple(map(self.to_torch, data))) + + + +class DictRolloutBuffer(RolloutBuffer): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RolloutBuffer to use dictionary observations + + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to Monte-Carlo advantage estimate when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[torch.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + + super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + + assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" + + self.gae_lambda = gae_lambda + self.gamma = gamma + self.observations, self.actions, self.rewards, self.advantages = None, None, None, None + self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None + self.generator_ready = False + self.reset() + + def reset(self) -> None: + assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" + self.observations = {} + for key, obs_input_shape in self.obs_shape.items(): + self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32) + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.generator_ready = False + super(RolloutBuffer, self).reset() + + def add( + self, + obs: Dict[str, np.ndarray], + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: torch.Tensor, + log_prob: torch.Tensor, + ) -> None: + """ + :param obs: Observation + :param action: Action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob: log probability of the action + following the current policy. + """ + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + for key in self.observations.keys(): + obs_ = np.array(obs[key]).copy() + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space.spaces[key], spaces.Discrete): + obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key]) + self.observations[key][self.pos] = obs_ + + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray, env = None) -> DictRolloutBufferSamples: + + return DictRolloutBufferSamples( + observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, + actions=self.to_torch(self.actions[batch_inds]), + old_values=self.to_torch(self.values[batch_inds].flatten()), + old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), + advantages=self.to_torch(self.advantages[batch_inds].flatten()), + returns=self.to_torch(self.returns[batch_inds].flatten()), + ) + + +class MaskableRolloutBuffer(RolloutBuffer): + """ + Rollout buffer that also stores the invalid action masks associated with each observation. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__(self, *args, **kwargs): + self.action_masks = None + super().__init__(*args, **kwargs) + + def reset(self) -> None: + if isinstance(self.action_space, spaces.Discrete): + mask_dims = self.action_space.n + elif isinstance(self.action_space, spaces.MultiDiscrete): + mask_dims = sum(self.action_space.nvec) + elif isinstance(self.action_space, spaces.MultiBinary): + mask_dims = 2 * self.action_space.n # One mask per binary outcome + else: + raise ValueError( + f"Unsupported action space {type(self.action_space)}") + + self.mask_dims = mask_dims + self.action_masks = np.ones( + (self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32) + + super().reset() + + def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None: + """ + :param action_masks: Masks applied to constrain the choice of possible actions. + """ + if action_masks is not None: + self.action_masks[self.pos] = action_masks.reshape( + (self.n_envs, self.mask_dims)) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "action_masks", + ]: + self.__dict__[tensor] = self.swap_and_flatten( + self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx: start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray, env = None) -> MaskableRolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten(), + self.log_probs[batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + self.action_masks[batch_inds].reshape(-1, self.mask_dims), + ) + return MaskableRolloutBufferSamples(*map(self.to_torch, data)) + + + + + +class MaskableDictRolloutBuffer(DictRolloutBuffer): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RolloutBuffer to use dictionary observations + + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[torch.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.action_masks = None + super().__init__(buffer_size, observation_space, + action_space, device, gae_lambda, gamma, n_envs=n_envs) + + def reset(self) -> None: + if isinstance(self.action_space, spaces.Discrete): + mask_dims = self.action_space.n + elif isinstance(self.action_space, spaces.MultiDiscrete): + mask_dims = sum(self.action_space.nvec) + elif isinstance(self.action_space, spaces.MultiBinary): + mask_dims = 2 * self.action_space.n # One mask per binary outcome + else: + raise ValueError( + f"Unsupported action space {type(self.action_space)}") + + self.mask_dims = mask_dims + self.action_masks = np.ones( + (self.buffer_size, self.n_envs, self.mask_dims)) # .to(self.device) + + super().reset() + + def add(self, *args, action_masks: Optional[torch.Tensor] = None, **kwargs) -> None: + """ + :param action_masks: Masks applied to constrain the choice of possible actions. + """ + if action_masks is not None: + self.action_masks[self.pos] = action_masks.reshape( + (self.n_envs, self.mask_dims)) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + _tensor_names = ["actions", "values", "log_probs", + "advantages", "returns", "action_masks"] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten( + self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx: start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray, env = None) -> MaskableDictRolloutBufferSamples: + + return MaskableDictRolloutBufferSamples( + observations={key: self.to_torch(obs[batch_inds]) for ( + key, obs) in self.observations.items()}, + actions=self.to_torch(self.actions[batch_inds]), + old_values=self.to_torch(self.values[batch_inds].flatten()), + old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), + advantages=self.to_torch(self.advantages[batch_inds].flatten()), + returns=self.to_torch(self.returns[batch_inds].flatten()), + action_masks=self.to_torch( + self.action_masks[batch_inds].reshape(-1, self.mask_dims)), + ) \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/data_pool.py b/benchmark/torch/RL4LMs/utils/data_pool.py new file mode 100644 index 000000000..ad7de7769 --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/data_pool.py @@ -0,0 +1,116 @@ +from datasets import load_dataset +from .data_wrapper import Sample +from typing import Any, List, Dict +import random +from abc import abstractclassmethod +from tqdm import tqdm +from nltk.tokenize import word_tokenize + +class TextGenPool: + def __init__(self, samples: List[Sample]): + self._samples = samples + + def __len__(self): + return len(self._samples) + + def __getitem__(self, ix: int) -> Sample: + if ix >= len(self): + raise StopIteration + sample = self._samples[ix] + return sample, 1.0 + + def sample(self) -> Sample: + random_sample = random.choice(self._samples) + return random_sample + + @abstractclassmethod + def prepare(cls, **args) -> 'TextGenPool': + """ + A factory method to instantiate data pool + """ + raise NotImplementedError + + def split(self, split_ratios: List[float]) -> List['TextGenPool']: + start_ix = 0 + pools = [] + for ratio in split_ratios: + count = int(len(self) * ratio) + end_ix = start_ix + count + pools.append(type(self)(self._samples[start_ix: end_ix])) + start_ix = end_ix + return pools + +class CommonGen(TextGenPool): + @classmethod + def prepare(cls, split: str, + concept_separator_token: str = " ", + concept_end_token=" ", + prefix: str = "summarize: ") -> 'TextGenPool': + ds = load_dataset("gem", "common_gen") + samples = [] + split_id = CommonGen.gen_split_name(split) + for ix, item in enumerate(ds[split_id]): + concepts = concept_separator_token.join(item["concepts"]) + concepts = prefix + concepts + concepts += concept_end_token + if item["target"] == "": + # just to avoid breaking of metric computation + item["target"] = "empty reference" + targets = [item["target"]] + sample = Sample(id=f"{split}_{ix}", + prompt_or_input_text=concepts, + references=targets, + meta_data={ + "concepts": item["concepts"] + } + ) + samples.append(sample) + pool_instance = cls(samples) + return pool_instance + + @staticmethod + def gen_split_name(split: str): + if split == "train": + split_name = "train" + elif split == "val": + split_name = "validation" + elif split == "test": + split_name = "test" + else: + raise NotImplementedError + return split_name + + + +class CNNDailyMail(TextGenPool): + @classmethod + def prepare(cls, + split: str, + prompt_suffix: str = "", + prompt_prefix: str = "", + truncate_article: int = None, + max_size: int = None): + dataset = load_dataset("cnn_dailymail", "3.0.0") + dataset_split = CommonGen.gen_split_name(split) + samples = [] + for ix, item in tqdm(enumerate(dataset[dataset_split]), + desc="Tokenizing dataset", + total=len(dataset[dataset_split])): + + if truncate_article is not None: + tokens = word_tokenize(item["article"]) + tokens = tokens[:truncate_article] + item["article"] = " ".join(tokens) + + sample = Sample(id=f"{split}_{ix}", + prompt_or_input_text=prompt_prefix + + item["article"] + prompt_suffix, + references=[item["highlights"]] + ) + samples.append(sample) + + if max_size is not None and ix == (max_size-1): + break + + pool_instance = cls(samples) + return pool_instance \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/data_wrapper.py b/benchmark/torch/RL4LMs/utils/data_wrapper.py new file mode 100644 index 000000000..4917dd49f --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/data_wrapper.py @@ -0,0 +1,327 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional +from transformers import AutoTokenizer +from copy import deepcopy +from .type_wrapper import TensorDict +import torch +from typing import NamedTuple +import torch +import numpy as np + + +@dataclass +class TransitionInfo: + observation: TensorDict + action: np.ndarray + task_reward: np.ndarray + total_reward: np.ndarray + kl_div: np.ndarray + episode_start: np.ndarray + value: torch.Tensor + log_prob: torch.Tensor + done: np.ndarray + ref_log_prob: torch.Tensor + kl_reward: np.ndarray + action_mask: np.ndarray + info: Dict[str, Any] + + +class MaskableRolloutBufferSamples(NamedTuple): + observations: torch.Tensor + actions: torch.Tensor + old_values: torch.Tensor + old_log_prob: torch.Tensor + advantages: torch.Tensor + returns: torch.Tensor + action_masks: torch.Tensor + +class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples): + observations: TensorDict + actions: torch.Tensor + old_values: torch.Tensor + old_log_prob: torch.Tensor + advantages: torch.Tensor + returns: torch.Tensor + action_masks: torch.Tensor + + +class RolloutBufferSamples(NamedTuple): + observations: torch.Tensor + actions: torch.Tensor + old_values: torch.Tensor + old_log_prob: torch.Tensor + advantages: torch.Tensor + returns: torch.Tensor + + +class DictRolloutBufferSamples(RolloutBufferSamples): + observations: TensorDict + actions: torch.Tensor + old_values: torch.Tensor + old_log_prob: torch.Tensor + advantages: torch.Tensor + returns: torch.Tensor + + +@dataclass(init=True) +class Sample: + id: str + prompt_or_input_text: str + references: List[str] + meta_data: Dict[str, Any] = None + + + +class PolicyType(Enum): + CAUSAL = 0 + SEQ2SEQ = 1 + + +@dataclass +class EvaluateActionsOutput: + """ + Dataclass for the output of the method policy.evaluate_actions(). + This is invoked during training phase for each mini-batch in the rollout buffer + """ + + # values of the given state + values: torch.tensor + # log prob of chosen actions + log_prob: torch.tensor + # entropy of action dist + entropy: torch.tensor + + +@dataclass +class PolicyOutput: + """ + Dataclass for the output of the method policy.foward_policy() + """ + + # chosen actions by policy + actions: torch.tensor + # raw log probs corresponding to chosen actions + raw_log_probs: torch.tensor + # processed log probs (eg: after action masking) for chosen actions + log_probs: torch.tensor + # entropy of action dist + entropy: torch.tensor + # cached policy activations for sequential forward passes + past_model_kwargs: torch.tensor + + +@dataclass +class RefPolicyOutput: + """ + Dataclass for the output of the method policy.get_ref_log_probs() + """ + + # ref log_probs for corresponding observation and chosen action + log_probs: torch.tensor + # cached policy activations for sequential forward passes + past_model_kwargs: torch.tensor + + +@dataclass +class ValueOutput: + """ + Dataclass for the output of the method policy.forward_value() + """ + + # values corresponding to given state + values: torch.tensor + # cached value activations for sequential forward passes + past_model_kwargs: Dict[str, torch.tensor] + + +@dataclass +class GenerationInputs: + # prompt inputs + inputs: torch.tensor + # prompt attention masks + attention_masks: torch.tensor + + +@dataclass +class GenerationOutputs: + # log probs at each time step + step_wise_logprobs: List[List[torch.tensor]] + # actions at each time step + step_wise_actions: List[torch.tensor] + # generated tokens + gen_tokens: List[List[int]] + # generated texts + gen_texts: List[str] + # action masks + action_masks: List[torch.tensor] = None + + +@dataclass +class Observation: + # encoded input + prompt_or_input_encoded_pt: torch.tensor + # attention mask for the input + prompt_or_input_attention_mask_pt: torch.tensor + # input text + prompt_or_input_text: str + # encoded context + context_encoded_pt: torch.tensor + # attention mask for the context + context_attention_mask_pt: torch.tensor + # context text + context_text: str + # reference texts + target_or_reference_texts: List[str] + + # concatenated input + input_encoded_pt: torch.tensor + input_attention_mask_pt: torch.tensor + + # list of actions + action_history: List[str] + + # other meta info + meta_info: Dict[str, Any] + + def to_dict(self) -> Dict[str, torch.tensor]: + """ + For stable baselines (only return tensor items) + """ + dict_obs = { + "prompt_or_input_encoded_pt": self.prompt_or_input_encoded_pt.numpy().flatten(), + "prompt_or_input_attention_mask_pt": self.prompt_or_input_attention_mask_pt.numpy().flatten(), + "context_encoded_pt": self.context_encoded_pt.numpy().flatten(), + "context_attention_mask_pt": self.context_attention_mask_pt.numpy().flatten(), + "input_encoded_pt": self.input_encoded_pt.numpy().flatten(), + "input_attention_mask_pt": self.input_attention_mask_pt.numpy().flatten() + + } + return dict_obs + + @staticmethod + def _concat(prompt: torch.tensor, prompt_mask: torch.tensor, + context: torch.tensor, context_mask: torch.tensor, + pad_token: int): + + prompt_ = prompt[:, prompt_mask.flatten().bool().tolist()] + context_ = context[:, context_mask.flatten().bool().tolist()] + actual_size = prompt_.shape[1] + context_.shape[1] + + full_size = prompt.shape[1] + context.shape[1] + concatenated = torch.full( + (full_size,), fill_value=pad_token).reshape(1, -1) + concatenated_mask = torch.zeros((1, full_size)).int() + + concatenated[:, full_size - + actual_size:] = torch.cat((prompt_, context_), dim=1) + concatenated_mask[:, full_size - + actual_size:] = 1 + return concatenated, concatenated_mask + + def update(self, action: int, tokenizer: AutoTokenizer) -> "Observation": + """ + Updates the observation using the given action + """ + + # update the action history + current_action_history = deepcopy(self.action_history) + current_action_history.append(tokenizer._convert_id_to_token(action)) + + # get the current context + current_context = deepcopy(self.context_encoded_pt) + current_context_attention_mask = deepcopy( + self.context_attention_mask_pt) + + # just shift the context (also the attention mask) to left by 1 + current_context[:, 0:-1] = current_context[:, 1:].clone() + current_context_attention_mask[:, 0:- + 1] = current_context_attention_mask[:, 1:].clone() + + # add the action always at the end (assumes left padding) + current_context[:, -1] = action + current_context_attention_mask[:, -1] = 1 + + # decode the context + context_text = tokenizer.decode( + current_context.flatten(), skip_special_tokens=True) + + # concatenate and still keep the left padding + input_encoded_pt, input_attention_mask_pt = Observation._concat( + self.prompt_or_input_encoded_pt, self.prompt_or_input_attention_mask_pt, + current_context, current_context_attention_mask, + tokenizer.pad_token_id) + + # and create a new observation + obs = Observation(self.prompt_or_input_encoded_pt, + self.prompt_or_input_attention_mask_pt, + self.prompt_or_input_text, + current_context, + current_context_attention_mask, + context_text, + self.target_or_reference_texts, + input_encoded_pt, + input_attention_mask_pt, + current_action_history, + self.meta_info) + + return obs + + @classmethod + def init_from_sample(cls, sample: Sample, + tokenizer: AutoTokenizer, + max_input_length: int, + max_context_length: int, + prompt_truncation_side: str, + context_start_token: int = None, + meta_info: Dict[str, Any] = None): + # encode the prompt text + # override truncation side for prompt + prev_truncation_side = tokenizer.truncation_side + tokenizer.truncation_side = prompt_truncation_side + prompt_outputs = tokenizer(sample.prompt_or_input_text, + padding="max_length", + max_length=max_input_length, + return_tensors="pt", + return_attention_mask=True, + truncation=True) + tokenizer.truncation_side = prev_truncation_side + + # for seq2seq models, context should be initialized to start token if provided + if context_start_token is not None: + context_outputs = tokenizer("", + padding="max_length", + max_length=max_context_length, + return_tensors="pt", + return_attention_mask=True) + context_outputs.input_ids = torch.ones(1, max_context_length, dtype=torch.int32) * tokenizer.pad_token_id + context_outputs.input_ids[:, -1] = context_start_token + context_outputs.attention_mask = torch.zeros(1, max_context_length, dtype=torch.int32) + context_outputs.attention_mask[:, -1] = 1 + else: + context_outputs = tokenizer("", + padding="max_length", + max_length=max_context_length, + return_tensors="pt", + return_attention_mask=True) + + # concatenate + input_encoded_pt, input_attention_mask_pt = Observation._concat( + prompt_outputs.input_ids, prompt_outputs.attention_mask, + context_outputs.input_ids, context_outputs.attention_mask, + tokenizer.pad_token_id) + + obs = Observation(prompt_or_input_encoded_pt=prompt_outputs.input_ids, + prompt_or_input_attention_mask_pt=prompt_outputs.attention_mask, + prompt_or_input_text=sample.prompt_or_input_text, + context_encoded_pt=context_outputs.input_ids, + context_attention_mask_pt=context_outputs.attention_mask, + input_encoded_pt=input_encoded_pt, + input_attention_mask_pt=input_attention_mask_pt, + context_text="", + target_or_reference_texts=sample.references, + action_history=[], + meta_info=meta_info) + + return obs + diff --git a/benchmark/torch/RL4LMs/utils/distribution_wrapper.py b/benchmark/torch/RL4LMs/utils/distribution_wrapper.py new file mode 100644 index 000000000..bcb5bca5a --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/distribution_wrapper.py @@ -0,0 +1,68 @@ +# refer to stable_baselines3.common.distributions +from torch import nn +from torch.distributions import Categorical +from typing import Tuple +import torch + +class CategoricalDistribution: + """ + Categorical distribution for discrete actions. + + :param action_dim: Number of discrete actions + """ + + def __init__(self, action_dim: int): + super().__init__() + self.action_dim = action_dim + + def proba_distribution_net(self, latent_dim: int) -> nn.Module: + """ + Create the layer that represents the distribution: + it will be the logits of the Categorical distribution. + You can then get probabilities using a softmax. + + :param latent_dim: Dimension of the last layer + of the policy network (before the action layer) + :return: + """ + action_logits = nn.Linear(latent_dim, self.action_dim) + return action_logits + + def proba_distribution(self, action_logits: torch.Tensor) -> "CategoricalDistribution": + self.distribution = Categorical(logits=action_logits) + return self + + def log_prob(self, actions: torch.Tensor) -> torch.Tensor: + return self.distribution.log_prob(actions) + + def entropy(self) -> torch.Tensor: + return self.distribution.entropy() + + def sample(self) -> torch.Tensor: + return self.distribution.sample() + + def mode(self) -> torch.Tensor: + return torch.argmax(self.distribution.probs, dim=1) + + + def actions_from_params(self, action_logits: torch.Tensor, deterministic: bool = False) -> torch.Tensor: + # Update the proba distribution + self.proba_distribution(action_logits) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, action_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + actions = self.actions_from_params(action_logits) + log_prob = self.log_prob(actions) + return actions, log_prob + + + def get_actions(self, deterministic: bool = False) -> torch.Tensor: + """ + Return actions according to the probability distribution. + + :param deterministic: + :return: + """ + if deterministic: + return self.mode() + return self.sample() \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/evaluation_util.py b/benchmark/torch/RL4LMs/utils/evaluation_util.py new file mode 100644 index 000000000..5bb317d71 --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/evaluation_util.py @@ -0,0 +1,125 @@ +from typing import Any, Dict, List + +from benchmark.torch.RL4LMs.models import BasePolicy +from tqdm import tqdm +from transformers import AutoTokenizer + +from . import Sample +from .metric_util import BaseMetric + + +def get_batch(samples: List[Sample], batch_size: int): + current_ix = 0 + n_samples = len(samples) + while current_ix < n_samples: + current_batch = samples[current_ix : current_ix + batch_size] + yield current_batch + current_ix += batch_size + + +def evaluate_on_samples( + policy: BasePolicy, + tokenizer: AutoTokenizer, + samples: List[Sample], + batch_size: int, + max_prompt_length: int, + metrics: List[BaseMetric], + epoch: int, + split_name: str, + # tracker: Tracker = None, + tracker = None, # TODO: change tracker to parl logging + dt_control_token: str = "", + gen_kwargs: Dict[str, Any] = None, +): + # generate text by batch + all_generated_texts = [] + all_ref_texts = [] + all_prompt_texts = [] + all_meta_infos = [] + ###########CHANGE FOR DEBUG############ + tem = [] + for i in range(200): + tem.append(samples[i]) + samples = tem + ###########CHANGE FOR DEBUG############ + + + + n_samples = len(samples) + for batch in tqdm(list(get_batch(samples, batch_size)), desc="Evaluating"): + batch_generated_texts = generate_text( + policy, tokenizer, batch, max_prompt_length, dt_control_token, gen_kwargs + ) + batch_ref_texts = [sample.references for sample in batch] + batch_prompt_texts = [sample.prompt_or_input_text for sample in batch] + batch_meta_infos = [sample.meta_data for sample in batch] + all_generated_texts.extend(batch_generated_texts) + all_ref_texts.extend(batch_ref_texts) + all_prompt_texts.extend(batch_prompt_texts) + all_meta_infos.extend(batch_meta_infos) + + # compute metrics + corpus_level_metrics = {} + sample_scores_by_metric = {} + if metrics is not None: + for metric in metrics: + metric_dict = metric.compute( + all_prompt_texts, + all_generated_texts, + all_ref_texts, + all_meta_infos, + policy.get_language_model(), + split_name, + ) + + for metric_key, (sample_scores, corpus_score) in metric_dict.items(): + if sample_scores is None: + sample_scores = ["n/a"] * n_samples + corpus_level_metrics[metric_key] = corpus_score + sample_scores_by_metric[metric_key] = sample_scores + + # aggregate sample metric scores + sample_predictions_dict = [] + for ix, (sample, prompt_text, generated_text, ref_texts) in enumerate( + zip(samples, all_prompt_texts, all_generated_texts, all_ref_texts) + ): + sample_prediction = { + "split_name": split_name, + "sample_id": sample.id, + "prompt_text": prompt_text, + "generated_text": generated_text, + "ref_text": "".join( + [ + f"" + ref_text + f"" + for ref_ix, ref_text in enumerate(ref_texts) + ] + ), + } + for metric_key, sample_scores in sample_scores_by_metric.items(): + sample_prediction[metric_key] = sample_scores[ix] + sample_predictions_dict.append(sample_prediction) + + + # TODO: change tracker to parl logging + # if tracker is not None: + # # log the entire predictions + # tracker.log_predictions(epoch, split_name, sample_predictions_dict) + # # log the corpus level scores + # tracker.log_metrics(epoch, split_name, corpus_level_metrics) + + +def generate_text( + policy: BasePolicy, + tokenizer: AutoTokenizer, + samples: List[Sample], + max_prompt_length: int, + dt_control_token: str, + gen_kwargs: Dict[str, Any], +): + prompt_texts = [ + dt_control_token + sample.prompt_or_input_text for sample in samples + ] + generated_texts = policy.generate( + tokenizer, prompt_texts, max_prompt_length, gen_kwargs=gen_kwargs + ).gen_texts + return generated_texts diff --git a/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py b/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py new file mode 100644 index 000000000..a83cd2284 --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py @@ -0,0 +1,3492 @@ +# coding=utf-8 +# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers.generation_utils import GenerationMixin +import inspect +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch import nn + +from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint +from transformers.generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer +from transformers.generation_logits_process import ( + EncoderNoRepeatNGramLogitsProcessor, + ExponentialDecayLengthPenalty, + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, + HammingDiversityLogitsProcessor, + InfNanRemoveLogitsProcessor, + LogitsProcessorList, + MinLengthLogitsProcessor, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, + PrefixConstrainedLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TypicalLogitsWarper, +) +from transformers.generation_stopping_criteria import ( + MaxLengthCriteria, + MaxTimeCriteria, + StoppingCriteria, + StoppingCriteriaList, + validate_stopping_criteria, +) +from transformers.pytorch_utils import torch_int_div +from transformers.utils import ModelOutput, logging + + +logger = logging.get_logger(__name__) + + +@dataclass +class GreedySearchDecoderOnlyOutput(ModelOutput): + """ + Base class for outputs of decoder-only generation models using greedy search. + + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each + tensor of shape `(batch_size, config.vocab_size)`). + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class GreedySearchEncoderDecoderOutput(ModelOutput): + """ + Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention + weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the + encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape + `(batch_size, config.vocab_size)`). + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class SampleDecoderOnlyOutput(ModelOutput): + """ + Base class for outputs of decoder-only generation models using sampling. + + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each + tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`). + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length, + sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class SampleEncoderDecoderOutput(ModelOutput): + """ + Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of + the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states + attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape + `(batch_size*num_return_sequences, config.vocab_size)`). + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape + `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size*num_return_sequences, sequence_length, hidden_size)`. + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_return_sequences, num_heads, generated_length, + sequence_length)`. + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class BeamSearchDecoderOnlyOutput(ModelOutput): + """ + Base class for outputs of decoder-only generation models using beam search. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Final beam scores of the generated `sequences`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. + `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape + `(batch_size*num_beams*num_return_sequences, config.vocab_size)`). + beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped + tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class BeamSearchEncoderDecoderOutput(ModelOutput): + """ + Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights + of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states + attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Final beam scores of the generated `sequences`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. + `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, + config.vocab_size)`). + beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped + tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, + sequence_length)`. + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class BeamSampleDecoderOnlyOutput(ModelOutput): + """ + Base class for outputs of decoder-only generation models using beam sample. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Final beam scores of the generated `sequences`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. + `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape + `(batch_size*num_beams*num_return_sequences, config.vocab_size)`). + beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped + tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class BeamSampleEncoderDecoderOutput(ModelOutput): + """ + Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention + weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the + encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Final beam scores of the generated `sequences`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. + `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, + config.vocab_size)`). + beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped + tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size*num_beams, sequence_length, hidden_size)`. + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, + GreedySearchDecoderOnlyOutput] +SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] +BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, + BeamSearchDecoderOnlyOutput] +BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, + BeamSampleDecoderOnlyOutput] + + +class GenerationMixinWithRawScores: + """ + A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. + + The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for: + - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and + `do_sample=False`. + - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and + `do_sample=False`. + - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if + `num_beams>1` and `do_sample=True`. + - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if + `num_beams>1` and `num_beam_groups>1`. + - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`], + if `constraints!=None` or `force_words_ids!=None`. + """ + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: + """ + This function extracts the model-specific `inputs` for generation. + """ + # 1. retrieve all kwargs that are non-None or non-model input related. + # some encoder-decoder models have different names for model and encoder + if ( + self.config.is_encoder_decoder + and hasattr(self, "encoder") + and self.encoder.main_input_name != self.main_input_name + ): + input_name = self.encoder.main_input_name + else: + input_name = self.main_input_name + + model_kwargs = {k: v for k, v in model_kwargs.items( + ) if v is not None or k != input_name} + + # 2. check whether model_input_name is passed as kwarg + # if yes and `inputs` is None use kwarg inputs + inputs_kwarg = model_kwargs.pop(input_name, None) + if inputs_kwarg is not None and inputs is not None: + raise ValueError( + f"`inputs`: {inputs}` were passed alongside " + f"{input_name} which is not allowed." + f"Make sure to either pass {inputs} or {input_name}=..." + ) + elif inputs_kwarg is not None: + inputs = inputs_kwarg + + # 3. models with `input_ids` can also make use of `inputs_embeds` + if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs): + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" + + # 4. Only encoder-decoder models can have non `input_ids` input format + if not self.config.is_encoder_decoder and input_name != "input_ids": + raise ValueError( + f"If {input_name} is passed as model-specific keyword " + "input then model has to be an encoder-decoder and not a " + f"{self.__class__.__name__}." + ) + + # 5. if `inputs` is still None, try to create `input_ids` from BOS token + if inputs is None: + inputs = self._prepare_input_ids_for_generation( + bos_token_id, model_kwargs.get("encoder_outputs")) + + return inputs, input_name, model_kwargs + + def _can_retrieve_inputs_from_name( + self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor] + ) -> torch.Tensor: + """ + If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved + from name + """ + can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set( + inspect.signature(self.forward).parameters.keys() + ) + + if can_retrieve_inputs and inputs is not None: + raise ValueError( + f"Cannot only pass one of {name} and {self.main_input_name}") + + return can_retrieve_inputs + + def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: + """ + Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method. + """ + return {"input_ids": input_ids} + + def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: + """ + Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. + """ + return logits + + def _prepare_input_ids_for_generation( + self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput] + ) -> torch.LongTensor: + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.last_hidden_state.size()[:-1] + return torch.ones(shape, dtype=torch.long, device=self.device) * -100 + + if bos_token_id is None: + raise ValueError( + "`bos_token_id` has to be defined when no `input_ids` are provided.") + return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_attention_mask_for_generation( + self, + inputs: torch.Tensor, + pad_token_id: int, + eos_token_id: int, + ) -> torch.LongTensor: + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [ + torch.int, torch.long] + is_pad_token_in_inputs = (pad_token_id is not None) and ( + pad_token_id in inputs) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( + (eos_token_id is not None) and (pad_token_id != eos_token_id) + ) + # Check if input is input_ids and padded -> only then is attention_mask defined + if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: + return inputs.ne(pad_token_id).long() + else: + return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device) + + def _prepare_encoder_decoder_kwargs_for_generation( + self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None + ) -> Dict[str, Any]: + # 1. get encoder + encoder = self.get_encoder() + + # 2. prepare encoder args and encoder kwargs from model kwargs + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name if model_input_name is not None else self.main_input_name + encoder_kwargs["return_dict"] = True + encoder_kwargs[model_input_name] = inputs_tensor + model_kwargs["encoder_outputs"]: ModelOutput = encoder( + **encoder_kwargs) + + return model_kwargs + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + decoder_start_token_id: int = None, + bos_token_id: int = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + return model_kwargs.pop("decoder_input_ids") + else: + decoder_start_token_id = self._get_decoder_start_token_id( + decoder_start_token_id, bos_token_id) + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id + + def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + ) + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + + if decoder_start_token_id is not None: + return decoder_start_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "decoder_start_token_id") + and self.config.decoder.decoder_start_token_id is not None + ): + return self.config.decoder.decoder_start_token_id + elif bos_token_id is not None: + return bos_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "bos_token_id") + and self.config.decoder.bos_token_id is not None + ): + return self.config.decoder.bos_token_id + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + + @staticmethod + def _expand_inputs_for_generation( + input_ids: torch.LongTensor, + expand_size: int = 1, + is_encoder_decoder: bool = False, + attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[ModelOutput] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, + expand_size).view(-1).to(input_ids.device) + ) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = token_type_ids.index_select( + 0, expanded_return_idx) + + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask.index_select( + 0, expanded_return_idx) + + if is_encoder_decoder: + if encoder_outputs is None: + raise ValueError( + "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( + 0, expanded_return_idx.to( + encoder_outputs.last_hidden_state.device) + ) + model_kwargs["encoder_outputs"] = encoder_outputs + return input_ids, model_kwargs + + @staticmethod + def _update_model_kwargs_for_generation( + outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False + ) -> Dict[str, Any]: + # update past + if "past_key_values" in outputs: + model_kwargs["past"] = outputs.past_key_values + elif "mems" in outputs: + model_kwargs["past"] = outputs.mems + elif "past_buckets_states" in outputs: + model_kwargs["past"] = outputs.past_buckets_states + else: + model_kwargs["past"] = None + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + # update attention mask + if not is_encoder_decoder: + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + return model_kwargs + + def _reorder_cache(self, past, beam_idx): + raise NotImplementedError( + f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}" + ) + + def _get_logits_warper( + self, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + temperature: Optional[float] = None, + num_beams: Optional[int] = None, + ) -> LogitsProcessorList: + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances + used for multinomial sampling. + """ + + # init warp parameters + top_k = top_k if top_k is not None else self.config.top_k + top_p = top_p if top_p is not None else self.config.top_p + typical_p = typical_p if typical_p is not None else self.config.typical_p + temperature = temperature if temperature is not None else self.config.temperature + # instantiate warpers list + warpers = LogitsProcessorList() + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if temperature is not None and temperature != 1.0: + warpers.append(TemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + warpers.append(TopKLogitsWarper( + top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if top_p is not None and top_p < 1.0: + warpers.append(TopPLogitsWarper( + top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if typical_p is not None and typical_p < 1.0: + warpers.append(TypicalLogitsWarper( + mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + return warpers + + def _get_logits_processor( + self, + repetition_penalty: float, + no_repeat_ngram_size: int, + encoder_no_repeat_ngram_size: int, + input_ids_seq_length: int, + encoder_input_ids: torch.LongTensor, + bad_words_ids: List[List[int]], + min_length: int, + max_length: int, + eos_token_id: int, + forced_bos_token_id: int, + forced_eos_token_id: int, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], + num_beams: int, + num_beam_groups: int, + diversity_penalty: float, + remove_invalid_values: bool, + exponential_decay_length_penalty: Tuple, + logits_processor: Optional[LogitsProcessorList], + ) -> LogitsProcessorList: + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] + instances used to modify the scores of the language model head. + """ + processors = LogitsProcessorList() + + # init warp parameters + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty + no_repeat_ngram_size = ( + no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size + ) + encoder_no_repeat_ngram_size = ( + encoder_no_repeat_ngram_size + if encoder_no_repeat_ngram_size is not None + else self.config.encoder_no_repeat_ngram_size + ) + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids + min_length = min_length if min_length is not None else self.config.min_length + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty + forced_bos_token_id = ( + forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id + ) + forced_eos_token_id = ( + forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id + ) + remove_invalid_values = ( + remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values + ) + exponential_decay_length_penalty = ( + exponential_decay_length_penalty + if exponential_decay_length_penalty is not None + else self.config.exponential_decay_length_penalty + ) + # instantiate processors list + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if diversity_penalty is not None and diversity_penalty > 0.0: + processors.append( + HammingDiversityLogitsProcessor( + diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups + ) + ) + if repetition_penalty is not None and repetition_penalty != 1.0: + processors.append(RepetitionPenaltyLogitsProcessor( + penalty=repetition_penalty)) + if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: + processors.append( + NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) + if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: + if self.config.is_encoder_decoder: + processors.append(EncoderNoRepeatNGramLogitsProcessor( + encoder_no_repeat_ngram_size, encoder_input_ids)) + else: + raise ValueError( + "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" + ) + if bad_words_ids is not None: + processors.append(NoBadWordsLogitsProcessor( + bad_words_ids, eos_token_id)) + if min_length is not None and eos_token_id is not None and min_length > 0: + processors.append(MinLengthLogitsProcessor( + min_length, eos_token_id)) + if prefix_allowed_tokens_fn is not None: + processors.append(PrefixConstrainedLogitsProcessor( + prefix_allowed_tokens_fn, num_beams // num_beam_groups)) + if forced_bos_token_id is not None: + processors.append( + ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) + if forced_eos_token_id is not None: + processors.append(ForcedEOSTokenLogitsProcessor( + max_length, forced_eos_token_id)) + if remove_invalid_values is True: + processors.append(InfNanRemoveLogitsProcessor()) + if exponential_decay_length_penalty is not None: + processors.append( + ExponentialDecayLengthPenalty( + exponential_decay_length_penalty, eos_token_id, input_ids_seq_length) + ) + processors = self._merge_criteria_processor_list( + processors, logits_processor) + return processors + + def _get_stopping_criteria( + self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList] + ) -> StoppingCriteriaList: + criteria = StoppingCriteriaList() + if max_length is not None: + criteria.append(MaxLengthCriteria(max_length=max_length)) + if max_time is not None: + criteria.append(MaxTimeCriteria(max_time=max_time)) + criteria = self._merge_criteria_processor_list( + criteria, stopping_criteria) + return criteria + + def _merge_criteria_processor_list( + self, + default_list: Union[LogitsProcessorList, StoppingCriteriaList], + custom_list: Union[LogitsProcessorList, StoppingCriteriaList], + ) -> Union[LogitsProcessorList, StoppingCriteriaList]: + if len(custom_list) == 0: + return default_list + for default in default_list: + for custom in custom_list: + if type(custom) is type(default): + object_type = "stopping criteria" if isinstance( + custom, StoppingCriteria) else "logits processor" + raise ValueError( + f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to `generate`, " + f"but it has already been created with the values {default}. {default} has been created by passing the " + "corresponding arguments to generate or by the model's config default values. " + f"If you just want to change the default values of {object_type} consider passing them as arguments " + f"to `generate` instead of using a custom {object_type}." + ) + default_list.extend(custom_list) + return default_list + + def compute_beam_search_raw_logits( + self, + sequences: torch.Tensor, + scores: Tuple[torch.Tensor], + beam_indices: torch.Tensor, + eos_token_id: int = None, + ): + """Compute raw logits for beam search""" + + if not self.config.is_encoder_decoder: + raise NotImplementedError( + "Beam Search raw logits code is implemented only for enoder-decoder only models") + + # since sequences can be shorter than scores (probably due to beam search finalization) + # we always have to generate raw_logits only for generated sequences + # cut off the start tokens from generated + sequences = sequences.clone() + sequences = sequences[:, 1:] + gen_steps = sequences.shape[1] + + # align scores and beam indices according to gen_steps + # scores(gen_steps x(batch_size * num_beams) x vocab_size) + scores = scores[:gen_steps] + scores = torch.stack(scores) + _, _, vocab_size = scores.shape + + beam_indices = torch.tensor(beam_indices).T.to(scores.device) + beam_indices = beam_indices[:gen_steps, :] + batch_size = beam_indices.shape[1] + + # gen_steps x batch_size x vocab_size + beam_indices = beam_indices.unsqueeze(-1).repeat(1, 1, vocab_size) + step_wise_logits = scores.gather(dim=1, index=beam_indices) + assert step_wise_logits.shape == torch.Size( + (gen_steps, batch_size, vocab_size)) + + # finally convert to tuples + step_wise_logits = [(step_wise_logits[t], None) + for t in range(gen_steps)] + return step_wise_logits + + @ torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + max_length: Optional[int] = None, + min_length: Optional[int] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[bool] = None, + num_beams: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + bad_words_ids: Optional[Iterable[int]] = None, + force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + no_repeat_ngram_size: Optional[int] = None, + encoder_no_repeat_ngram_size: Optional[int] = None, + num_return_sequences: Optional[int] = None, + max_time: Optional[float] = None, + max_new_tokens: Optional[int] = None, + decoder_start_token_id: Optional[int] = None, + use_cache: Optional[bool] = None, + num_beam_groups: Optional[int] = None, + diversity_penalty: Optional[float] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), + stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), + constraints: Optional[List[Constraint]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + forced_bos_token_id: Optional[int] = None, + forced_eos_token_id: Optional[int] = None, + remove_invalid_values: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, + **model_kwargs, + ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. The method supports the following + generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: + + - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and + `do_sample=False`. + - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and + `do_sample=False`. + - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if + `num_beams>1` and `do_sample=True`. + - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if + `num_beams>1` and `num_beam_groups>1`. + - *constrained beam-search decoding* by calling + [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or + `force_words_ids!=None`. + + + + Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as + defined in the model's config (`config.json`) which in turn defaults to the + [`~modeling_utils.PretrainedConfig`] of the model. + + + + Most of these parameters are explained in more detail in [this blog + post](https://huggingface.co/blog/how-to-generate). + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + max_length (`int`, *optional*, defaults to `model.config.max_length`): + The maximum length of the sequence to be generated. + max_new_tokens (`int`, *optional*, defaults to None): + The maximum numbers of tokens to generate, ignore the current number of tokens. Use either + `max_new_tokens` or `max_length` but not both, they serve the same purpose. + min_length (`int`, *optional*, defaults to 10): + The minimum length of the sequence to be generated. + do_sample (`bool`, *optional*, defaults to `False`): + Whether or not to use sampling ; use greedy decoding otherwise. + early_stopping (`bool`, *optional*, defaults to `False`): + Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. + num_beams (`int`, *optional*, defaults to 1): + Number of beams for beam search. 1 means no beam search. + temperature (`float`, *optional*, defaults to 1.0): + The value used to module the next token probabilities. + top_k (`int`, *optional*, defaults to 50): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional*, defaults to 1.0): + If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher + are kept for generation. + repetition_penalty (`float`, *optional*, defaults to 1.0): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + bos_token_id (`int`, *optional*): + The id of the *beginning-of-sequence* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + length_penalty (`float`, *optional*, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the + model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer + sequences. + no_repeat_ngram_size (`int`, *optional*, defaults to 0): + If set to int > 0, all ngrams of that size can only occur once. + encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): + If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the + `decoder_input_ids`. + bad_words_ids(`List[List[int]]`, *optional*): + List of token ids that are not allowed to be generated. In order to get the token ids of the words that + should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, + add_special_tokens=False).input_ids`. + force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*): + List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple + list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, + this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), + where one can allow different forms of each word. + num_return_sequences(`int`, *optional*, defaults to 1): + The number of independently computed returned sequences for each element in the batch. + max_time(`float`, *optional*, defaults to None): + The maximum amount of time you allow the computation to run for in seconds. generation will still + finish the current pass after allocated time has been passed. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens + that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape + as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask) + decoder_start_token_id (`int`, *optional*): + If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. + use_cache: (`bool`, *optional*, defaults to `True`): + Whether or not the model should use the past last key/values attentions (if applicable to the model) to + speed up decoding. + num_beam_groups (`int`, *optional*, defaults to 1): + Number of groups to divide `num_beams` into in order to ensure diversity among different groups of + beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. + diversity_penalty (`float`, *optional*, defaults to 0.0): + This value is subtracted from a beam's score if it generates a token same as any beam from other group + at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is + enabled. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and a + model's config. If a logit processor is passed that is already created with the arguments or a model's + config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + model's config. If a stopping criteria is passed that is already created with the arguments or a + model's config an error is thrown. This feature is intended for advanced users. + constraints (`List[Constraint]`, *optional*): + Custom constraints that can be added to the generation to ensure that the output will contain the use + of certain tokens as defined by `Constraint` objects, in the most sensible way possible. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + forced_bos_token_id (`int`, *optional*): + The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful + for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be + the target language token. + forced_eos_token_id (`int`, *optional*): + The id of the token to force as the last generated token when `max_length` is reached. + remove_invalid_values (`bool`, *optional*): + Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to + crash. Note that using `remove_invalid_values` can slow down generation. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + exponential_decay_length_penalty (`tuple(int, float)`, *optional*): + This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been + generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates + where penalty starts and `decay_factor` represents the factor of exponential decay + + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model + is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs + should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation_utils.GreedySearchDecoderOnlyOutput`], + - [`~generation_utils.SampleDecoderOnlyOutput`], + - [`~generation_utils.BeamSearchDecoderOnlyOutput`], + - [`~generation_utils.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation_utils.GreedySearchEncoderDecoderOutput`], + - [`~generation_utils.SampleEncoderDecoderOutput`], + - [`~generation_utils.BeamSearchEncoderDecoderOutput`], + - [`~generation_utils.BeamSampleEncoderDecoderOutput`] + + Examples: + + Greedy Decoding: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + >>> # generate up to 30 tokens + >>> outputs = model.generate(input_ids, do_sample=False, max_length=30) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] + ``` + + Multinomial Sampling: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + >>> # sample up to 30 tokens + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.generate(input_ids, do_sample=True, max_length=30) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] + ``` + + Beam-search decoding: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + + >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") + + >>> sentence = "Paris is one of the densest populated areas in Europe." + >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids + + >>> outputs = model.generate(input_ids) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] + ```""" + # 1. Set generation parameters if not already defined + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + num_beams = num_beams if num_beams is not None else self.config.num_beams + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups + do_sample = do_sample if do_sample is not None else self.config.do_sample + num_return_sequences = ( + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences + ) + + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + if eos_token_id is None and hasattr(self.config, "decoder"): + eos_token_id = self.config.decoder.eos_token_id + + if pad_token_id is None and eos_token_id is not None: + # special case if pad_token_id is not defined + logger.warning( + f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + pad_token_id = eos_token_id + + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # 2. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, bos_token_id, model_kwargs) + batch_size = inputs_tensor.shape[0] + + # 3. Define other model kwargs + model_kwargs["output_attentions"] = output_attentions + model_kwargs["output_hidden_states"] = output_hidden_states + model_kwargs["use_cache"] = use_cache + + accepts_attention_mask = "attention_mask" in set( + inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, pad_token_id, eos_token_id + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created + # and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name + ) + + # 4. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=decoder_start_token_id, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + ) + else: + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor + + input_ids_seq_length = input_ids.shape[-1] + + # 5. Prepare `max_length` depending on other stopping criteria + # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens` + if max_length is None and max_new_tokens is not None: + max_length = max_new_tokens + input_ids_seq_length + elif max_length is not None and max_new_tokens is not None: + # Both are set, this is odd, raise a warning + warnings.warn( + "Both `max_length` and `max_new_tokens` have been set " + f"but they serve the same purpose. `max_length` {max_length} " + f"will take priority over `max_new_tokens` {max_new_tokens}.", + UserWarning, + ) + # default to config if still None + max_length = max_length if max_length is not None else self.config.max_length + + if input_ids_seq_length >= max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. " + "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." + ) + + # 6. determine generation mode + is_constraint_gen_mode = constraints is not None or force_words_ids is not None + is_greedy_gen_mode = ( + (num_beams == 1) and (num_beam_groups == + 1) and do_sample is False and not is_constraint_gen_mode + ) + is_sample_gen_mode = ( + (num_beams == 1) and (num_beam_groups == + 1) and do_sample is True and not is_constraint_gen_mode + ) + is_beam_gen_mode = ( + (num_beams > 1) and (num_beam_groups == + 1) and do_sample is False and not is_constraint_gen_mode + ) + is_beam_sample_gen_mode = ( + (num_beams > 1) and (num_beam_groups == + 1) and do_sample is True and not is_constraint_gen_mode + ) + is_group_beam_gen_mode = (num_beams > 1) and ( + num_beam_groups > 1) and not is_constraint_gen_mode + + if num_beam_groups > num_beams: + raise ValueError( + "`num_beam_groups` has to be smaller or equal to `num_beams`") + if is_group_beam_gen_mode and do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." + ) + + # 7. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + bad_words_ids=bad_words_ids, + min_length=min_length, + max_length=max_length, + eos_token_id=eos_token_id, + forced_bos_token_id=forced_bos_token_id, + forced_eos_token_id=forced_eos_token_id, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + diversity_penalty=diversity_penalty, + remove_invalid_values=remove_invalid_values, + exponential_decay_length_penalty=exponential_decay_length_penalty, + logits_processor=logits_processor, + ) + + # 8. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria + ) + + # 9. go into different generation modes + if is_greedy_gen_mode: + if num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." + ) + + # 10. run greedy search + return self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_sample_gen_mode: + # 10. prepare logits warper + logits_warper = self._get_logits_warper( + top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams + ) + + # 11. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, + expand_size=num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 12. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_beam_gen_mode: + if num_return_sequences > num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now.") + + # 10. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=self.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + num_beam_hyps_to_keep=num_return_sequences, + ) + # 11. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + # 12. run beam search + return self.beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_beam_sample_gen_mode: + # 10. prepare logits warper + logits_warper = self._get_logits_warper( + top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now.") + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size * num_return_sequences, + num_beams=num_beams, + device=self.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + ) + + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, + expand_size=num_beams * num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run beam sample + return self.beam_sample( + input_ids, + beam_scorer, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_group_beam_gen_mode: + if num_return_sequences > num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if num_beams % num_beam_groups != 0: + raise ValueError( + "`num_beams` should be divisible by `num_beam_groups` for group beam search.") + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now.") + + # 10. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + max_length=stopping_criteria.max_length, + device=self.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + num_beam_hyps_to_keep=num_return_sequences, + num_beam_groups=num_beam_groups, + ) + # 11. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + # 12. run beam search + return self.group_beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_constraint_gen_mode: + if num_return_sequences > num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now.") + + if num_beams <= 1: + raise ValueError( + "`num_beams` needs to be greater than 1 for constrained genertation.") + + if do_sample: + raise ValueError( + "`do_sample` needs to be false for constrained generation.") + + if num_beam_groups is not None and num_beam_groups > 1: + raise ValueError( + "`num_beam_groups` not supported yet for constrained generation.") + + final_constraints = [] + if constraints is not None: + final_constraints = constraints + + if force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + f"of positive integers, but is {force_words_ids}." + ) + + if not isinstance(force_words_ids, list) or len(force_words_ids) == 0: + typeerror() + + for word_ids in force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any(not isinstance(token_ids, list) for token_ids in word_ids): + typeerror() + if any( + any((not isinstance(token_id, int) or token_id < 0) + for token_id in token_ids) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 10. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=num_beams, + device=self.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + num_beam_hyps_to_keep=num_return_sequences, + ) + # 11. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + # 12. run beam search + return self.constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + def greedy_search( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[GreedySearchOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be + used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`] + or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "It might be possible to" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> outputs = model.greedy_search( + ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ["It might be possible to get a better understanding of the nature of the problem, but it's not"] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if ( + return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get( + "attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get( + "hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + cur_len = input_ids.shape[-1] + + this_peer_finished = False # used by synced_gpus only + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( + outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # pre-process distribution + next_tokens_scores = logits_processor( + input_ids, next_token_logits, model_inputs) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + \ + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + cur_len = cur_len + 1 + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul( + (next_tokens != eos_token_id).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GreedySearchEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return GreedySearchDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids + + def sample( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[SampleOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] + ```""" + + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length) + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if ( + return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get( + "attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get( + "hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + cur_len = input_ids.shape[-1] + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits_raw = outputs.logits[:, -1, :].clone() + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor( + input_ids, next_token_logits, model_inputs=model_inputs) + next_token_scores = logits_warper( + input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += ((next_token_logits_raw, next_token_scores),) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( + outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + \ + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + cur_len = cur_len + 1 + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul( + (next_tokens != eos_token_id).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return SampleEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return SampleDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids + + def beam_search( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[BeamSearchOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **beam search decoding** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + beam_scorer (`BeamScorer`): + An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... BeamSearchScorer, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + + >>> # lets run beam search using 3 beams + >>> num_beams = 3 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()( + ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ... ) + ... } + + >>> # instantiate beam scorer + >>> beam_scorer = BeamSearchScorer( + ... batch_size=1, + ... num_beams=num_beams, + ... device=model.device, + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + + >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt bist du?'] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length) + if len(stopping_criteria) == 0: + warnings.warn( + "You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + beam_indices = ( + tuple(() for _ in range(batch_beam_size)) if ( + return_dict_in_generate and output_scores) else None + ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if ( + return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get( + "attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get( + "hidden_states") if output_hidden_states else None + ) + + beam_scores = torch.zeros( + (batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False # used by synced_gpus only + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs) + + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + next_token_logits_raw = next_token_logits.clone() + + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `nn.functional.log_softmax` operation. + next_token_logits = self.adjust_logits_during_generation( + next_token_logits, cur_len=cur_len) + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + next_token_scores_processed = logits_processor( + input_ids, next_token_scores, model_inputs=model_inputs) + next_token_scores = next_token_scores_processed + \ + beam_scores[:, None].expand_as(next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_logits_raw,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( + outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view( + batch_size, num_beams * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True + ) + + next_indices = torch_int_div(next_tokens, vocab_size) + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat( + [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache( + model_kwargs["past"], beam_idx) + + if return_dict_in_generate and output_scores: + beam_indices = tuple( + (beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + else: + num_return_sequences = beam_scorer.num_beam_hyps_to_keep + # return only as many indices as sequences + beam_indices = tuple( + (beam_indices[i * num_beams: i * num_beams + + num_return_sequences] for i in range(batch_size)) + ) + beam_indices = sum(beam_indices, ()) + + step_wise_raw_logits = self.compute_beam_search_raw_logits( + sequence_outputs["sequences"].clone(), + scores, + beam_indices, + eos_token_id) + + if self.config.is_encoder_decoder: + return BeamSearchEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=step_wise_raw_logits, # raw logits + beam_indices=beam_indices, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSearchDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + beam_indices=beam_indices, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] + + def beam_sample( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[BeamSampleOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **beam search multinomial + sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + beam_scorer (`BeamScorer`): + A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... BeamSearchScorer, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + >>> # lets run beam search using 3 beams + >>> num_beams = 3 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()( + ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ... ) + ... } + + >>> # instantiate beam scorer + >>> beam_scorer = BeamSearchScorer( + ... batch_size=1, + ... max_length=model.config.max_length, + ... num_beams=num_beams, + ... device=model.device, + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> outputs = model.beam_sample( + ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt bist du?'] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + beam_indices = ( + tuple(() for _ in range(batch_beam_size)) if ( + return_dict_in_generate and output_scores) else None + ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if ( + return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get( + "attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get( + "hidden_states") if output_hidden_states else None + ) + + beam_scores = torch.zeros( + (batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False # used by synced_gpus only + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs) + + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits_raw = outputs.logits[:, -1, :] + + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `nn.functional.log_softmax` operation. + next_token_logits = self.adjust_logits_during_generation( + next_token_logits_raw, cur_len=cur_len) + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + next_token_scores_processed = logits_processor( + input_ids, next_token_logits, model_inputs=model_inputs) + next_token_scores = next_token_scores_processed + \ + beam_scores[:, None].expand_as(next_token_scores) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + # return raw scores instead of post-processed + scores += ((next_token_logits_raw, next_token_scores),) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( + outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view( + batch_size, num_beams * vocab_size) + + probs = nn.functional.softmax(next_token_scores, dim=-1) + + next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) + next_token_scores = torch.gather( + next_token_scores, -1, next_tokens) + + next_token_scores, _indices = torch.sort( + next_token_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, _indices) + + next_indices = torch_int_div(next_tokens, vocab_size) + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat( + [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache( + model_kwargs["past"], beam_idx) + + if return_dict_in_generate and output_scores: + beam_indices = tuple( + (beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + else: + num_return_sequences = beam_scorer.num_beam_hyps_to_keep + # return only as many indices as sequences + beam_indices = tuple( + (beam_indices[i * num_beams: i * num_beams + + num_return_sequences] for i in range(batch_size)) + ) + beam_indices = sum(beam_indices, ()) + + if self.config.is_encoder_decoder: + return BeamSampleEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + beam_indices=beam_indices, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSampleDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + beam_indices=beam_indices, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] + + def group_beam_search( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ): + r""" + Generates sequences of token ids for models with a language modeling head using **diverse beam search + decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + beam_scorer (`BeamScorer`): + An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + + model_kwargs: + Additional model specific kwargs that will be forwarded to the `forward` function of the model. If + model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if + `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a + [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... HammingDiversityLogitsProcessor, + ... BeamSearchScorer, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + + >>> # lets run diverse beam search using 6 beams + >>> num_beams = 6 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()( + ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ... ) + ... } + + >>> # instantiate beam scorer + >>> beam_scorer = BeamSearchScorer( + ... batch_size=1, + ... max_length=model.config.max_length, + ... num_beams=num_beams, + ... device=model.device, + ... num_beam_groups=3, + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), + ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + + >>> outputs = model.group_beam_search( + ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt bist du?'] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + device = input_ids.device + + batch_beam_size, cur_len = input_ids.shape + + if return_dict_in_generate and output_scores: + beam_indices = [tuple(() for _ in range( + num_sub_beams * batch_size)) for _ in range(num_beam_groups)] + else: + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if ( + return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get( + "attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get( + "hidden_states") if output_hidden_states else None + ) + + beam_scores = torch.full( + (batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False # used by synced_gpus only + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # predicted tokens in cur_len step + current_tokens = torch.zeros( + batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros( + batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs) + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + if output_scores: + processed_score = torch.zeros_like(outputs.logits[:, -1, :]) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of current group only + next_token_logits_raw = outputs.logits[batch_group_indices, -1, :] + + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `nn.functional.log_softmax` operation. + next_token_logits = self.adjust_logits_during_generation( + next_token_logits_raw, cur_len=cur_len) + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * group_size, vocab_size) + vocab_size = next_token_scores.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx, model_inputs=model_inputs + ) + next_token_scores = next_token_scores_processed + \ + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as( + next_token_scores_processed) + + if output_scores: + processed_score[batch_group_indices] = next_token_logits_raw + + # reshape for beam search + next_token_scores = next_token_scores.view( + batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch_int_div(next_tokens, vocab_size) + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + if return_dict_in_generate and output_scores: + beam_indices[beam_group_idx] = tuple( + beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) + ) + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat( + [group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * + torch_int_div(beam_idx, group_size) + + group_start_idx + (beam_idx % group_size) + ) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (processed_score,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( + outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + input_ids = torch.cat( + [input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache( + model_kwargs["past"], reordering_indices) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + else: + beam_indices = sum(beam_indices, ()) + num_return_sequences = beam_scorer.num_beam_hyps_to_keep + # return only as many indices as sequences + beam_indices = tuple( + (beam_indices[i * num_beams: i * num_beams + + num_return_sequences] for i in range(batch_size)) + ) + beam_indices = sum(beam_indices, ()) + + if self.config.is_encoder_decoder: + return BeamSearchEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + beam_indices=beam_indices, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSearchDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] + + def constrained_beam_search( + self, + input_ids: torch.LongTensor, + constrained_beam_scorer: ConstrainedBeamSearchScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = None, + **model_kwargs, + ) -> Union[BeamSearchOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **constrained beam search + decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + constrained_beam_scorer (`ConstrainedBeamSearchScorer`): + A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation, while satisfying a list of positive constraints. For more information, the + documentation of [`ConstrainedBeamSearchScorer`] should be read. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... ConstrainedBeamSearchScorer, + ... PhrasalConstraint, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + + >>> # lets run beam search using 3 beams + >>> num_beams = 3 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()( + ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ... ) + ... } + + >>> constraint_str = "Sie" + >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token + >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] + + + >>> # instantiate beam scorer + >>> beam_scorer = ConstrainedBeamSearchScorer( + ... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + + >>> outputs = model.constrained_beam_search( + ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt sind Sie?'] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length) + if len(stopping_criteria) == 0: + warnings.warn( + "You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if ( + return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get( + "attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get( + "hidden_states") if output_hidden_states else None + ) + + batch_size = len(constrained_beam_scorer._beam_hyps) + num_beams = constrained_beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.zeros( + (batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False # used by synced_gpus only + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs) + + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits_raw = outputs.logits[:, -1, :] + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `nn.functional.log_softmax` operation. + next_token_logits = self.adjust_logits_during_generation( + next_token_logits_raw, cur_len=cur_len) + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + next_token_scores_processed = logits_processor( + input_ids, next_token_scores, model_inputs=model_inputs) + + scores_for_all_vocab = next_token_scores_processed.clone() + + next_token_scores = next_token_scores_processed + \ + beam_scores[:, None].expand_as(next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += ((next_token_logits_raw, next_token_scores),) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( + outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view( + batch_size, num_beams * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True + ) + + next_indices = (next_tokens / vocab_size).long() + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = constrained_beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + scores_for_all_vocab, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat( + [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache( + model_kwargs["past"], beam_idx) + + # increase cur_len + cur_len = cur_len + 1 + + if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + sequence_outputs = constrained_beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + if self.config.is_encoder_decoder: + return BeamSearchEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSearchDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] + + +def top_k_top_p_filtering( + logits: torch.FloatTensor, + top_k: int = 0, + top_p: float = 1.0, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, +) -> torch.FloatTensor: + """ + Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + + Args: + logits: logits distribution shape (batch size, vocabulary size) + top_k (`int`, *optional*, defaults to 0): + If > 0, only keep the top k tokens with highest probability (top-k filtering) + top_p (`float`, *optional*, defaults to 1.0): + If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus + filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimumber of tokens we keep per batch example in the output. + + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( + None, logits + ) + + if 0 <= top_p <= 1.0: + logits = TopPLogitsWarper( + top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) + + return logits + + +def override_generation_routines(cls): + bases = list(cls.__bases__) + for base_ix in range(len(bases)): + if bases[base_ix] == GenerationMixin: + bases[base_ix] = GenerationMixinWithRawScores + + # recursively look up + if bases[base_ix] != object: + bases[base_ix] = override_generation_routines(bases[base_ix]) + + cls.__bases__ = tuple(bases) + return cls diff --git a/benchmark/torch/RL4LMs/utils/kl_controller.py b/benchmark/torch/RL4LMs/utils/kl_controller.py new file mode 100644 index 000000000..ad2d3a7ab --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/kl_controller.py @@ -0,0 +1,32 @@ +from typing import Optional, Dict, Any +import torch + + +class KLController: + def __init__(self, kl_coeff: float, target_kl: Optional[float] = None) -> None: + self._kl_coeff = kl_coeff + self._target_kl = target_kl + + def step(self, kl_div: torch.tensor): + """ + Adapts the KL coeff + """ + if self._target_kl is not None: + diff_to_target = (kl_div - self._target_kl) / self._target_kl + e_t = torch.clip(diff_to_target, -0.2, 0.2).item() + self._kl_coeff = self._kl_coeff * (1 + 0.1 * e_t) + + @property + def kl_coeff(self): + return self._kl_coeff + + def get_state_dict(self) -> Dict[str, Any]: + state = { + "target_kl": self._target_kl, + "current_kl_coeff": self._kl_coeff + } + return state + + def load_from_state_dict(self, state_dict: Dict[str, Any]): + self._kl_coeff = state_dict["current_kl_coeff"] + self._target_kl = state_dict["target_kl"] \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/metric_util.py b/benchmark/torch/RL4LMs/utils/metric_util.py new file mode 100644 index 000000000..e06c4aace --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/metric_util.py @@ -0,0 +1,644 @@ +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from transformers import PreTrainedModel +import torch +from typing import List, Dict, Tuple, Any +from abc import abstractmethod +import numpy as np +from datasets import load_metric +from gem_metrics.msttr import MSTTR +from gem_metrics.ngrams import NGramStats +from gem_metrics.texts import Predictions +from tqdm import tqdm +import copy +import rouge +import json +from tempfile import TemporaryDirectory +import subprocess +import os +import jsonlines + +# Cider, Spice, SummaCConv, SummaCZS, compute_parent, + + +def compute_bleu(predicted_texts: List[str], + raw_tables: List[dict]): + + def _read_results(path): + try: + with open(path) as fp: + score = json.load(fp)["score"]/100 + except: + score = 0.0 + return score + + with TemporaryDirectory() as temp_dir: + + # write tables + target_path = os.path.join(temp_dir, "samples.jsonl") + with jsonlines.open(target_path, "w") as writer: + for table in raw_tables: + writer.write(table) + + # write gen texts + prediction_path = os.path.join(temp_dir, "predictions.txt") + with open(prediction_path, "w") as fp: + predicted_texts = '\n'.join(predicted_texts) + fp.write(predicted_texts) + + cmd = ['bash', 'totto_bleu_eval.sh', + '-p', prediction_path, + '-t', target_path, + '--output_dir', temp_dir, + ] + subprocess.check_call(cmd, + cwd=os.path.dirname(os.path.abspath(__file__)), + stdout=subprocess.DEVNULL) + + # read the results back + bleu_overall = _read_results( + os.path.join(temp_dir, "bleu_overall.json")) + bleu_overlap = _read_results( + os.path.join(temp_dir, "bleu_overlap.json")) + bleu_non_overlap = _read_results( + os.path.join(temp_dir, "bleu_non_overlap.json")) + return bleu_overall, bleu_overlap, bleu_non_overlap + + + + +class BaseMetric: + @abstractmethod + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ): + """ + Returns a dict where key is the metric name and value is again a dict consisting of tuple of individual scores (if any) and corpus level score + + eg. { + metric_name: (individual_scores, corpus_level_score) + "metric_1": ([0.5, 0.5, 0.8], 0.1) + } + + """ + raise NotImplementedError + + +class LearnedRewardMetric(BaseMetric): + def __init__( + self, + model_name: str, + label_ix: int, + batch_size: int, + include_prompt_for_eval: bool = True, + ) -> None: + super().__init__() + self._device = "cuda" if torch.cuda.is_available() else "cpu" + self._tokenizer = AutoTokenizer.from_pretrained(model_name) + self._tokenizer.truncation_side = "left" + self._model = AutoModelForSequenceClassification.from_pretrained(model_name).to( + self._device + ) + self._label_ix = label_ix + self._batch_size = batch_size + self._include_prompt_for_eval = include_prompt_for_eval + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ) -> Dict[str, float]: + all_scores = [] + current_ix = 0 + n_texts = len(generated_texts) + while current_ix < n_texts: + batch_gen_texts = generated_texts[ + current_ix : current_ix + self._batch_size + ] + batch_prompt_texts = prompt_texts[ + current_ix : current_ix + self._batch_size + ] + + if self._include_prompt_for_eval: + batch_gen_texts = [ + (prompt + gen) + for gen, prompt in zip(batch_gen_texts, batch_prompt_texts) + ] + encoded = self._tokenizer( + batch_gen_texts, return_tensors="pt", truncation=True, padding=True + ) + with torch.no_grad(): + outputs = self._model( + input_ids=encoded.input_ids.to(self._device), + attention_mask=encoded.attention_mask.to(self._device), + ) + scores = torch.softmax(outputs.logits, dim=1) + scores = scores[:, self._label_ix].tolist() + all_scores.extend(scores) + current_ix += self._batch_size + + metric_dict = { + "semantic/learned_automodel_metric": (all_scores, np.mean(all_scores)) + } + return metric_dict + + +class MeteorMetric(BaseMetric): + def __init__(self) -> None: + super().__init__() + self._metric = load_metric("meteor") + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ): + + score = self._metric.compute( + predictions=generated_texts, references=reference_texts + )["meteor"] + + metric_dict = {"lexical/meteor": (None, score)} + return metric_dict + + +class RougeMetric(BaseMetric): + def __init__(self, use_single_ref: bool = True) -> None: + super().__init__() + self._metric = load_metric("rouge") + self._use_single_ref = use_single_ref + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ): + if self._use_single_ref: + # TBD: this is required for CNN/DM dataset, without this we get low scores + # TBD: needs investigation + ref_texts = [ref[0] for ref in reference_texts] + else: + ref_texts = reference_texts + + metric_results = self._metric.compute( + predictions=generated_texts, references=ref_texts, use_stemmer=True + ) + score_keys = ["rouge1", "rouge2", "rougeL", "rougeLsum"] + metric_dict = {} + for rouge_type in score_keys: + rouge_score = metric_results[rouge_type].mid.fmeasure + metric_dict[f"lexical/rouge_{rouge_type}"] = (None, rouge_score) + return metric_dict + + +class BERTScoreMetric(BaseMetric): + def __init__(self, language: str) -> None: + super().__init__() + self._metric = load_metric("bertscore") + self._language = language + # since models are loaded heavily on cuda:0, use the last one to avoid memory + self._last_gpu = f"cuda:{torch.cuda.device_count() - 1}" + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ) -> Tuple[List[float], float]: + with torch.no_grad(): + metric_results = self._metric.compute( + predictions=generated_texts, + references=reference_texts, + lang=self._language, + device=self._last_gpu, + ) + bert_scores = metric_results["f1"] + corpus_level_score = np.mean(bert_scores) + metric_dict = {"semantic/bert_score": (bert_scores, corpus_level_score)} + return metric_dict + + +class BLEUMetric(BaseMetric): + def __init__(self) -> None: + super().__init__() + self._metric = load_metric("bleu") + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ) -> Tuple[List[float], float]: + + tokenized_predictions = [] + tokenized_reference_texts = [] + for prediction, refs in zip(generated_texts, reference_texts): + tokenized_prediction = prediction.split() + tokenized_refs = [ref.split() for ref in refs] + tokenized_predictions.append(tokenized_prediction) + tokenized_reference_texts.append(tokenized_refs) + + try: + metric_results = self._metric.compute( + predictions=tokenized_predictions, references=tokenized_reference_texts + ) + bleu_score = metric_results["bleu"] + metric_dict = {"lexical/bleu": (None, bleu_score)} + return metric_dict + except Exception as e: + return {"lexical/bleu": (None, "n/a")} + + +class BLEURTMetric(BaseMetric): + def __init__(self, config_name: str = None) -> None: + super().__init__() + self._metric = load_metric("bleurt", config_name=config_name) + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ) -> Tuple[List[float], float]: + metric_results = self._metric.compute( + predictions=generated_texts, references=reference_texts + ) + corpus_score = np.mean(metric_results["scores"]) + metric_dict = {"semantic/bleurt": (metric_results["scores"], corpus_score)} + return metric_dict + + +def get_generated_and_predictions( + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + split_name: str, +): + split_name = "" if split_name is None else split_name + preds = {} + refs = {} + for ix, (prompt_text, gen_text, ref_text) in enumerate( + zip(prompt_texts, generated_texts, reference_texts) + ): + preds[split_name + prompt_text] = [gen_text] + refs[split_name + prompt_text] = ref_text + return preds, refs + + +def get_individual_scores( + prompt_texts: List[str], split_name: str, scores_dict: Dict[str, float] +): + split_name = "" if split_name is None else split_name + scores = [] + for prompt_text in prompt_texts: + scores.append(scores_dict.get(split_name + prompt_text, "n/a")) + return scores + + + + +class DiversityMetrics(BaseMetric): + def __init__(self, window_size: int = 100) -> None: + self._msttr_metric = MSTTR(window_size=window_size) + self._n_gram_metric = NGramStats() + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ) -> Tuple[List[float], float]: + + predictions = Predictions(data={"filename": "", "values": generated_texts}) + diversity_metrics = {} + msttr_metrics = self._msttr_metric.compute(None, predictions) + n_gram_metrics = self._n_gram_metric.compute(None, predictions) + + for key, value in msttr_metrics.items(): + diversity_metrics[f"diversity_metrics/{key}"] = (None, value) + for key, value in n_gram_metrics.items(): + diversity_metrics[f"diversity_metrics/{key}"] = (None, value) + + return diversity_metrics + + +# class SummaCZSMetric(BaseMetric): +# """ +# Consistency metric for summarization +# +# https://github.com/tingofurro/summac/ +# """ +# +# def __init__(self, **kwargs) -> None: +# super().__init__() +# self._scorer = SummaCZS(**kwargs) +# +# def compute( +# self, +# prompt_texts: List[str], +# generated_texts: List[str], +# reference_texts: List[List[str]], +# meta_infos: List[Dict[str, Any]] = None, +# model: PreTrainedModel = None, +# split_name: str = None, +# ) -> Tuple[List[float], float]: +# metric_results = self._scorer.score(prompt_texts, generated_texts) +# corpus_score = np.mean(metric_results["scores"]) +# metric_dict = {"consistency/summaczs": (metric_results["scores"], corpus_score)} +# return metric_dict + + + + + +class Perplexity(BaseMetric): + def __init__( + self, + stride: int, + tokenizer_id: str, + model_type: str = "causal", + use_text_from_meta_data: bool = False, + ) -> None: + super().__init__() + self._tokenizer_id = tokenizer_id + self._model_type = model_type + self._stride = stride + self._use_text_from_meta_data = use_text_from_meta_data + + def get_device(self, model: PreTrainedModel): + try: + return model.transformer.first_device + except: + return model.device + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ) -> Tuple[List[float], float]: + if split_name == "train": + return {} + + if self._model_type != "causal": + raise NotImplementedError + + # we compute perplexity on reference texts + if self._use_text_from_meta_data: + reference_texts = [info["reference"] for info in meta_infos] + else: + reference_texts = [ref for refs in reference_texts for ref in refs] + tokenizer = AutoTokenizer.from_pretrained(self._tokenizer_id) + encodings = tokenizer("\n\n".join(reference_texts), return_tensors="pt") + + device = self.get_device(model) + + nlls = [] + max_length = model.config.n_positions + for i in tqdm(range(0, encodings.input_ids.size(1), self._stride)): + begin_loc = max(i + self._stride - max_length, 0) + end_loc = min(i + self._stride, encodings.input_ids.size(1)) + trg_len = end_loc - i # may be different from stride on last loop + + # run on last device + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + with torch.no_grad(): + outputs = model(input_ids, labels=target_ids) + neg_log_likelihood = outputs[0] * trg_len + + nlls.append(neg_log_likelihood) + + return { + "fluency_metrics/perplexity": ( + None, + torch.exp(torch.stack(nlls).sum() / end_loc).item(), + ) + } + + + + + +class BLEUToTTo: + """ + Official version + """ + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]], + model: PreTrainedModel = None, + split_name: str = None, + ): + tables = [info["raw_table"] for info in meta_infos] + bleu_overall, bleu_overlap, bleu_non_overlap = compute_bleu( + generated_texts, tables + ) + + metric_results = { + "table_to_text/bleu_overall": (None, bleu_overall), + "table_to_text/bleu_overlap": (None, bleu_overlap), + "table_to_text/bleu_non_overlap": (None, bleu_non_overlap), + } + return metric_results + + +class RougeLMax(BaseMetric): + def __init__(self, **args) -> None: + super().__init__() + self._metric = rouge.Rouge(metrics=["rouge-l"], **args) + + def _rouge_max_over_ground_truths(self, prediction, ground_truths): + """ + Computes max of Rouge-L (https://github.com/allenai/unifiedqa/blob/bad6ef339db6286f0d8bd0661a2daeeb0f800f59/evaluation/evaluate_narrativeqa.py#L25) + """ + # load stemmer + self._metric.load_stemmer(self._metric.ensure_compatibility) + + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = self._metric.get_scores(prediction, [ground_truth]) + scores_for_ground_truths.append(score) + max_score = copy.deepcopy(score) + max_score = max([score["rouge-l"]["f"] for score in scores_for_ground_truths]) + return max_score + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ): + all_scores = [] + for gen_text, ref_texts in zip(generated_texts, reference_texts): + rouge_max_score = self._rouge_max_over_ground_truths(gen_text, ref_texts) + all_scores.append(rouge_max_score) + + metric_dict = {"lexical/rouge_l_max": (all_scores, np.mean(all_scores))} + return metric_dict + + +class SacreBLEUMetric(BaseMetric): + def __init__(self, **args) -> None: + super().__init__() + self._args = args + self._metric = load_metric("sacrebleu") + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ) -> Tuple[List[float], float]: + + metric_results = self._metric.compute( + predictions=generated_texts, references=reference_texts, **self._args + ) + bleu_score = metric_results["score"] / 100 + metric_dict = {"lexical/sacrebleu": (None, bleu_score)} + return metric_dict + + +class TERMetric(BaseMetric): + def __init__(self) -> None: + super().__init__() + self._metric = load_metric("ter") + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ) -> Tuple[List[float], float]: + + metric_results = self._metric.compute( + predictions=generated_texts, references=reference_texts + ) + score = metric_results["score"] / 100 + metric_dict = {"lexical/ter": (None, score)} + return metric_dict + + +class chrFmetric(BaseMetric): + def __init__(self) -> None: + super().__init__() + self._metric = load_metric("chrf") + + def compute( + self, + prompt_texts: List[str], + generated_texts: List[str], + reference_texts: List[List[str]], + meta_infos: List[Dict[str, Any]] = None, + model: PreTrainedModel = None, + split_name: str = None, + ) -> Tuple[List[float], float]: + + metric_results = self._metric.compute( + predictions=generated_texts, references=reference_texts + ) + score = metric_results["score"] / 100 + metric_dict = {"lexical/chrf": (None, score)} + return metric_dict + + + + +if __name__ == "__main__": + prompt_texts = [""] + gen_texts = ["Hello there general kenobi", "foo bar foobar"] + reference_texts = [["Hello there general kenobi"], ["foo bar foobar"]] + # metric = MeteorMetric() + # print(metric.compute(prompt_texts, gen_texts, reference_texts)) + + # metric = RougeMetric() + # print(metric.compute(prompt_texts, gen_texts, reference_texts)) + + # metric = SacreBLEUMetric(tokenize="intl") + # print(metric.compute(prompt_texts, gen_texts, reference_texts)) + + # metric = TERMetric() + # print(metric.compute(prompt_texts, gen_texts, reference_texts)) + + # metric = chrFmetric() + # print(metric.compute(prompt_texts, gen_texts, reference_texts)) + + # metric = BERTScoreMetric(language="en") + # print(metric.compute(prompt_texts, gen_texts, reference_texts)) + + # metric = BLEUMetric() + # print(metric.compute(prompt_texts, gen_texts, reference_texts)) + + # metric = BLEURTMetric() + # print(metric.compute(prompt_texts, gen_texts, reference_texts)) + + # metric = DiversityMetrics() + # print(metric.compute(prompt_texts, gen_texts, reference_texts)) + + # document = """Jeff joined Microsoft in 1992 to lead corporate developer evangelism for Windows NT. He then served as a Group Program manager in Microsoft’s Internet Business Unit. In 1998, he led the creation of SharePoint Portal Server, which became one of Microsoft’s fastest-growing businesses, exceeding $2 billion in revenues. Jeff next served as Corporate Vice President for Program Management across Office 365 Services and Servers, which is the foundation of Microsoft’s enterprise cloud leadership. He then led Corporate Strategy supporting Satya Nadella and Amy Hood on Microsoft’s mobile-first/cloud-first transformation and acquisitions. Prior to joining Microsoft, Jeff was vice president for software development for an investment firm in New York. He leads Office shared experiences and core applications, as well as OneDrive and SharePoint consumer and business services in Office 365. Jeff holds a Master of Business Administration degree from Harvard Business School and a Bachelor of Science degree in information systems and finance from New York University.""" + # summary = "Jeff joined Microsoft in 1992 to lead the company's corporate evangelism. He then served as a Group Manager in Microsoft's Internet Business Unit. In 1998, Jeff led Sharepoint Portal Server, which became the company's fastest-growing business, surpassing $3 million in revenue. Jeff next leads corporate strategy for SharePoint and Servers which is the basis of Microsoft's cloud-first strategy. He leads corporate strategy for Satya Nadella and Amy Hood on Microsoft's mobile-first." + + # metric = SummaCZSMetric(granularity="sentence", + # use_ent=True, + # use_con=False) + # print(metric.compute([document], [summary], [])) + + # metric = SummaCConvMetric(granularity="sentence") + # print(metric.compute([document], [summary], [])) + + prompt_texts = ["1", "2"] + gen_texts = [ + "The dog is the boy's cat.", + "A boy is picking apples from trees and put them into bags.", + ] + reference_texts = [ + ["The dog is the boy's cat.", "The dog eats the cat of the boy."], + ["A boy is picking apples from trees."], + ] diff --git a/benchmark/torch/RL4LMs/utils/registry.py b/benchmark/torch/RL4LMs/utils/registry.py new file mode 100644 index 000000000..1630800b2 --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/registry.py @@ -0,0 +1,189 @@ +from typing import Any, Dict, Type, Union + + +from benchmark.torch.RL4LMs.algorithms import RL4LMPPO +from benchmark.torch.RL4LMs.summarization import RL4LMsSummaAgent + +from .data_pool import TextGenPool, CNNDailyMail +# from rl4lms.envs.text_generation.alg_wrappers import wrap_onpolicy_alg + +from .metric_util import ( + BaseMetric, + BERTScoreMetric, + BLEUMetric, + BLEURTMetric, + BLEUToTTo, + DiversityMetrics, + LearnedRewardMetric, + MeteorMetric, + Perplexity, + RougeLMax, + RougeMetric, + SacreBLEUMetric, + TERMetric, + chrFmetric, +) +from benchmark.torch.RL4LMs.models import LMActorCriticPolicy + +from benchmark.torch.RL4LMs.models import Seq2SeqLMModel + +from .reward_util import ( + BERTScoreRewardFunction, + BLEURewardFunction, + BLEURTRewardFunction, + CommonGenPenaltyShapingFunction, + LearnedRewardFunction, + MeteorRewardFunction, + PARENTRewardFunction, + RewardFunction, + RougeCombined, + RougeLMaxRewardFunction, + RougeRewardFunction, + SacreBleu, +) + + + +class DataPoolRegistry: + _registry = { + "cnn_daily_mail": CNNDailyMail, + } + + @classmethod + def get(cls, datapool_id: str, kwargs: Dict[str, Any]) -> TextGenPool: + datapool_cls = cls._registry[datapool_id] + datapool = datapool_cls.prepare(**kwargs) + return datapool + + @classmethod + def add(cls, id: str, datapool_cls: Type[TextGenPool]): + DataPoolRegistry._registry[id] = datapool_cls + + +class RewardFunctionRegistry: + _registry = { + "learned_reward": LearnedRewardFunction, + "meteor": MeteorRewardFunction, + "rouge": RougeRewardFunction, + "bert_score": BERTScoreRewardFunction, + "bleu": BLEURewardFunction, + "bleurt": BLEURTRewardFunction, + "rouge_combined": RougeCombined, + "common_gen_repeat_penalty": CommonGenPenaltyShapingFunction, + "parent": PARENTRewardFunction, + "sacre_bleu": SacreBleu, + "rouge_l_max": RougeLMaxRewardFunction, + } + + @classmethod + def get(cls, reward_fn_id: str, kwargs: Dict[str, Any]) -> RewardFunction: + reward_cls = cls._registry[reward_fn_id] + reward_fn = reward_cls(**kwargs) + return reward_fn + + @classmethod + def add(cls, id: str, reward_fn_cls: Type[RewardFunction]): + RewardFunctionRegistry._registry[id] = reward_fn_cls + + +class MetricRegistry: + _registry = { + "learned_reward": LearnedRewardMetric, + "meteor": MeteorMetric, + "rouge": RougeMetric, + "bert_score": BERTScoreMetric, + "bleu": BLEUMetric, + "bleurt": BLEURTMetric, + "diversity": DiversityMetrics, + + "causal_perplexity": Perplexity, + + "bleu_totto": BLEUToTTo, + "rouge_l_max": RougeLMax, + "sacre_bleu": SacreBLEUMetric, + "ter": TERMetric, + "chrf": chrFmetric, + + } + + @classmethod + def get(cls, metric_id: str, kwargs: Dict[str, Any]) -> BaseMetric: + metric_cls = cls._registry[metric_id] + metric = metric_cls(**kwargs) + return metric + + @classmethod + def add(cls, id: str, metric_cls: Type[BaseMetric]): + MetricRegistry._registry[id] = metric_cls + + +class PolicyRegistry: + _registry = { + "seq2seq_lm_actor_critic_policy": Seq2SeqLMModel, + } + + @classmethod + def get(cls, policy_id: str) -> Type[LMActorCriticPolicy]: + policy_cls = cls._registry[policy_id] + return policy_cls + + @classmethod + def add(cls, id: str, policy_cls: Type[LMActorCriticPolicy]): + PolicyRegistry._registry[id] = policy_cls + + +class AlgorithmRegistry: + _registry = { + "ppo": RL4LMPPO, + } + + @classmethod + def get( + cls, alg_id: str + ): + try: + alg_cls = cls._registry[alg_id] + except KeyError: + raise NotImplementedError + return alg_cls + + @classmethod + def add( + cls, id: str, alg_cls + ): + AlgorithmRegistry._registry[id] = alg_cls + + +class WrapperRegistry: + _registry = { + "ppo": RL4LMsSummaAgent, + } + + @classmethod + def get(cls, alg_id: str): + try: + wrapper_def = cls._registry[alg_id] + except KeyError: + raise NotImplementedError + return wrapper_def + + @classmethod + def add(cls, id: str, wrapper_def): + WrapperRegistry._registry[id] = wrapper_def + + +class PostProcessorRegistry: + _registry = { + } + + @classmethod + def get(cls, post_processor_id: str): + try: + wrapper_def = cls._registry[post_processor_id] + except KeyError: + raise NotImplementedError + return wrapper_def + + @classmethod + def add(cls, id: str, post_processor_fn): + PostProcessorRegistry._registry[id] = post_processor_fn diff --git a/benchmark/torch/RL4LMs/utils/reward_util.py b/benchmark/torch/RL4LMs/utils/reward_util.py new file mode 100644 index 000000000..b62443a87 --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/reward_util.py @@ -0,0 +1,446 @@ +from abc import ABC, abstractclassmethod + +import torch +from datasets import load_metric +from .data_wrapper import Observation +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from .metric_util import ( + MeteorMetric, + BERTScoreMetric, + BLEUMetric, + RougeLMax, +) +import numpy as np +from typing import List, Dict, Any + + +class RewardFunction(ABC): + @abstractclassmethod + def __call__( + self, + current_observation: Observation, + action: int, + next_observation: Observation, + done: bool, + meta_info: Dict[str, Any] = None, + ) -> float: + """ + Callable for reward functions for text generation + + Args: + current_observation (Observation): previous observation (s) + action (int): action performed (a) at s + next_observation (Observation): observation after the action was performed (s') + done (bool): whether the episode is finished or not + meta_info (dict) - other information regarding textual sample + Returns: + float: scalar reward + """ + raise NotImplementedError + + +class BatchedRewardFunction(ABC): + """ + Computes rewards for several instances at once + """ + + @abstractclassmethod + def __call__( + self, + prompt_texts: List[str], + gen_texts: List[str], + ref_texts: List[List[str]], + dones: List[bool], + meta_infos: List[Dict[str, Any]] = None, + ) -> List[float]: + """ + An abstract class for batched reward functions for text generation + """ + raise NotImplementedError + + +### Automated reward functions ########################### + + +class CommonGenPenaltyShapingFunction(RewardFunction): + def __call__( + self, + current_observation: Observation, + action: int, + next_observation: Observation, + done: bool, + meta_info: Dict[str, Any] = None, + ) -> float: + if done: + prompt_text = next_observation.prompt_or_input_text + prefix = "generate a sentence with: " + concept_n_grams = prompt_text.split(prefix)[1][:-1] + + if ( + concept_n_grams.lower() in next_observation.context_text.lower() + or prefix in next_observation.context_text.lower() + or "generate" in next_observation.context_text.lower() + or "sentence" in next_observation.context_text.lower() + ): + penalty_score = -1 + else: + penalty_score = 0 + return penalty_score + return 0 + + + + + +class MeteorRewardFunction(RewardFunction): + def __init__(self, shaping_fn: str = None) -> None: + super().__init__() + self._metric = MeteorMetric() + from rl4lms.envs.text_generation.registry import RewardFunctionRegistry + + self._shaping_fn = ( + RewardFunctionRegistry.get(shaping_fn, {}) + if shaping_fn is not None + else shaping_fn + ) + + def __call__( + self, + current_observation: Observation, + action: int, + next_observation: Observation, + done: bool, + meta_info: Dict[str, Any] = None, + ) -> float: + + # compute meteor at the end of episode + if done: + references = [next_observation.target_or_reference_texts] + predicted = [next_observation.context_text] + metric_dict = self._metric.compute(None, predicted, references) + score = metric_dict["lexical/meteor"][1] + + if self._shaping_fn is not None: + aux_score = self._shaping_fn( + current_observation, action, next_observation, done, meta_info + ) + score = score + aux_score + return score + return 0 + + +class RougeRewardFunction(RewardFunction): + def __init__( + self, rouge_type: str, shaping_fn: str = None, use_single_ref: bool = True + ) -> None: + super().__init__() + self._metric = load_metric("rouge") + self._rouge_type = rouge_type + from rl4lms.envs.text_generation.registry import RewardFunctionRegistry + + self._shaping_fn = ( + RewardFunctionRegistry.get(shaping_fn, {}) + if shaping_fn is not None + else shaping_fn + ) + self._use_single_ref = use_single_ref + + def __call__( + self, + current_observation: Observation, + action: int, + next_observation: Observation, + done: bool, + meta_info: Dict[str, Any] = None, + ) -> float: + if done: + # TBD: considers only one reference for now + if self._use_single_ref: + references = [next_observation.target_or_reference_texts[0]] + else: + references = [next_observation.target_or_reference_texts] + predicted = [next_observation.context_text] + + metric_results = self._metric.compute( + predictions=predicted, references=references, use_stemmer=True + ) + reward = metric_results[self._rouge_type].mid.fmeasure + if self._shaping_fn is not None: + aux_score = self._shaping_fn( + current_observation, action, next_observation, done, meta_info + ) + reward = reward + aux_score + return reward + return 0 + + +class RougeCombined(RewardFunction): + def __init__(self, shaping_fn: str = None) -> None: + super().__init__() + self._metric = load_metric("rouge") + from rl4lms.envs.text_generation.registry import RewardFunctionRegistry + + self._shaping_fn = ( + RewardFunctionRegistry.get(shaping_fn, {}) + if shaping_fn is not None + else shaping_fn + ) + + def __call__( + self, + current_observation: Observation, + action: int, + next_observation: Observation, + done: bool, + meta_info: Dict[str, Any] = None, + ) -> float: + if done: + # TBD: considers only one reference for now + references = [next_observation.target_or_reference_texts[0]] + predicted = [next_observation.context_text] + + metric_results = self._metric.compute( + predictions=predicted, references=references, use_stemmer=True + ) + + rouge_keys = ["rouge1", "rouge2", "rougeL"] + scores = [ + metric_results[rouge_type].mid.fmeasure for rouge_type in rouge_keys + ] + reward = np.mean(scores) + if self._shaping_fn is not None: + aux_score = self._shaping_fn( + current_observation, action, next_observation, done, meta_info + ) + reward = reward + aux_score + return reward + return 0 + + +class BERTScoreRewardFunction(RewardFunction): + def __init__(self, language: str = "en") -> None: + super().__init__() + self._metric = BERTScoreMetric(language) + + def __call__( + self, + current_observation: Observation, + action: int, + next_observation: Observation, + done: bool, + meta_info: Dict[str, Any] = None, + ) -> float: + if done: + references = [next_observation.target_or_reference_texts] + predicted = [next_observation.context_text] + metric_results = self._metric.compute(None, predicted, references) + bert_score = metric_results["semantic/bert_score"][1] + return bert_score + return 0 + + +class BLEURewardFunction(RewardFunction): + def __init__(self) -> None: + super().__init__() + self._metric = BLEUMetric() + + def __call__( + self, + current_observation: Observation, + action: int, + next_observation: Observation, + done: bool, + meta_info: Dict[str, Any] = None, + ) -> float: + if done: + references = [next_observation.target_or_reference_texts] + predicted = [next_observation.context_text] + metric_results = self._metric.compute(None, predicted, references) + bleu_score = metric_results["lexical/bleu"][1] + return bleu_score + return 0 + + +class SacreBleu(RewardFunction): + def __init__(self, **args) -> None: + super().__init__() + self._metric = load_metric("sacrebleu") + self._args = args + + def __call__( + self, + current_observation: Observation, + action: int, + next_observation: Observation, + done: bool, + meta_info: Dict[str, Any] = None, + ) -> float: + if done: + references = [next_observation.target_or_reference_texts] + predicted = [next_observation.context_text] + metric_results = self._metric.compute( + predictions=predicted, references=references, **self._args + ) + return metric_results["score"] / 100 + return 0 + + + + +############################################################################# + +########## Learned Reward Functions########################################## + + +class LearnedRewardFunction(RewardFunction): + def __init__( + self, model_name: str, label_ix: int, include_prompt_for_eval: bool = True + ) -> None: + super().__init__() + self._device = "cuda" if torch.cuda.is_available() else "cpu" + self._metric_tokenizer = AutoTokenizer.from_pretrained(model_name) + self._metric_tokenizer.truncation_side = "left" + self._metric_model = AutoModelForSequenceClassification.from_pretrained( + model_name + ).to(self._device) + self._label_ix = label_ix + self._include_prompt_for_eval = include_prompt_for_eval + + def __call__( + self, + current_observation: Observation, + action: int, + next_observation: Observation, + done: bool, + meta_info: Dict[str, Any] = None, + ) -> float: + if done: + generated_text = ( + current_observation.prompt_or_input_text + if self._include_prompt_for_eval + else "" + ) + generated_text += next_observation.context_text + + with torch.no_grad(): + encoded = self._metric_tokenizer( + generated_text, return_tensors="pt", truncation=True, padding=True + ) + outputs = self._metric_model( + input_ids=encoded.input_ids.to(self._device), + attention_mask=encoded.attention_mask.to(self._device), + ) + scores = torch.softmax(outputs.logits.flatten(), dim=0) + score = scores[self._label_ix].item() + return score + return 0 + + +class BLEURTRewardFunction(RewardFunction): + def __init__(self, checkpoint: str = None): + super().__init__() + self._metric = load_metric("bleurt", checkpoint=checkpoint) + + def __call__( + self, + current_observation: Observation, + action: int, + next_observation: Observation, + done: bool, + meta_info: Dict[str, Any] = None, + ) -> float: + if done: + references = [next_observation.target_or_reference_texts] + predicted = [next_observation.context_text] + metric_results = self._metric.compute( + predictions=predicted, references=references + ) + score = metric_results["scores"][0] + return score + return 0 + + +class PARENTRewardFunction(RewardFunction): + """ + PARENT F1 score as the reward + """ + + def __init__(self) -> None: + super().__init__() + self._metric = ParentToTTo() + + def __call__( + self, + current_observation: Observation, + action: int, + next_observation: Observation, + done: bool, + meta_info: Dict[str, Any] = None, + ) -> float: + if done: + generated_texts = [next_observation.context_text] + meta_infos = [meta_info] + scores = self._metric.compute(None, generated_texts, None, meta_infos) + reward = scores["table_to_text/parent_overall_f_score"][0][0] + return reward + return 0 + + +class RougeLMaxRewardFunction(RewardFunction): + def __init__(self, **args) -> None: + super().__init__() + self._metric = RougeLMax(**args) + + def __call__( + self, + current_observation: Observation, + action: int, + next_observation: Observation, + done: bool, + meta_info: Dict[str, Any] = None, + ) -> float: + if done: + references = [next_observation.target_or_reference_texts] + predicted = [next_observation.context_text] + meta_infos = [meta_info] + scores = self._metric.compute(None, predicted, references, meta_infos) + reward = scores["lexical/rouge_l_max"][0][0] + return reward + return 0 + + + + +if __name__ == "__main__": + predictions = "hello there general kenobi" + references = ["hello there general kenobi", "hello there!!"] + observation = Observation( + None, None, None, None, None, predictions, references, None, None, None, None + ) + + reward_fn = MeteorRewardFunction() + print(reward_fn(None, None, observation, True)) + + reward_fn = chrF() + print(reward_fn(None, None, observation, True)) + + reward_fn = RougeCombined() + print(reward_fn(None, None, observation, True)) + + reward_fn = RougeRewardFunction(rouge_type="rouge1") + print(reward_fn(None, None, observation, True)) + + reward_fn = RougeRewardFunction(rouge_type="rouge2") + print(reward_fn(None, None, observation, True)) + + reward_fn = RougeRewardFunction(rouge_type="rougeL") + print(reward_fn(None, None, observation, True)) + + reward_fn = BERTScoreRewardFunction(language="en") + print(reward_fn(None, None, observation, True)) + + reward_fn = BLEURewardFunction() + print(reward_fn(None, None, observation, True)) + + reward_fn = BLEURTRewardFunction() + print(reward_fn(None, None, observation, True)) diff --git a/benchmark/torch/RL4LMs/utils/sample_util.py b/benchmark/torch/RL4LMs/utils/sample_util.py new file mode 100644 index 000000000..d403fd741 --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/sample_util.py @@ -0,0 +1,40 @@ +from collections import deque +from typing import Any, List +import numpy as np + + +class PrioritySampler: + def __init__(self, max_size: int = None, priority_scale: float = 0.0): + """ + Creates a priority sampler + + Args: + max_size (int): maximum size of the queue + priority_scale (float): 0.0 is a pure uniform sampling, 1.0 is completely priority sampling + """ + self.max_size = max_size + self.items = deque(maxlen=self.max_size) + self.item_priorities = deque(maxlen=self.max_size) + self.priority_scale = priority_scale + + def add(self, item: Any, priority: float): + self.items.append(item) + self.item_priorities.append(priority) + + def sample(self, size: int) -> List[Any]: + min_sample_size = min(len(self.items), size) + scaled_item_priorities = np.array( + self.item_priorities) ** self.priority_scale + sample_probs = scaled_item_priorities / np.sum(scaled_item_priorities) + samples = np.random.choice( + a=self.items, p=sample_probs, size=min_sample_size) + return samples + + def update(self, item: Any, priority: float): + index = self.items.index(item) + del self.items[index] + del self.item_priorities[index] + self.add(item, priority) + + def get_all_samples(self) -> List[Any]: + return self.items diff --git a/benchmark/torch/RL4LMs/utils/tracker.py b/benchmark/torch/RL4LMs/utils/tracker.py new file mode 100644 index 000000000..5c48855b7 --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/tracker.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import Dict, Any, List +import os +import json +import jsonlines +import pandas as pd +from transformers import AutoModel +import logging +import copy +import random + + +class Tracker: + def __init__(self, + base_path_to_store_results: str, + run_config: Dict[str, Any], + project_name: str, + experiment_name: str, + entity_name: str = None, + wandb_log: bool = False, + log_level: int = logging.DEBUG, + ): + self._log_level = log_level + self._base_path_to_store_results = base_path_to_store_results + self._config = run_config + self._experiment_name = experiment_name + self._project_name = project_name + self._entity_name = entity_name + self._wandb_log = wandb_log + self._init() + + def _init(self): + # create a folder + self._run_path = os.path.join( + self._base_path_to_store_results, + self._project_name, + self._experiment_name) + os.makedirs(self._run_path, exist_ok=True) + + # store also the config into it + config_path = os.path.join(self._run_path, "config.json") + with open(config_path, "w") as fp: + json.dump(self._config, fp) + + # init logger + log_path = os.path.join(self._run_path, "log.txt") + logging.basicConfig( + level=self._log_level, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.FileHandler(log_path) ] + ) + + + def log_predictions(self, epoch: int, + split_name: str, + predictions: List[Dict]): + # log them per epoch in a separate file as they can get huge + prediction_file_at_epoch = os.path.join( + self._run_path, f"epoch_{epoch}_{split_name}_split_predictions.json") + with open(prediction_file_at_epoch, "w") as fp: + json.dump(predictions, fp) + + # randomly display few predictions for logging + predictions_ = copy.deepcopy(predictions) + random.shuffle(predictions_) + logging.info(f"Split {split_name} predictions") + for pred in predictions_[:10]: + logging.info(pred) + + + def log_metrics(self, epoch: int, + split_name: str, + metrics_dict: Dict[str, float]): + # for each split, one file + metric_file_per_split = os.path.join( + self._run_path, f"{split_name}_split_metrics.jsonl") + metrics_dict_ = { + "epoch": epoch, + "metrics": metrics_dict + } + with jsonlines.open(metric_file_per_split, "a") as writer: + writer.write(metrics_dict_) + + # logger + logging.info(f"{split_name} metrics: {metrics_dict_}") + + def log_rollout_infos(self, rollout_info: Dict[str, float]): + logging.info(f"Rollout Info: {rollout_info}") + rollout_info_file = os.path.join( + self._run_path, "rollout_info.jsonl") + with jsonlines.open(rollout_info_file, mode="a") as writer: + writer.write(rollout_info) + + def log_training_infos(self, training_info: Dict[str, float]): + logging.info(f"Training Info: {training_info}") + training_info_file = os.path.join( + self._run_path, "training_info.jsonl") + with jsonlines.open(training_info_file, mode="a") as writer: + writer.write(training_info) + + def done(self): + pass + + def save_auto_model(self, model: AutoModel): + model_path = os.path.join(self._run_path, "model") + model.save_pretrained(model_path) + + @property + def checkpoint_base_path(self): + return os.path.join(self._run_path, "checkpoints") + + def log_info(self, msg: str): + logging.info(msg) + + +if __name__ == "__main__": + base_path = "/data/zhangsw/" + run_config = { + "param_1": 1, + "param_2": 2 + } + predictions = { + "1": [{"sample_id": "1", "prompt_text": "Hello", "gen_text": "I am there"}, + {"sample_id": "2", "prompt_text": "Hi", "gen_text": "there"}], + "2": [{"sample_id": "1", "prompt_text": "Hello", "gen_text": "I am there"}, + {"sample_id": "2", "prompt_text": "Hi", "gen_text": "there"}], + "3": [{"sample_id": "1", "prompt_text": "Hello", "gen_text": "I am there"}, + {"sample_id": "2", "prompt_text": "Hi", "gen_text": "there"}], + } + + metrics = { + "1": {"metric_1": 0.05, "metric_2": 0.1}, + "2": {"metric_1": 0.06, "metric_2": 0.2}, + "3": {"metric_1": 0.06, "metric_2": 0.3}, + } + + rollout_infos = [ + {"ep_len": 2, "ep_reward": 0.4}, + {"ep_len": 3, "ep_reward": 0.5}, + {"ep_len": 3, "ep_reward": 0.5}, + ] + + tracker = Tracker(base_path, run_config, "test_logs", "test_run", "T_1", False) + tracker.log_predictions(1, "val", predictions["1"]) + tracker.log_metrics(1, "val", metrics["1"]) + tracker.log_predictions(2, "val", predictions["2"]) + tracker.log_metrics(2, "val", metrics["2"]) + tracker.log_predictions(3, "val", predictions["3"]) + tracker.log_metrics(3, "val", metrics["3"]) + tracker.log_rollout_infos(rollout_infos[0]) + tracker.log_rollout_infos(rollout_infos[1]) + tracker.log_rollout_infos(rollout_infos[2]) + tracker.done() \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/type_wrapper.py b/benchmark/torch/RL4LMs/utils/type_wrapper.py new file mode 100644 index 000000000..17f81ddd8 --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/type_wrapper.py @@ -0,0 +1,7 @@ +from typing import Any, Dict, Optional, List, Union, Callable +import torch + + +# refer to stable_baselines3.common.type_aliases +TensorDict = Dict[Union[str, int], torch.Tensor] +Schedule = Callable[[float], float] diff --git a/benchmark/torch/RL4LMs/utils/warm_start.py b/benchmark/torch/RL4LMs/utils/warm_start.py new file mode 100644 index 000000000..efa12ba14 --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/warm_start.py @@ -0,0 +1,147 @@ +import os +from typing import Any, Dict + +import torch + +# from rl4lms.envs.text_generation.logging_utils import Tracker +# from rl4lms.envs.text_generation.policy.base_policy import LMActorCriticPolicy + + +class ActorCriticWarmStartMixin: + def get_state_dict(self) -> Dict[str, Any]: + state_dict = { + "policy_model": self._policy_model.state_dict(), + "value_model": self._value_model.state_dict(), + "value_head": self._value_head.state_dict(), + "optimizer": self.optimizer.state_dict() + } + return state_dict + + def load_from_dict(self, state_dict: dict = None): + if state_dict is not None: + self._policy_model.load_state_dict(state_dict["policy_model"]) + self._value_model.load_state_dict(state_dict["value_model"]) + self._value_head.load_state_dict(state_dict["value_head"]) + self.optimizer.load_state_dict(state_dict["optimizer"]) + + + +class OnPolicyWarmStartMixin: + def get_state_dict(self) -> Dict[str, Any]: + # just the kl controller state is sufficient for onpolicy algs + state_dict = { + "kl_controller_state": self._kl_controller.get_state_dict(), + } + return state_dict + + def load_from_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: + if state_dict is not None: + self._kl_controller.load_from_state_dict( + state_dict["kl_controller_state"]) + +# ################## Policy Warm Start Mixins####################################### +# +# +# class ActorOnlyWarmStartMixin: +# def get_state_dict(self) -> Dict[str, Any]: +# state_dict = { +# "policy_model": self._policy_model.state_dict(), +# "optimizer": self.optimizer.state_dict() +# } +# return state_dict +# +# def load_from_dict(self, state_dict: dict = None): +# if state_dict is not None: +# self._policy_model.load_state_dict(state_dict["policy_model"]) +# self.optimizer.load_state_dict(state_dict["optimizer"]) +# +# +# +# +# +# +# +# ################## Algorithm Warm Start Mixins####################################### + +# +# +# class OffPolicyWarmStartMixin: +# def get_state_dict(self) -> Dict[str, Any]: +# # TBD: just buffer is sufficient? or is there something else? +# state_dict = { +# "replay_buffer": self.replay_buffer.get_state_dict(), +# } +# return state_dict +# +# def load_from_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: +# if state_dict is not None: +# self.replay_buffer.load_from_state_dict( +# state_dict["replay_buffer"]) +# +# +# ################## Trainer Warm Start Mixins####################################### +# class TrainerWarmStartMixin: +# def _get_recent_ckpt_path(self, tracker: Tracker): +# try: +# checkpoints = os.listdir(tracker.checkpoint_base_path) +# except: +# os.makedirs(tracker.checkpoint_base_path) +# checkpoints = os.listdir(tracker.checkpoint_base_path) +# +# if len(checkpoints) == 0: +# return None, None +# +# sorted_ckpts = sorted(checkpoints, reverse=True, +# key=lambda ckpt: int(ckpt.split("_")[1])) +# recent_ckpt = sorted_ckpts[0] +# recent_ckpt_id = int(recent_ckpt.split("_")[1]) +# +# recent_ckpt_path = os.path.join( +# tracker.checkpoint_base_path, f"checkpoint_{recent_ckpt_id}") +# return recent_ckpt_path, recent_ckpt_id +# +# def load_trainer_state(self, tracker: Tracker): +# recent_ckpt_path, _ = self._get_recent_ckpt_path(tracker) +# state_dict = None +# try: +# if recent_ckpt_path is not None: +# state_dict = torch.load( +# recent_ckpt_path, map_location=torch.device("cuda")) +# tracker.log_info("Model checkpoint found - Warm starting") +# self._policy_state_dict = state_dict["policy_state"] +# self._alg_state_dict = state_dict["alg_state"] +# self._trainer_state = state_dict["trainer_state"] +# +# tracker.log_info( +# f"Loaded the current trainer state from: {self._trainer_state}") +# else: +# self._policy_state_dict = None +# self._alg_state_dict = None +# self._trainer_state = { +# "current_iter": 0, +# } +# except Exception as e: +# tracker.log_info(f"Exception while doing warm start {e}") +# tracker.log_info( +# f"Checkpoint may be corrupted...skipping warm start") +# self._policy_state_dict = None +# self._alg_state_dict = None +# self._trainer_state = { +# "current_iter": 0, +# } +# +# def save_trainer_state(self, tracker: Tracker, +# policy: LMActorCriticPolicy, +# trainer_state: Dict[str, Any]): +# full_state = { +# "alg_state": self._alg.get_state_dict(), +# "policy_state": policy.get_state_dict(), +# "trainer_state": trainer_state +# } +# _, recent_ckpt_id = self._get_recent_ckpt_path(tracker) +# +# # hot fix - just to save only the last checkpoint (overwrite) +# new_ckpt_id = 0 if recent_ckpt_id is None else recent_ckpt_id + 1 +# new_ckpt_path = os.path.join( +# tracker.checkpoint_base_path, f"checkpoint_{new_ckpt_id}") +# torch.save(full_state, new_ckpt_path, pickle_protocol=4) From e706ed4b1b346e589ec835f84057e3728093af9d Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Thu, 2 Mar 2023 21:27:05 +0800 Subject: [PATCH 02/34] benchmark of RL4LMs v0.0 --- benchmark/torch/RL4LMs/algorithms/__init__.py | 2 +- benchmark/torch/RL4LMs/algorithms/ppo.py | 145 ++++++ .../torch/RL4LMs/algorithms/rl4lm_ppo.py | 5 - .../RL4LMs/configs/summarization/t5_ppo.yml | 2 +- benchmark/torch/RL4LMs/env/__init__.py | 3 +- benchmark/torch/RL4LMs/env/vec_env.py | 203 ++++++++ benchmark/torch/RL4LMs/metrics/__init__.py | 16 + .../RL4LMs/{utils => metrics}/metric_util.py | 0 benchmark/torch/RL4LMs/models/__init__.py | 2 +- benchmark/torch/RL4LMs/models/base_model.py | 259 +--------- .../torch/RL4LMs/models/seq2seq_model.py | 8 +- .../torch/RL4LMs/{utils => }/registry.py | 24 +- .../summarization/rl4lms_summa_agent.py | 481 ++++-------------- .../summarization/rl4lms_summa_model.py | 7 - benchmark/torch/RL4LMs/train.py | 5 +- benchmark/torch/RL4LMs/trainers.py | 432 ++++++++++++++-- benchmark/torch/RL4LMs/utils/__init__.py | 11 +- benchmark/torch/RL4LMs/utils/data_wrapper.py | 1 - .../torch/RL4LMs/utils/evaluation_util.py | 16 +- .../utils/huggingface_generation_util.py | 1 + benchmark/torch/RL4LMs/utils/reward_util.py | 60 +-- benchmark/torch/RL4LMs/utils/warm_start.py | 134 ++--- 22 files changed, 997 insertions(+), 820 deletions(-) create mode 100644 benchmark/torch/RL4LMs/algorithms/ppo.py delete mode 100644 benchmark/torch/RL4LMs/algorithms/rl4lm_ppo.py create mode 100644 benchmark/torch/RL4LMs/env/vec_env.py create mode 100644 benchmark/torch/RL4LMs/metrics/__init__.py rename benchmark/torch/RL4LMs/{utils => metrics}/metric_util.py (100%) rename benchmark/torch/RL4LMs/{utils => }/registry.py (87%) delete mode 100644 benchmark/torch/RL4LMs/summarization/rl4lms_summa_model.py diff --git a/benchmark/torch/RL4LMs/algorithms/__init__.py b/benchmark/torch/RL4LMs/algorithms/__init__.py index 8d9429824..8bacfd707 100644 --- a/benchmark/torch/RL4LMs/algorithms/__init__.py +++ b/benchmark/torch/RL4LMs/algorithms/__init__.py @@ -1 +1 @@ -from .rl4lm_ppo import RL4LMPPO \ No newline at end of file +from .ppo import RL4LMPPO \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/algorithms/ppo.py b/benchmark/torch/RL4LMs/algorithms/ppo.py new file mode 100644 index 000000000..7af0c9114 --- /dev/null +++ b/benchmark/torch/RL4LMs/algorithms/ppo.py @@ -0,0 +1,145 @@ +import parl +from benchmark.torch.RL4LMs.utils import Tracker +from benchmark.torch.RL4LMs.utils import Schedule +from typing import Union, Optional, Dict, Any +import torch +from gym import spaces +from benchmark.torch.RL4LMs.utils import EvaluateActionsOutput +from torch.nn import functional as F + + +from parl.algorithms.torch import PPO + +class RL4LMPPO(parl.Algorithm): + def __init__(self, + model: parl.Model, + tracker: Tracker, + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 2048, + batch_size: int = 64, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: Union[float, Schedule] = 0.2, + normalize_advantage: bool = True, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + target_kl: Optional[float] = None, + seed: Optional[int] = None, + device: Union[torch.device, str] = "auto", + _init_setup_model: bool = True, + ): + super(RL4LMPPO, self).__init__(model=model) + self.tracker = tracker + self.learning_rate = learning_rate + self.n_steps = n_steps + self.batch_size = batch_size + self.n_epochs = n_epochs + self.gamma = gamma + self.gae_lambda = gae_lambda + self.clip_range = clip_range + self.normalize_advantage = normalize_advantage + self.ent_coef = ent_coef + self.vf_coef = vf_coef + self.max_grad_norm = max_grad_norm + self.target_kl = target_kl + self.seed = seed + self.device = device + + def learn(self, rollout_buffer, log_info): + entropy_losses = log_info["entropy_losses"] + pg_losses = log_info["entropy_losses"] + value_losses = log_info["value_losses"] + clip_fractions = log_info["clip_fractions"] + approx_kl_divs = log_info["approx_kl_divs"] + continue_training = True + # Do a complete pass on the rollout buffer + for batch_ix, rollout_data in enumerate(list(rollout_buffer.get(self.batch_size))): + # self.verify_rollout_data(rollout_data) + + actions = rollout_data.actions + if isinstance(self.model.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + + evaluation_output: EvaluateActionsOutput = self.model.evaluate_actions( + rollout_data.observations, actions) + values, log_prob, entropy = evaluation_output.values, evaluation_output.log_prob, evaluation_output.entropy + values = values.flatten() + # Normalize advantage + advantages = rollout_data.advantages + if self.normalize_advantage: + advantages = (advantages - advantages.mean() + ) / (advantages.std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = torch.exp(log_prob - rollout_data.old_log_prob) + + # clipped surrogate loss + policy_loss_1 = advantages * ratio + policy_loss_2 = advantages * \ + torch.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range) + policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean() + + # Logging + pg_losses.append(policy_loss.item()) + clip_fraction = torch.mean( + (torch.abs(ratio - 1) > self.clip_range).float()).item() + clip_fractions.append(clip_fraction) + + # No clipping + values_pred = values + + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(rollout_data.returns, values_pred) + value_losses.append(value_loss.item()) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -torch.mean(-log_prob) + else: + entropy_loss = -torch.mean(entropy) + + entropy_losses.append(entropy_loss.item()) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with torch.no_grad(): + log_ratio = log_prob - rollout_data.old_log_prob + approx_kl_div = torch.mean( + (torch.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + break + + # Optimization step + self.model.optimizer.zero_grad() + loss.backward() + # Clip grad norm + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm) + self.model.optimizer.step() + + return continue_training, loss + + + def sample(self, obs): + pass + + def predict(self, obs): + pass + + def value(self, obs): + pass + + + diff --git a/benchmark/torch/RL4LMs/algorithms/rl4lm_ppo.py b/benchmark/torch/RL4LMs/algorithms/rl4lm_ppo.py deleted file mode 100644 index ee2592f1a..000000000 --- a/benchmark/torch/RL4LMs/algorithms/rl4lm_ppo.py +++ /dev/null @@ -1,5 +0,0 @@ -from parl.algorithms.torch.ppo import PPO - - -class RL4LMPPO(PPO): - pass \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml b/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml index de73bbb3a..a907057d0 100644 --- a/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml +++ b/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml @@ -45,7 +45,7 @@ alg: kl_div: coeff: 0.001 target_kl: 0.2 - policy: + model: id: seq2seq_lm_actor_critic_policy args: model_name: t5-base diff --git a/benchmark/torch/RL4LMs/env/__init__.py b/benchmark/torch/RL4LMs/env/__init__.py index 40764a3b1..39f83816f 100644 --- a/benchmark/torch/RL4LMs/env/__init__.py +++ b/benchmark/torch/RL4LMs/env/__init__.py @@ -1 +1,2 @@ -from .text_gen_env import TextGenEnv \ No newline at end of file +from .text_gen_env import TextGenEnv +from .vec_env import LocalParallelVecEnv, make_vec_env \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/env/vec_env.py b/benchmark/torch/RL4LMs/env/vec_env.py new file mode 100644 index 000000000..bd94490aa --- /dev/null +++ b/benchmark/torch/RL4LMs/env/vec_env.py @@ -0,0 +1,203 @@ +import numpy as np +import cloudpickle +import gym +from collections import OrderedDict +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union, Dict +import multiprocessing as mp + + +class CloudpickleWrapper: + def __init__(self, var): + self.var = var + + def __getstate__(self): + return cloudpickle.dumps(self.var) + + def __setstate__(self, var) -> None: + self.var = cloudpickle.loads(var) + + +def _flatten_obs(obs, space: gym.spaces.Space): + assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" + assert len(obs) > 0, "need observations from at least one environment" + + if isinstance(space, gym.spaces.Dict): + assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" + assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" + return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) + elif isinstance(space, gym.spaces.Tuple): + assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" + obs_len = len(space.spaces) + return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) + else: + return np.stack(obs) + + +def _worker( + remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper +) -> None: + # Import here to avoid a circular import + + parent_remote.close() + env = env_fn_wrapper.var() + while True: + try: + cmd, data = remote.recv() + if cmd == "step": + observation, reward, done, info = env.step(data) + if done: + # save final observation where user can get it, then reset + info["terminal_observation"] = observation + observation = env.reset() + remote.send((observation, reward, done, info)) + elif cmd == "seed": + remote.send(env.seed(data)) + elif cmd == "reset": + observation = env.reset() + remote.send(observation) + elif cmd == "render": + remote.send(env.render(data)) + elif cmd == "close": + env.close() + remote.close() + break + elif cmd == "get_spaces": + remote.send((env.observation_space, env.action_space)) + elif cmd == "env_method": + method = getattr(env, data[0]) + remote.send(method(*data[1], **data[2])) + elif cmd == "get_attr": + remote.send(getattr(env, data)) + elif cmd == "set_attr": + remote.send(setattr(env, data[0], data[1])) + else: + raise NotImplementedError(f"`{cmd}` is not implemented in the worker") + except EOFError: + break + + +def make_vec_env( + env_id: Union[str, Type[gym.Env]], + vec_env_cls, + n_envs: int = 1, + seed: Optional[int] = None, + start_index: int = 0, + env_kwargs: Optional[Dict[str, Any]] = None, +): + def make_env(rank): + def _init(): + env = env_id(**env_kwargs) + if seed is not None: + env.seed(seed + rank) + env.action_space.seed(seed + rank) + return env + return _init + + return vec_env_cls([make_env(i + start_index) for i in range(n_envs)]) + + +class LocalParallelVecEnv: + + def __init__(self, env_fns, start_method = None): + self.waiting = False + self.closed = False + n_envs = len(env_fns) + + if start_method is None: + # Fork is not a thread safe method (see issue #217) + # but is more user friendly (does not require to wrap the code in + # a `if __name__ == "__main__":`) + forkserver_available = "forkserver" in mp.get_all_start_methods() + start_method = "forkserver" if forkserver_available else "spawn" + ctx = mp.get_context(start_method) + + self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)]) + self.processes = [] + for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns): + args = (work_remote, remote, CloudpickleWrapper(env_fn)) + # daemon=True: if the main process crashes, we should not cause things to hang + process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error + process.start() + self.processes.append(process) + work_remote.close() + + self.remotes[0].send(("get_spaces", None)) + observation_space, action_space = self.remotes[0].recv() + self.num_envs = len(env_fns) + self.observation_space = observation_space + self.action_space = action_space + + def step_async(self, actions: np.ndarray) -> None: + for remote, action in zip(self.remotes, actions): + remote.send(("step", action)) + self.waiting = True + + def step_wait(self): + results = [remote.recv() for remote in self.remotes] + self.waiting = False + obs, rews, dones, infos = zip(*results) + return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos + + def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + if seed is None: + seed = np.random.randint(0, 2**32 - 1) + for idx, remote in enumerate(self.remotes): + remote.send(("seed", seed + idx)) + return [remote.recv() for remote in self.remotes] + + def reset(self): + for remote in self.remotes: + remote.send(("reset", None)) + obs = [remote.recv() for remote in self.remotes] + return _flatten_obs(obs, self.observation_space) + + def close(self) -> None: + if self.closed: + return + if self.waiting: + for remote in self.remotes: + remote.recv() + for remote in self.remotes: + remote.send(("close", None)) + for process in self.processes: + process.join() + self.closed = True + + def set_attr(self, attr_name: str, value: Any, indices = None) -> None: + """Set attribute inside vectorized environments (see base class).""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("set_attr", (attr_name, value))) + for remote in target_remotes: + remote.recv() + + def env_method(self, method_name: str, *method_args, indices = None, **method_kwargs) -> List[Any]: + """Call instance methods of vectorized environments.""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("env_method", (method_name, method_args, method_kwargs))) + return [remote.recv() for remote in target_remotes] + + def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices = None) -> List[bool]: + """Check if worker environments are wrapped with a given wrapper""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("is_wrapped", wrapper_class)) + return [remote.recv() for remote in target_remotes] + + def _get_target_remotes(self, indices) -> List[Any]: + if indices is None: + indices = range(self.num_envs) + elif isinstance(indices, int): + indices = [indices] + return [self.remotes[i] for i in indices] + + def step(self, actions: np.ndarray): + """ + Step the environments with the given action + + :param actions: the action + :return: observation, reward, done, information + """ + self.step_async(actions) + return self.step_wait() \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/metrics/__init__.py b/benchmark/torch/RL4LMs/metrics/__init__.py new file mode 100644 index 000000000..30fe430fa --- /dev/null +++ b/benchmark/torch/RL4LMs/metrics/__init__.py @@ -0,0 +1,16 @@ +from.metric_util import ( + BaseMetric, + BERTScoreMetric, + BLEUMetric, + BLEURTMetric, + BLEUToTTo, + DiversityMetrics, + LearnedRewardMetric, + MeteorMetric, + Perplexity, + RougeLMax, + RougeMetric, + SacreBLEUMetric, + TERMetric, + chrFmetric, +) \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/metric_util.py b/benchmark/torch/RL4LMs/metrics/metric_util.py similarity index 100% rename from benchmark/torch/RL4LMs/utils/metric_util.py rename to benchmark/torch/RL4LMs/metrics/metric_util.py diff --git a/benchmark/torch/RL4LMs/models/__init__.py b/benchmark/torch/RL4LMs/models/__init__.py index 0509d06e7..3a53cbfc4 100644 --- a/benchmark/torch/RL4LMs/models/__init__.py +++ b/benchmark/torch/RL4LMs/models/__init__.py @@ -1,2 +1,2 @@ -from .base_model import BasePolicy, LMActorCriticPolicy +from .base_model import BaseModel, LMActorCriticModel from .seq2seq_model import Seq2SeqLMModel \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/models/base_model.py b/benchmark/torch/RL4LMs/models/base_model.py index cca3053f3..5325182df 100644 --- a/benchmark/torch/RL4LMs/models/base_model.py +++ b/benchmark/torch/RL4LMs/models/base_model.py @@ -1,7 +1,5 @@ -from abc import abstractmethod, ABC +from abc import abstractmethod from copy import deepcopy -from dataclasses import dataclass -from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch @@ -15,8 +13,10 @@ import gym import numpy as np +import parl + +TensorDict = Dict[Union[str, int], torch.Tensor] from benchmark.torch.RL4LMs.utils import ( - Schedule, TensorDict, CategoricalDistribution, @@ -26,252 +26,33 @@ # refer to stable_baselines3.common.policies -class BaseModel(nn.Module, ABC): - """ - The base model object: makes predictions in response to observations. - - In the case of policies, the prediction is an action. In the case of critics, it is the - estimated value of the observation. - - :param observation_space: The observation space of the environment - :param action_space: The action space of the environment - :param features_extractor_class: Features extractor to use. - :param features_extractor_kwargs: Keyword arguments - to pass to the features extractor. - :param features_extractor: Network to extract features - (a CNN when using images, a nn.Flatten() layer otherwise) - :param normalize_images: Whether to normalize images or not, - dividing by 255.0 (True by default) - :param optimizer_class: The optimizer to use, - ``torch.optim.Adam`` by default - :param optimizer_kwargs: Additional keyword arguments, - excluding the learning rate, to pass to the optimizer - """ - - def __init__( - self, +class BaseModel(parl.Model): + def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - # features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - features_extractor: Optional[nn.Module] = None, - normalize_images: bool = True, optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - ): + optimizer_kwargs: Optional[Dict[str, Any]] = None,): super().__init__() - if optimizer_kwargs is None: optimizer_kwargs = {} - if features_extractor_kwargs is None: - features_extractor_kwargs = {} - self.observation_space = observation_space self.action_space = action_space - self.features_extractor = features_extractor - self.normalize_images = normalize_images self.optimizer_class = optimizer_class self.optimizer_kwargs = optimizer_kwargs - self.optimizer = None # type: Optional[torch.optim.Optimizer] - - # self.features_extractor_class = features_extractor_class - self.features_extractor_kwargs = features_extractor_kwargs + self.optimizer = None @abstractmethod def forward(self, *args, **kwargs): pass - # def _update_features_extractor( - # self, - # net_kwargs: Dict[str, Any], - # features_extractor: Optional[BaseFeaturesExtractor] = None, - # ) -> Dict[str, Any]: - # """ - # Update the network keyword arguments and create a new features extractor object if needed. - # If a ``features_extractor`` object is passed, then it will be shared. - # - # :param net_kwargs: the base network keyword arguments, without the ones - # related to features extractor - # :param features_extractor: a features extractor object. - # If None, a new object will be created. - # :return: The updated keyword arguments - # """ - # net_kwargs = net_kwargs.copy() - # if features_extractor is None: - # # The features extractor is not shared, create a new one - # features_extractor = self.make_features_extractor() - # net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim)) - # return net_kwargs - # - # def make_features_extractor(self) -> BaseFeaturesExtractor: - # """Helper method to create a features extractor.""" - # return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) - # - # def extract_features(self, obs: torch.Tensor) -> torch.Tensor: - # """ - # Preprocess the observation if needed and extract features. - # - # :param obs: - # :return: - # """ - # assert self.features_extractor is not None, "No features extractor was set" - # preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) - # return self.features_extractor(preprocessed_obs) - - def _get_constructor_parameters(self) -> Dict[str, Any]: - """ - Get data that need to be saved in order to re-create the model when loading it from disk. - - :return: The dictionary to pass to the as kwargs constructor when reconstruction this model. - """ - return dict( - observation_space=self.observation_space, - action_space=self.action_space, - # Passed to the constructor by child class - # squash_output=self.squash_output, - # features_extractor=self.features_extractor - normalize_images=self.normalize_images, - ) - - # @property - # def device(self) -> torch.device: - # """Infer which device this policy lives on by inspecting its parameters. - # If it has no parameters, the 'cpu' device is used as a fallback. - # - # :return:""" - # for param in self.parameters(): - # return param.device - # return get_device("cpu") - - def save(self, path: str) -> None: - """ - Save model to a given location. - - :param path: - """ - torch.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) - - # @classmethod - # def load(cls, path: str, device: Union[torch.device, str] = "auto") -> "BaseModel": - # """ - # Load model from patorch. - # - # :param path: - # :param device: Device on which the policy should be loaded. - # :return: - # """ - # device = get_device(device) - # saved_variables = torch.load(path, map_location=device) - # - # # Allow to load policy saved with older version of SB3 - # if "sde_net_arch" in saved_variables["data"]: - # warnings.warn( - # "sde_net_arch is deprecated, please downgrade to SB3 v1.2.0 if you need such parameter.", - # DeprecationWarning, - # ) - # del saved_variables["data"]["sde_net_arch"] - # - # # Create policy object - # model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable - # # Load weights - # model.load_state_dict(saved_variables["state_dict"]) - # model.to(device) - # return model - - def load_from_vector(self, vector: np.ndarray) -> None: - """ - Load parameters from a 1D vector. - - :param vector: - """ - torch.nn.utils.vector_to_parameters(torch.FloatTensor(vector).to(self.device), self.parameters()) - - def parameters_to_vector(self) -> np.ndarray: - """ - Convert the parameters to a 1D vector. - - :return: - """ - return torch.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() - - def set_training_mode(self, mode: bool) -> None: - """ - Put the policy in either training or evaluation mode. - - This affects certain modules, such as batch normalisation and dropout. - - :param mode: if true, set to training mode, else set to evaluation mode - """ - self.train(mode) - # - # def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[torch.Tensor, bool]: - # """ - # Convert an input observation to a PyTorch tensor that can be fed to a model. - # Includes sugar-coating to handle different observations (e.g. normalizing images). - # - # :param observation: the input observation - # :return: The observation as PyTorch tensor - # and whether the observation is vectorized or not - # """ - # vectorized_env = False - # if isinstance(observation, dict): - # # need to copy the dict as the dict in VecFrameStack will become a torch tensor - # observation = copy.deepcopy(observation) - # for key, obs in observation.items(): - # obs_space = self.observation_space.spaces[key] - # if is_image_space(obs_space): - # obs_ = maybe_transpose(obs, obs_space) - # else: - # obs_ = np.array(obs) - # vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space) - # # Add batch dimension if needed - # observation[key] = obs_.reshape((-1,) + self.observation_space[key].shape) - # - # elif is_image_space(self.observation_space): - # # Handle the different cases for images - # # as PyTorch use channel first format - # observation = maybe_transpose(observation, self.observation_space) - # - # else: - # observation = np.array(observation) - # - # if not isinstance(observation, dict): - # # Dict obs need to be handled separately - # vectorized_env = is_vectorized_observation(observation, self.observation_space) - # # Add batch dimension if needed - # observation = observation.reshape((-1,) + self.observation_space.shape) - # - # observation = obs_as_tensor(observation, self.device) - # return observation, vectorized_env - - -class BasePolicy(BaseModel): - """The base policy object. - - Parameters are mostly the same as `BaseModel`; additions are documented below. - - :param args: positional arguments passed through to `BaseModel`. - :param kwargs: keyword arguments passed through to `BaseModel`. - :param squash_output: For continuous actions, whether the output is squashed - or not using a ``tanh()`` function. - """ - - def __init__(self, *args, squash_output: bool = False, **kwargs): - super().__init__(*args, **kwargs) - self._squash_output = squash_output - @staticmethod def _dummy_schedule(progress_remaining: float) -> float: """(float) Useful for pickling policy.""" del progress_remaining return 0.0 - @property - def squash_output(self) -> bool: - """(bool) Getter for squash_output.""" - return self._squash_output @staticmethod def init_weights(module: nn.Module, gain: float = 1) -> None: @@ -295,6 +76,22 @@ def _predict(self, observation: torch.Tensor, deterministic: bool = False) -> to :param deterministic: Whether to use stochastic or deterministic actions :return: Taken action according to the policy """ + def _get_constructor_parameters(self) -> Dict[str, Any]: + return dict( + observation_space=self.observation_space, + action_space=self.action_space, + ) + + def save(self, path: str) -> None: + """ + Save model to a given location. + + :param path: + """ + torch.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) + + def set_training_mode(self, mode: bool) -> None: + self.train(mode) def predict( self, @@ -367,16 +164,14 @@ def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: low, high = self.action_space.low, self.action_space.high return low + (0.5 * (scaled_action + 1.0) * (high - low)) -class LMActorCriticPolicy(BasePolicy): +class LMActorCriticModel(BaseModel): def __init__( self, observation_space: DictSpace, action_space: Discrete, - lr_schedule: Schedule, model_name: str, optimizer_kwargs: Dict[str, Any] = {}, weight_decay: float = 1e-6, - use_sde: bool = None, apply_model_parallel: bool = True, optimizer_class: torch.optim.Optimizer = torch.optim.AdamW, generation_kwargs: Dict[str, Any] = {}, @@ -387,11 +182,9 @@ def __init__( Args: observation_space (DictSpace): Observation space action_space (Discrete): Action space - lr_schedule (Schedule): Learning rate schedule model_name (str): name of the causal or seq2seq model from transformers library optimizer_kwargs (Dict[str, Any], optional): optimizer kwargs. Defaults to {}. weight_decay (float, optional): weight decay. Defaults to 1e-6. - use_sde (bool, optional): Use state-dependent exploration. Defaults to None. (Unused parameter from stable-baselines3) apply_model_parallel (bool, optional): whether to apply model parallel. Defaults to True. optimizer_class (torch.optim.Optimizer, optional): Optimizer class. Defaults to torch.optim.AdamW. generation_kwargs (Dict[str, Any], optional): generation parameters for rollout. Defaults to {}. @@ -433,7 +226,7 @@ def forward(self, *args, **kwargs): # dummy just to comply with base policy pass - @staticmethod + def _predict( self, observation: Dict[str, torch.tensor], deterministic: bool = False ) -> torch.Tensor: diff --git a/benchmark/torch/RL4LMs/models/seq2seq_model.py b/benchmark/torch/RL4LMs/models/seq2seq_model.py index ba9cb1a74..447b99494 100644 --- a/benchmark/torch/RL4LMs/models/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/models/seq2seq_model.py @@ -18,19 +18,17 @@ PolicyType, EvaluateActionsOutput, GenerationOutputs, ) -from base_model import LMActorCriticPolicy +from .base_model import LMActorCriticModel -class Seq2SeqLMModel(LMActorCriticPolicy, ActorCriticWarmStartMixin): +class Seq2SeqLMModel(LMActorCriticModel, ActorCriticWarmStartMixin): def __init__( self, observation_space: DictSpace, action_space: Discrete, - lr_schedule: Schedule, model_name: str, optimizer_kwargs: Dict[str, Any] = {}, weight_decay: float = 1e-6, - use_sde: bool = None, apply_model_parallel: bool = True, optimizer_class: torch.optim.Optimizer = torch.optim.AdamW, generation_kwargs: Dict[str, Any] = {}, @@ -40,11 +38,9 @@ def __init__( super().__init__( observation_space, action_space, - lr_schedule, model_name, optimizer_kwargs, weight_decay, - use_sde, apply_model_parallel, optimizer_class, generation_kwargs, diff --git a/benchmark/torch/RL4LMs/utils/registry.py b/benchmark/torch/RL4LMs/registry.py similarity index 87% rename from benchmark/torch/RL4LMs/utils/registry.py rename to benchmark/torch/RL4LMs/registry.py index 1630800b2..1a468515f 100644 --- a/benchmark/torch/RL4LMs/utils/registry.py +++ b/benchmark/torch/RL4LMs/registry.py @@ -4,10 +4,10 @@ from benchmark.torch.RL4LMs.algorithms import RL4LMPPO from benchmark.torch.RL4LMs.summarization import RL4LMsSummaAgent -from .data_pool import TextGenPool, CNNDailyMail +from benchmark.torch.RL4LMs.utils import TextGenPool, CNNDailyMail # from rl4lms.envs.text_generation.alg_wrappers import wrap_onpolicy_alg -from .metric_util import ( +from benchmark.torch.RL4LMs.metrics import ( BaseMetric, BERTScoreMetric, BLEUMetric, @@ -23,18 +23,17 @@ TERMetric, chrFmetric, ) -from benchmark.torch.RL4LMs.models import LMActorCriticPolicy +from benchmark.torch.RL4LMs.models import LMActorCriticModel from benchmark.torch.RL4LMs.models import Seq2SeqLMModel -from .reward_util import ( +from benchmark.torch.RL4LMs.utils import ( BERTScoreRewardFunction, BLEURewardFunction, BLEURTRewardFunction, CommonGenPenaltyShapingFunction, LearnedRewardFunction, MeteorRewardFunction, - PARENTRewardFunction, RewardFunction, RougeCombined, RougeLMaxRewardFunction, @@ -70,7 +69,6 @@ class RewardFunctionRegistry: "bleurt": BLEURTRewardFunction, "rouge_combined": RougeCombined, "common_gen_repeat_penalty": CommonGenPenaltyShapingFunction, - "parent": PARENTRewardFunction, "sacre_bleu": SacreBleu, "rouge_l_max": RougeLMaxRewardFunction, } @@ -117,19 +115,15 @@ def add(cls, id: str, metric_cls: Type[BaseMetric]): MetricRegistry._registry[id] = metric_cls -class PolicyRegistry: +class ModelRegistry: _registry = { - "seq2seq_lm_actor_critic_policy": Seq2SeqLMModel, + "seq2seq_lm_actor_critic_model": Seq2SeqLMModel, } @classmethod - def get(cls, policy_id: str) -> Type[LMActorCriticPolicy]: - policy_cls = cls._registry[policy_id] - return policy_cls - - @classmethod - def add(cls, id: str, policy_cls: Type[LMActorCriticPolicy]): - PolicyRegistry._registry[id] = policy_cls + def get(cls, model_id: str) -> Type[LMActorCriticModel]: + model_cls = cls._registry[model_id] + return model_cls class AlgorithmRegistry: diff --git a/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py b/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py index a829300ac..d07200c02 100644 --- a/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py +++ b/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py @@ -1,66 +1,10 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import parl -import torch import numpy as np -from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Type, Union -import numpy as np +from typing import List import torch -from benchmark.torch.RL4LMs.utils import DictRolloutBuffer, RolloutBuffer, TransitionInfo, TensorDict,\ - BatchedRewardFunction, RewardFunction, PolicyOutput, RefPolicyOutput, ValueOutput, \ - MaskableDictRolloutBuffer, OnPolicyWarmStartMixin, KLController, Tracker - -from transformers import PreTrainedTokenizer - - - -def obs_as_tensor( - obs: Union[np.ndarray, Dict[Union[str, int], np.ndarray]], device: torch.device -) -> Union[torch.Tensor, TensorDict]: - """ - Moves the observation to the given device. - - :param obs: - :param device: PyTorch device - :return: PyTorch tensor of the observation on a desired device. - """ - if isinstance(obs, np.ndarray): - return torch.as_tensor(obs).to(device) - elif isinstance(obs, dict): - return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} - else: - raise Exception(f"Unrecognized type of observation {type(obs)}") - - - - - -def unpack_observations(obs_tensor, n_envs: int): - """ - Unpacks vectorized dict observations into separate dict observations - """ - unpacked_obs = [] - keys = obs_tensor.keys() - for env_ix in range(n_envs): - obs_dict = {} - for key in keys: - obs_dict[key] = obs_tensor[key][env_ix].reshape(1, -1).cpu() - unpacked_obs.append(obs_dict) - return unpacked_obs +from benchmark.torch.RL4LMs.utils import TransitionInfo,\ + RewardFunction, Tracker def compute_batched_rewards( @@ -96,340 +40,109 @@ def compute_batched_rewards( ) -def wrap_onpolicy_alg( - alg_class, - alg_kwargs: Dict[str, Any], - kl_coeff: float, - tracker: Tracker, - target_kl: float = None, - norm_reward: bool = False, -): - class OnPolicyAlgText(alg_class, OnPolicyWarmStartMixin): - def __init__( - self, - alg_kwargs: Dict[str, Any], - kl_coeff: float, - tracker: Tracker, - target_kl: float = None, - norm_reward: bool = False, - ): - alg_kwargs["tracker"] = tracker - super().__init__(**alg_kwargs) - self._kl_controller = KLController(kl_coeff, target_kl) - self.tracker = tracker - self._norm_reward = norm_reward - # flattened rollout buffer - self.rollout_buffer = MaskableDictRolloutBuffer( - self.n_steps * self.env.num_envs, - self.observation_space, - self.action_space, - device=self.device, - gamma=self.gamma, - gae_lambda=self.gae_lambda, - n_envs=1, - ) - self.reward_fn = self.env.get_attr("reward_function", 0)[0] - - def get_policy_kwargs( - self, - obs: TensorDict, - action: torch.tensor, - past_state: Dict[str, torch.tensor], - action_mask: torch.tensor, - ): - - policy_kwargs = { - "obs": obs, - "actions": action, - "past_model_kwargs": past_state, - } - if action_mask is not None: - policy_kwargs["action_masks"] = action_mask - return policy_kwargs - - def generate_batch( - self, - rollout_buffer: DictRolloutBuffer, - tokenizer: PreTrainedTokenizer, - max_steps: int, - rollout_info: Dict[str, Any], - ): - # if rollout buffer is already full, do not continue - if rollout_buffer.full: - return - - # start parallel episodes - current_obs = self.env.reset() - episode_starts = np.ones((self.env.num_envs,), dtype=bool) - - # generate text using the model - obs_tensor = obs_as_tensor(current_obs, self.device) - generation_inputs = self.policy.get_inputs_for_generation(obs_tensor) - gen_output = self.policy.generate( - input_ids=generation_inputs.inputs, - attention_mask=generation_inputs.attention_masks, - tokenizer=tokenizer, - ) - - # process them one step at a time to collect rollout info - episode_wise_transitions = [[] for _ in range(self.env.num_envs)] - ep_terminated = np.zeros((self.env.num_envs,), dtype=bool) - value_past_state = None - ref_past_state = None - policy_past_state = None - masks = ( - gen_output.action_masks - if gen_output.action_masks is not None - else [None] * len(gen_output.step_wise_logprobs) - ) - - for actions_tensor, _, action_mask in zip( - gen_output.step_wise_actions, gen_output.step_wise_logprobs, masks - ): - # if all episodes are done, just break and do not continue - if np.all(ep_terminated): - break - - # evaluate actions with actions from rollout - with torch.no_grad(): - obs_tensor = obs_as_tensor(current_obs, self.device) - - # get log probs (TBD: generalize this a bit) - policy_kwargs = self.get_policy_kwargs( - obs_tensor, actions_tensor, policy_past_state, action_mask - ) - - policy_outputs: PolicyOutput = self.policy.forward_policy( - **policy_kwargs - ) - raw_log_probs, log_probs, policy_past_state = ( - policy_outputs.raw_log_probs, - policy_outputs.log_probs, - policy_outputs.past_model_kwargs, - ) - - # sanity check - assert torch.all( - torch.isfinite(log_probs) - ), "Infinite values in log probs" - - # sanity check - assert torch.all( - torch.isfinite(raw_log_probs) - ), "Infinite values in log probs" - - # get values - value_outputs: ValueOutput = self.policy.forward_value( - obs_tensor, value_past_state - ) - values, value_past_state = ( - value_outputs.values, - value_outputs.past_model_kwargs, - ) - - # get reference log probs - ref_policy_outputs: RefPolicyOutput = ( - self.policy.get_log_probs_ref_model( - obs_tensor, actions_tensor, ref_past_state - ) - ) - ref_log_probs, ref_past_state = ( - ref_policy_outputs.log_probs, - ref_policy_outputs.past_model_kwargs, - ) - - # sanity check - assert torch.all( - torch.isfinite(ref_log_probs) - ), "Infinite values in log probs" - - # compute KL rewards - kl_div = raw_log_probs - ref_log_probs - kl_rewards = -1 * self._kl_controller.kl_coeff * kl_div - - # step into env to get rewards - actions = actions_tensor.cpu().numpy() - new_obs, rewards, dones, infos = self.env.step(actions) - - self.num_timesteps += self.env.num_envs - - # compute total rewards - total_rewards = rewards + kl_rewards.cpu().numpy() - - # unpack individual observations - unpacked_obs = unpack_observations(obs_tensor, self.env.num_envs) - - # store episode wise transitions separately - for env_ix in range(self.env.num_envs): - # only if not terminated already - if not ep_terminated[env_ix]: - transtion = TransitionInfo( - observation=unpacked_obs[env_ix], - action=actions[env_ix], - task_reward=rewards[env_ix], - total_reward=total_rewards[env_ix], - kl_div=kl_div.cpu().numpy()[env_ix], - episode_start=episode_starts[env_ix], - value=values[env_ix].cpu(), - log_prob=log_probs[env_ix].cpu(), - done=dones[env_ix], - ref_log_prob=ref_log_probs[env_ix].cpu(), - kl_reward=kl_rewards.cpu().numpy()[env_ix], - action_mask=action_mask[env_ix].cpu().numpy() - if action_mask is not None - else None, - info=infos[env_ix], - ) - - episode_wise_transitions[env_ix].append(transtion) - - # mark this episode to terminated if done occurs once - if dones[env_ix]: - ep_terminated[env_ix] = True - - episode_starts = np.zeros((self.env.num_envs,), dtype=bool) - current_obs = new_obs - - # now we flush all episode wise info to the 1-D buffer - rollout_info = self._add_to_buffer( - rollout_buffer, episode_wise_transitions, rollout_info - ) - return rollout_info - - def _add_to_buffer( - self, rollout_buffer, episode_wise_transitions, rollout_info - ): - # if the reward function is batchable, we override the rewards here - if isinstance(self.reward_fn, BatchedRewardFunction): - compute_batched_rewards(episode_wise_transitions, self.reward_fn) - - advantages_computed = False - for ep_ix, transitions in enumerate(episode_wise_transitions): - ep_length = len(transitions) - total_reward = 0.0 - total_kl_reward = 0.0 - for transition_ix, transition in enumerate(transitions): - total_reward += transition.task_reward - total_kl_reward += transition.kl_reward - rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div) - rollout_info["rollout_info/log_prob"].append(transition.log_prob) - rollout_info["rollout_info/ref_log_prob"].append( - transition.ref_log_prob - ) - rollout_info["rollout_info/values"].append(transition.value.numpy()) - - if not rollout_buffer.full: - rollout_buffer.add( - transition.observation, - transition.action, - transition.total_reward, - transition.episode_start, - transition.value, - transition.log_prob, - action_masks=transition.action_mask, - ) - - # if the buffer is full, compute advantages - if rollout_buffer.full and not advantages_computed: - - # normalize the rewards - if self._norm_reward: - mean = rollout_buffer.rewards.mean() - std = rollout_buffer.rewards.std() - rollout_buffer.rewards = (rollout_buffer.rewards - mean) / ( - std + 1e-8 - ) - - # we fetch the last value for the last time step - # values come from the next transitions's values - next_values = ( - transitions[transition_ix + 1].value - if (transition_ix + 1) < ep_length - else torch.tensor([0.0]) - ) - - rollout_buffer.compute_returns_and_advantage( - last_values=next_values, dones=transition.done - ) - advantages_computed = True - - rollout_info["rollout_info/ep_rew"].append(total_reward) - rollout_info["rollout_info/ep_lens"].append(ep_length) - rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward) - return rollout_info - - def collect_rollouts( - self, - env, - rollout_buffer: RolloutBuffer, - n_rollout_steps: int, - ) -> bool: - # max episode steps - max_steps = env.unwrapped.get_attr("max_steps", [0])[0] - - # get tokenizer - tokenizer = env.unwrapped.get_attr("tokenizer", [0]) - tokenizer = tokenizer[0] - - # Switch to eval mode - self.policy.set_training_mode(False) - - # reset rollout buffer and stats - rollout_buffer.reset() - - # start the rollout process - rollout_info = { - "rollout_info/ep_rew": [], - "rollout_info/kl_div_mean": [], - "rollout_info/ep_lens": [], - "rollout_info/ep_kl_rew": [], - "rollout_info/log_prob": [], - "rollout_info/ref_log_prob": [], - "rollout_info/values": [], - } - while not rollout_buffer.full: - # generate batch of rollouts - rollout_info = self.generate_batch( - rollout_buffer, tokenizer, max_steps, rollout_info - ) - - # aggregate rollout info - aggregated_rollout_info = {} - for key, values in rollout_info.items(): - aggregated_rollout_info[key] = np.mean(values).item() - aggregated_rollout_info[f"{key}_std"] = np.std(values).item() - aggregated_rollout_info[ - "rollout_info/kl_coeff" - ] = self._kl_controller.kl_coeff - - if self.tracker is not None: - self.tracker.log_rollout_infos(aggregated_rollout_info) - - # adapt the KL coeff - self._kl_controller.step( - torch.tensor(aggregated_rollout_info["rollout_info/kl_div_mean"]) - ) - return True - - # instantiate the wrapped alg - alg = OnPolicyAlgText(alg_kwargs, kl_coeff, tracker, target_kl, norm_reward) - return alg - - +def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: + """ + Computes fraction of variance that ypred explains about y. + Returns 1 - Var[y-ypred] / Var[y] + interpretation: + ev=0 => might as well have predicted zero + ev=1 => perfect prediction + ev<0 => worse than just predicting zero + """ + assert y_true.ndim == 1 and y_pred.ndim == 1 + var_y = np.var(y_true) + return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y class RL4LMsSummaAgent(parl.Agent): - def __init__(self, algorithm, config): + def __init__(self, + algorithm, + alg_config, + tracker: Tracker, + norm_reward: bool = False, + ): super(RL4LMsSummaAgent, self).__init__(algorithm) self.dataset = None - self.config = config - - def learn(self, *args, **kwargs): - pass + self.config = alg_config + self.n_epochs = alg_config["n_epochs"] + self._tracker = tracker + self._norm_reward = norm_reward + self._n_updates = 0 + + + + + def learn(self, rollout_buffer): + entropy_losses = [] + pg_losses, value_losses = [], [] + clip_fractions = [] + approx_kl_divs = [] + log_info = { + "entropy_losses": entropy_losses, + "pg_losses": entropy_losses, + "value_losses": value_losses, + "clip_fractions": clip_fractions, + "approx_kl_divs": approx_kl_divs + } + + continue_training = True + loss = torch.tensor(0.0) + + # train for n_epochs epochs + for epoch in range(self.n_epochs): + continue_training, loss = self.alg.learn(rollout_buffer=rollout_buffer, + log_info=log_info) + if not continue_training: + print( + f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_divs[-1]:.2f}") + break + + self._n_updates += self.n_epochs + explained_var = explained_variance( + rollout_buffer.values.flatten(), rollout_buffer.returns.flatten()) + + # Logs + print("train/entropy_loss", np.mean(entropy_losses)) + print("train/policy_gradient_loss", np.mean(pg_losses)) + print("train/value_loss", np.mean(value_losses)) + print("train/approx_kl", np.mean(approx_kl_divs)) + print("train/clip_fraction", np.mean(clip_fractions)) + print("train/loss", loss.item()) + print("train/explained_variance", explained_var) + # self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + # self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + # self.logger.record("train/value_loss", np.mean(value_losses)) + # self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + # self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + # self.logger.record("train/loss", loss.item()) + # self.logger.record("train/explained_variance", explained_var) + if hasattr(self.alg.model, "log_std"): + # self.logger.record( + # "train/std", torch.exp(self.policy.log_std).mean().item()) + print("train/std", torch.exp(self.alg.model.log_std).mean().item()) + + # self.logger.record("train/n_updates", + # self._n_updates, exclude="tensorboard") + # self.logger.record("train/clip_range", clip_range) + print("train/n_updates", self._n_updates) + print("train/clip_range", self.alg.clip_range) + + train_info = { + "ppo/entropy_loss": np.mean(entropy_losses).item(), + "ppo/policy_gradient_loss": np.mean(pg_losses).item(), + "ppo/value_loss": np.mean(value_losses).item(), + "ppo/approx_kl": np.mean(approx_kl_divs).item(), + } + + self._tracker.log_training_infos(train_info) + # for k, v in train_info.items(): + # print(f"{k}: {v}") def predict(self, *args, **kwargs): pass def sample(self, *args, **kwargs): pass + + diff --git a/benchmark/torch/RL4LMs/summarization/rl4lms_summa_model.py b/benchmark/torch/RL4LMs/summarization/rl4lms_summa_model.py deleted file mode 100644 index 7bcc5588a..000000000 --- a/benchmark/torch/RL4LMs/summarization/rl4lms_summa_model.py +++ /dev/null @@ -1,7 +0,0 @@ -import parl -import torch -import torch.nn as nn - - -class RL4LMsSummaModel(parl.Model): - pass \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 8ee888815..7ab7a2563 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -67,13 +67,16 @@ def main(config): help="Base path to store experiment results", default=os.getcwd(), ) + parser.add_argument( + "--entity_name", type=str, help="entity name", default=None + ) args = parser.parse_args() # load the config file with open(args.config_path, "r") as fp: config = yaml.safe_load(fp) - recursive_dict_update(config, args) + recursive_dict_update(config, vars(args)) main(config) diff --git a/benchmark/torch/RL4LMs/trainers.py b/benchmark/torch/RL4LMs/trainers.py index 78bd390ac..7db9dbf02 100644 --- a/benchmark/torch/RL4LMs/trainers.py +++ b/benchmark/torch/RL4LMs/trainers.py @@ -3,19 +3,14 @@ from typing import Any, Dict, List import numpy as np -from benchmark.torch.RL4LMs.utils import Sample + +from benchmark.torch.RL4LMs.utils import Sample, Tracker, RewardFunction,\ + evaluate_on_samples, TrainerWarmStartMixin,\ + KLController, RolloutBuffer, DictRolloutBuffer, MaskableDictRolloutBuffer,\ + TransitionInfo, TensorDict, RefPolicyOutput, ValueOutput, PolicyOutput +from benchmark.torch.RL4LMs.registry import DataPoolRegistry, MetricRegistry, RewardFunctionRegistry, \ + ModelRegistry, AlgorithmRegistry from benchmark.torch.RL4LMs.env import TextGenEnv -from rl4lms.envs.text_generation.evaluation_utils import evaluate_on_samples -from rl4lms.envs.text_generation.logging_utils import Tracker -from rl4lms.envs.text_generation.registry import (DataPoolRegistry, - MetricRegistry, - RewardFunctionRegistry, - PolicyRegistry, - AlgorithmRegistry, - WrapperRegistry) -from rl4lms.envs.text_generation.reward import RewardFunction -from stable_baselines3.common.env_util import make_vec_env -from stable_baselines3.common.vec_env import SubprocVecEnv from transformers import (AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, @@ -23,11 +18,11 @@ TrainingArguments, DataCollatorForLanguageModeling, DataCollatorForSeq2Seq) - -from rl4lms.envs.text_generation.warm_start import TrainerWarmStartMixin - - - +from benchmark.torch.RL4LMs.env import LocalParallelVecEnv, make_vec_env +from transformers import PreTrainedTokenizer +from benchmark.torch.RL4LMs.summarization import RL4LMsSummaAgent +from benchmark.torch.RL4LMs.algorithms import RL4LMPPO +import torch def build_tokenizer(tokenizer_config: Dict[str, Any]): tokenizer = AutoTokenizer.from_pretrained( @@ -85,40 +80,57 @@ def build_env(env_config: Dict[str, Any], "samples": train_samples, } env_kwargs = {**env_kwargs, **env_config.get("args", {})} - env = make_vec_env(TextGenEnv, + envs = make_vec_env(TextGenEnv, n_envs=env_config.get( "n_envs", 1), - vec_env_cls=SubprocVecEnv, + vec_env_cls=LocalParallelVecEnv, env_kwargs=env_kwargs) - return env - - -def build_alg(alg_config: Dict[str, Any], - env: TextGenEnv, - tracker: Tracker, - policy_state: Dict[str, Any], - alg_state: Dict[str, Any]): - # TBD - move these to a registry once the experimentation is done - # Also switch to Sb3 algos when possible with minimal code adaptations - policy_config = alg_config["policy"] - policy_cls = PolicyRegistry.get(policy_config["id"]) + return envs + +def build_agent(alg_config: Dict[str, Any], + env: LocalParallelVecEnv, + tracker: Tracker, + model_state: Dict[str, Any] = None, # TODO: save model checkpoint + alg_state: Dict[str, Any] = None # TODO: save alg checkpoint + ): + model_config = alg_config["model"] + model_cls = ModelRegistry.get(model_config["id"]) alg_cls = AlgorithmRegistry.get(alg_config["id"]) - policy_args = policy_config["args"] - policy_args["state_dict"] = policy_state - alg_kwargs = { - "policy": policy_cls, - "env": env, - "policy_kwargs": policy_args, - } - alg_kwargs = {**alg_kwargs, **alg_config.get("args")} - wrapper = WrapperRegistry.get(alg_config["id"]) - alg = wrapper(alg_cls, alg_kwargs, - alg_config["kl_div"]["coeff"], tracker, - alg_config["kl_div"].get("target_kl", None), - alg_config["kl_div"].get("norm_reward", False)) - alg.load_from_dict(alg_state) - return alg + model_args = model_config["args"] + model_args["state_dict"] = model_state + + rl4lms_model = model_cls( + observation_space = env.observation_space, + action_space= env.action_space, + **model_args + ) + + rl4lm_alg_cls = alg_cls( + model=rl4lms_model, + **alg_config.get("args") + ) + + rl4lm_agent = RL4LMsSummaAgent(rl4lm_alg_cls, alg_config, tracker) + return rl4lm_agent + + +def dict_to_tensor(obs, device): + return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} + + +def unpack_observations(obs_tensor, n_envs: int): + """ + Unpacks vectorized dict observations into separate dict observations + """ + unpacked_obs = [] + keys = obs_tensor.keys() + for env_ix in range(n_envs): + obs_dict = {} + for key in keys: + obs_dict[key] = obs_tensor[key][env_ix].reshape(1, -1).cpu() + unpacked_obs.append(obs_dict) + return unpacked_obs class OnPolicyTrainer(TrainerWarmStartMixin): @@ -144,9 +156,18 @@ def __init__(self, self._train_eval_config = train_eval_config self._tracker = tracker self._experiment_name = experiment_name + self._agent = None + self._env = None + self.num_timesteps = None + self._kl_controller = None + self.device = torch.device("cuda" if torch.cuda. + is_available() else "cpu") + self._norm_reward = False + self._setup() def _setup(self): + # load trainer state from available previous checkpoint if available self.load_trainer_state(self._tracker) @@ -158,17 +179,32 @@ def _setup(self): self._samples_by_split = build_datapool(self._datapool_config) self._env = build_env(self._env_config, self._reward_fn, self._tokenizer, self._samples_by_split["train"]) - self._alg = build_alg(self._on_policy_alg_config, - self._env, self._tracker, - self._policy_state_dict, - self._alg_state_dict) + + + self._agent = build_agent(self._on_policy_alg_config, + self._env, self._tracker) + + self._rollout_buffer = MaskableDictRolloutBuffer( + buffer_size=self._agent.alg.n_steps * self._env.num_envs, + observation_space=self._agent.alg.model.observation_space, + action_space=self._agent.alg.model.action_space, + device=self.device, + gamma=self._agent.alg.gamma, + gae_lambda=self._agent.alg.gae_lambda, + n_envs=1, + ) + + self._kl_controller = KLController( + self._on_policy_alg_config["kl_div"]["coeff"], + self._on_policy_alg_config["kl_div"].get("norm_reward", False)) # extract train params self._max_episode_length = self._env_config["args"]["max_episode_length"] self._max_prompt_length = self._env_config["args"]["max_prompt_length"] self._eval_batch_size = self._train_eval_config["eval_batch_size"] self._n_iters = int(self._train_eval_config["n_iters"]) - self._n_steps_per_iter = self._env.num_envs * self._alg.n_steps + self._n_steps_per_iter = self._env.num_envs * self._agent.alg.n_steps + self._num_timesteps = 0 # gen kwargs for evaluation (if it is different from rollout gen kwargs) self._eval_gen_kwargs = self._train_eval_config.get( @@ -177,7 +213,7 @@ def _setup(self): def _evaluate_on_datapools(self, epoch: int, splits: List[str] = ["val", "test"]): for split in splits: - evaluate_on_samples(policy=self._alg.policy, + evaluate_on_samples(policy=self._agent.alg.model, tokenizer=self._tokenizer, samples=self._samples_by_split[split], batch_size=self._eval_batch_size, @@ -198,8 +234,13 @@ def train_and_eval(self): # current state self._trainer_state["current_iter"] = epoch - # inner rollout and learn loop for on-policy algorithm - self._alg.learn(self._n_steps_per_iter) + self._num_timesteps = 0 + + while self._num_timesteps < self._n_steps_per_iter: + self.collect_rollouts(self._env, self._rollout_buffer) + # inner rollout and learn loop for on-policy algorithm + # self._agent.learn(self._n_steps_per_iter) + self._agent.learn(self._rollout_buffer) # save the policy checkpoint if (epoch + 1) % self._train_eval_config.get("save_every", 20) == 0: @@ -216,4 +257,285 @@ def train_and_eval(self): # save model here - we save only the language model if self._tracker is not None: self._tracker.save_auto_model( - self._alg.policy.get_language_model()) \ No newline at end of file + self._alg.policy.get_language_model()) + + + def get_policy_kwargs( + self, + obs: TensorDict, + action: torch.tensor, + past_state: Dict[str, torch.tensor], + action_mask: torch.tensor, + ): + + policy_kwargs = { + "obs": obs, + "actions": action, + "past_model_kwargs": past_state, + } + if action_mask is not None: + policy_kwargs["action_masks"] = action_mask + return policy_kwargs + + def generate_batch( + self, + rollout_buffer, + tokenizer: PreTrainedTokenizer, + max_steps: int, + rollout_info: Dict[str, Any], + ): + # if rollout buffer is already full, do not continue + if rollout_buffer.full: + return + + # start parallel episodes + current_obs = self._env.reset() + episode_starts = np.ones((self._env.num_envs,), dtype=bool) + + # generate text using the model + obs_tensor = dict_to_tensor(current_obs, self.device) + generation_inputs = self._agent.model.get_inputs_for_generation(obs_tensor) + gen_output = self._agent.model.generate( + input_ids=generation_inputs.inputs, + attention_mask=generation_inputs.attention_masks, + tokenizer=tokenizer, + ) + + # process them one step at a time to collect rollout info + episode_wise_transitions = [[] for _ in range(self._env.num_envs)] + ep_terminated = np.zeros((self._env.num_envs,), dtype=bool) + value_past_state = None + ref_past_state = None + policy_past_state = None + masks = ( + gen_output.action_masks + if gen_output.action_masks is not None + else [None] * len(gen_output.step_wise_logprobs) + ) + + for actions_tensor, _, action_mask in zip( + gen_output.step_wise_actions, gen_output.step_wise_logprobs, masks + ): + # if all episodes are done, just break and do not continue + if np.all(ep_terminated): + break + + # evaluate actions with actions from rollout + with torch.no_grad(): + obs_tensor = dict_to_tensor(current_obs, self.device) + + # get log probs (TBD: generalize this a bit) + policy_kwargs = self.get_policy_kwargs( + obs_tensor, actions_tensor, policy_past_state, action_mask + ) + + policy_outputs: PolicyOutput = self.policy.forward_policy( + **policy_kwargs + ) + raw_log_probs, log_probs, policy_past_state = ( + policy_outputs.raw_log_probs, + policy_outputs.log_probs, + policy_outputs.past_model_kwargs, + ) + + # sanity check + assert torch.all( + torch.isfinite(log_probs) + ), "Infinite values in log probs" + + # sanity check + assert torch.all( + torch.isfinite(raw_log_probs) + ), "Infinite values in log probs" + + # get values + value_outputs: ValueOutput = self.policy.forward_value( + obs_tensor, value_past_state + ) + values, value_past_state = ( + value_outputs.values, + value_outputs.past_model_kwargs, + ) + + # get reference log probs + ref_policy_outputs: RefPolicyOutput = ( + self.policy.get_log_probs_ref_model( + obs_tensor, actions_tensor, ref_past_state + ) + ) + ref_log_probs, ref_past_state = ( + ref_policy_outputs.log_probs, + ref_policy_outputs.past_model_kwargs, + ) + + # sanity check + assert torch.all( + torch.isfinite(ref_log_probs) + ), "Infinite values in log probs" + + # compute KL rewards + kl_div = raw_log_probs - ref_log_probs + kl_rewards = -1 * self._kl_controller.kl_coeff * kl_div + + # step into env to get rewards + actions = actions_tensor.cpu().numpy() + new_obs, rewards, dones, infos = self._env.step(actions) + + self._num_timesteps += self._env.num_envs + + # compute total rewards + total_rewards = rewards + kl_rewards.cpu().numpy() + + # unpack individual observations + unpacked_obs = unpack_observations(obs_tensor, self._env.num_envs) + + # store episode wise transitions separately + for env_ix in range(self._env.num_envs): + # only if not terminated already + if not ep_terminated[env_ix]: + transtion = TransitionInfo( + observation=unpacked_obs[env_ix], + action=actions[env_ix], + task_reward=rewards[env_ix], + total_reward=total_rewards[env_ix], + kl_div=kl_div.cpu().numpy()[env_ix], + episode_start=episode_starts[env_ix], + value=values[env_ix].cpu(), + log_prob=log_probs[env_ix].cpu(), + done=dones[env_ix], + ref_log_prob=ref_log_probs[env_ix].cpu(), + kl_reward=kl_rewards.cpu().numpy()[env_ix], + action_mask=action_mask[env_ix].cpu().numpy() + if action_mask is not None + else None, + info=infos[env_ix], + ) + + episode_wise_transitions[env_ix].append(transtion) + + # mark this episode to terminated if done occurs once + if dones[env_ix]: + ep_terminated[env_ix] = True + + episode_starts = np.zeros((self._env.num_envs,), dtype=bool) + current_obs = new_obs + + # now we flush all episode wise info to the 1-D buffer + rollout_info = self._add_to_buffer( + rollout_buffer, episode_wise_transitions, rollout_info + ) + return rollout_info + + def _add_to_buffer( + self, rollout_buffer, episode_wise_transitions, rollout_info + ): + # if the reward function is batchable, we override the rewards here + # if isinstance(self.reward_fn, BatchedRewardFunction): + # compute_batched_rewards(episode_wise_transitions, self.reward_fn) + + advantages_computed = False + for ep_ix, transitions in enumerate(episode_wise_transitions): + ep_length = len(transitions) + total_reward = 0.0 + total_kl_reward = 0.0 + for transition_ix, transition in enumerate(transitions): + total_reward += transition.task_reward + total_kl_reward += transition.kl_reward + rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div) + rollout_info["rollout_info/log_prob"].append(transition.log_prob) + rollout_info["rollout_info/ref_log_prob"].append( + transition.ref_log_prob + ) + rollout_info["rollout_info/values"].append(transition.value.numpy()) + + if not rollout_buffer.full: + rollout_buffer.add( + transition.observation, + transition.action, + transition.total_reward, + transition.episode_start, + transition.value, + transition.log_prob, + action_masks=transition.action_mask, + ) + + # if the buffer is full, compute advantages + if rollout_buffer.full and not advantages_computed: + + # normalize the rewards + if self._norm_reward: + mean = rollout_buffer.rewards.mean() + std = rollout_buffer.rewards.std() + rollout_buffer.rewards = (rollout_buffer.rewards - mean) / ( + std + 1e-8 + ) + + # we fetch the last value for the last time step + # values come from the next transitions's values + next_values = ( + transitions[transition_ix + 1].value + if (transition_ix + 1) < ep_length + else torch.tensor([0.0]) + ) + + rollout_buffer.compute_returns_and_advantage( + last_values=next_values, dones=transition.done + ) + advantages_computed = True + + rollout_info["rollout_info/ep_rew"].append(total_reward) + rollout_info["rollout_info/ep_lens"].append(ep_length) + rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward) + return rollout_info + + def collect_rollouts( + self, + env, + rollout_buffer: RolloutBuffer, + ) -> bool: + # max episode steps + max_steps = env.unwrapped.get_attr("max_steps", [0])[0] + + # get tokenizer + tokenizer = env.unwrapped.get_attr("tokenizer", [0]) + tokenizer = tokenizer[0] + + # Switch to eval mode + self._agent.alg.model.set_training_mode(False) + + # reset rollout buffer and stats + rollout_buffer.reset() + + # start the rollout process + rollout_info = { + "rollout_info/ep_rew": [], + "rollout_info/kl_div_mean": [], + "rollout_info/ep_lens": [], + "rollout_info/ep_kl_rew": [], + "rollout_info/log_prob": [], + "rollout_info/ref_log_prob": [], + "rollout_info/values": [], + } + while not rollout_buffer.full: + # generate batch of rollouts + rollout_info = self.generate_batch( + rollout_buffer, tokenizer, max_steps, rollout_info + ) + + # aggregate rollout info + aggregated_rollout_info = {} + for key, values in rollout_info.items(): + aggregated_rollout_info[key] = np.mean(values).item() + aggregated_rollout_info[f"{key}_std"] = np.std(values).item() + aggregated_rollout_info[ + "rollout_info/kl_coeff" + ] = self._kl_controller.kl_coeff + + # if self.tracker is not None: + # self.tracker.log_rollout_infos(aggregated_rollout_info) + + # adapt the KL coeff + self._kl_controller.step( + torch.tensor(aggregated_rollout_info["rollout_info/kl_div_mean"]) + ) + return True \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/__init__.py b/benchmark/torch/RL4LMs/utils/__init__.py index ec9908986..9e2ea3014 100644 --- a/benchmark/torch/RL4LMs/utils/__init__.py +++ b/benchmark/torch/RL4LMs/utils/__init__.py @@ -5,14 +5,12 @@ from .huggingface_generation_util import override_generation_routines -from .warm_start import ActorCriticWarmStartMixin, OnPolicyWarmStartMixin +from .warm_start import ActorCriticWarmStartMixin, OnPolicyWarmStartMixin, TrainerWarmStartMixin from .type_wrapper import TensorDict, Schedule from .distribution_wrapper import CategoricalDistribution -from .reward_util import RewardFunction, BatchedRewardFunction - from .sample_util import PrioritySampler from .buffer import DictRolloutBuffer, RolloutBuffer,\ @@ -22,3 +20,10 @@ from .tracker import Tracker +from .evaluation_util import evaluate_on_samples + +from .data_pool import TextGenPool, CNNDailyMail + +from .reward_util import RewardFunction, RougeRewardFunction, RougeLMaxRewardFunction, \ + BatchedRewardFunction, BERTScoreRewardFunction, BLEURewardFunction, BLEURTRewardFunction, MeteorRewardFunction,\ + LearnedRewardFunction, SacreBleu, CommonGenPenaltyShapingFunction, RougeCombined diff --git a/benchmark/torch/RL4LMs/utils/data_wrapper.py b/benchmark/torch/RL4LMs/utils/data_wrapper.py index 4917dd49f..a85cd291c 100644 --- a/benchmark/torch/RL4LMs/utils/data_wrapper.py +++ b/benchmark/torch/RL4LMs/utils/data_wrapper.py @@ -4,7 +4,6 @@ from transformers import AutoTokenizer from copy import deepcopy from .type_wrapper import TensorDict -import torch from typing import NamedTuple import torch import numpy as np diff --git a/benchmark/torch/RL4LMs/utils/evaluation_util.py b/benchmark/torch/RL4LMs/utils/evaluation_util.py index 5bb317d71..c9f7319d6 100644 --- a/benchmark/torch/RL4LMs/utils/evaluation_util.py +++ b/benchmark/torch/RL4LMs/utils/evaluation_util.py @@ -1,6 +1,5 @@ from typing import Any, Dict, List -from benchmark.torch.RL4LMs.models import BasePolicy from tqdm import tqdm from transformers import AutoTokenizer @@ -18,7 +17,7 @@ def get_batch(samples: List[Sample], batch_size: int): def evaluate_on_samples( - policy: BasePolicy, + policy, tokenizer: AutoTokenizer, samples: List[Sample], batch_size: int, @@ -100,16 +99,15 @@ def evaluate_on_samples( sample_predictions_dict.append(sample_prediction) - # TODO: change tracker to parl logging - # if tracker is not None: - # # log the entire predictions - # tracker.log_predictions(epoch, split_name, sample_predictions_dict) - # # log the corpus level scores - # tracker.log_metrics(epoch, split_name, corpus_level_metrics) + if tracker is not None: + # log the entire predictions + tracker.log_predictions(epoch, split_name, sample_predictions_dict) + # log the corpus level scores + tracker.log_metrics(epoch, split_name, corpus_level_metrics) def generate_text( - policy: BasePolicy, + policy, tokenizer: AutoTokenizer, samples: List[Sample], max_prompt_length: int, diff --git a/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py b/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py index a83cd2284..b54644216 100644 --- a/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py +++ b/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py @@ -3478,6 +3478,7 @@ def top_k_top_p_filtering( return logits + def override_generation_routines(cls): bases = list(cls.__bases__) for base_ix in range(len(bases)): diff --git a/benchmark/torch/RL4LMs/utils/reward_util.py b/benchmark/torch/RL4LMs/utils/reward_util.py index b62443a87..7ab5da1bc 100644 --- a/benchmark/torch/RL4LMs/utils/reward_util.py +++ b/benchmark/torch/RL4LMs/utils/reward_util.py @@ -4,7 +4,7 @@ from datasets import load_metric from .data_wrapper import Observation from transformers import AutoModelForSequenceClassification, AutoTokenizer -from .metric_util import ( +from benchmark.torch.RL4LMs.metrics import ( MeteorMetric, BERTScoreMetric, BLEUMetric, @@ -96,7 +96,7 @@ class MeteorRewardFunction(RewardFunction): def __init__(self, shaping_fn: str = None) -> None: super().__init__() self._metric = MeteorMetric() - from rl4lms.envs.text_generation.registry import RewardFunctionRegistry + from benchmark.torch.RL4LMs.registry import RewardFunctionRegistry self._shaping_fn = ( RewardFunctionRegistry.get(shaping_fn, {}) @@ -136,7 +136,7 @@ def __init__( super().__init__() self._metric = load_metric("rouge") self._rouge_type = rouge_type - from rl4lms.envs.text_generation.registry import RewardFunctionRegistry + from benchmark.torch.RL4LMs.registry import RewardFunctionRegistry self._shaping_fn = ( RewardFunctionRegistry.get(shaping_fn, {}) @@ -178,7 +178,7 @@ class RougeCombined(RewardFunction): def __init__(self, shaping_fn: str = None) -> None: super().__init__() self._metric = load_metric("rouge") - from rl4lms.envs.text_generation.registry import RewardFunctionRegistry + from benchmark.torch.RL4LMs.registry import RewardFunctionRegistry self._shaping_fn = ( RewardFunctionRegistry.get(shaping_fn, {}) @@ -360,30 +360,30 @@ def __call__( return 0 -class PARENTRewardFunction(RewardFunction): - """ - PARENT F1 score as the reward - """ - - def __init__(self) -> None: - super().__init__() - self._metric = ParentToTTo() - - def __call__( - self, - current_observation: Observation, - action: int, - next_observation: Observation, - done: bool, - meta_info: Dict[str, Any] = None, - ) -> float: - if done: - generated_texts = [next_observation.context_text] - meta_infos = [meta_info] - scores = self._metric.compute(None, generated_texts, None, meta_infos) - reward = scores["table_to_text/parent_overall_f_score"][0][0] - return reward - return 0 +# class PARENTRewardFunction(RewardFunction): +# """ +# PARENT F1 score as the reward +# """ +# +# def __init__(self) -> None: +# super().__init__() +# self._metric = ParentToTTo() +# +# def __call__( +# self, +# current_observation: Observation, +# action: int, +# next_observation: Observation, +# done: bool, +# meta_info: Dict[str, Any] = None, +# ) -> float: +# if done: +# generated_texts = [next_observation.context_text] +# meta_infos = [meta_info] +# scores = self._metric.compute(None, generated_texts, None, meta_infos) +# reward = scores["table_to_text/parent_overall_f_score"][0][0] +# return reward +# return 0 class RougeLMaxRewardFunction(RewardFunction): @@ -421,8 +421,8 @@ def __call__( reward_fn = MeteorRewardFunction() print(reward_fn(None, None, observation, True)) - reward_fn = chrF() - print(reward_fn(None, None, observation, True)) + # reward_fn = chrF() + # print(reward_fn(None, None, observation, True)) reward_fn = RougeCombined() print(reward_fn(None, None, observation, True)) diff --git a/benchmark/torch/RL4LMs/utils/warm_start.py b/benchmark/torch/RL4LMs/utils/warm_start.py index efa12ba14..d5c557d0e 100644 --- a/benchmark/torch/RL4LMs/utils/warm_start.py +++ b/benchmark/torch/RL4LMs/utils/warm_start.py @@ -1,6 +1,6 @@ import os from typing import Any, Dict - +from .tracker import Tracker import torch # from rl4lms.envs.text_generation.logging_utils import Tracker @@ -79,69 +79,69 @@ def load_from_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: # state_dict["replay_buffer"]) # # -# ################## Trainer Warm Start Mixins####################################### -# class TrainerWarmStartMixin: -# def _get_recent_ckpt_path(self, tracker: Tracker): -# try: -# checkpoints = os.listdir(tracker.checkpoint_base_path) -# except: -# os.makedirs(tracker.checkpoint_base_path) -# checkpoints = os.listdir(tracker.checkpoint_base_path) -# -# if len(checkpoints) == 0: -# return None, None -# -# sorted_ckpts = sorted(checkpoints, reverse=True, -# key=lambda ckpt: int(ckpt.split("_")[1])) -# recent_ckpt = sorted_ckpts[0] -# recent_ckpt_id = int(recent_ckpt.split("_")[1]) -# -# recent_ckpt_path = os.path.join( -# tracker.checkpoint_base_path, f"checkpoint_{recent_ckpt_id}") -# return recent_ckpt_path, recent_ckpt_id -# -# def load_trainer_state(self, tracker: Tracker): -# recent_ckpt_path, _ = self._get_recent_ckpt_path(tracker) -# state_dict = None -# try: -# if recent_ckpt_path is not None: -# state_dict = torch.load( -# recent_ckpt_path, map_location=torch.device("cuda")) -# tracker.log_info("Model checkpoint found - Warm starting") -# self._policy_state_dict = state_dict["policy_state"] -# self._alg_state_dict = state_dict["alg_state"] -# self._trainer_state = state_dict["trainer_state"] -# -# tracker.log_info( -# f"Loaded the current trainer state from: {self._trainer_state}") -# else: -# self._policy_state_dict = None -# self._alg_state_dict = None -# self._trainer_state = { -# "current_iter": 0, -# } -# except Exception as e: -# tracker.log_info(f"Exception while doing warm start {e}") -# tracker.log_info( -# f"Checkpoint may be corrupted...skipping warm start") -# self._policy_state_dict = None -# self._alg_state_dict = None -# self._trainer_state = { -# "current_iter": 0, -# } -# -# def save_trainer_state(self, tracker: Tracker, -# policy: LMActorCriticPolicy, -# trainer_state: Dict[str, Any]): -# full_state = { -# "alg_state": self._alg.get_state_dict(), -# "policy_state": policy.get_state_dict(), -# "trainer_state": trainer_state -# } -# _, recent_ckpt_id = self._get_recent_ckpt_path(tracker) -# -# # hot fix - just to save only the last checkpoint (overwrite) -# new_ckpt_id = 0 if recent_ckpt_id is None else recent_ckpt_id + 1 -# new_ckpt_path = os.path.join( -# tracker.checkpoint_base_path, f"checkpoint_{new_ckpt_id}") -# torch.save(full_state, new_ckpt_path, pickle_protocol=4) +################## Trainer Warm Start Mixins####################################### +class TrainerWarmStartMixin: + def _get_recent_ckpt_path(self, tracker: Tracker): + try: + checkpoints = os.listdir(tracker.checkpoint_base_path) + except: + os.makedirs(tracker.checkpoint_base_path) + checkpoints = os.listdir(tracker.checkpoint_base_path) + + if len(checkpoints) == 0: + return None, None + + sorted_ckpts = sorted(checkpoints, reverse=True, + key=lambda ckpt: int(ckpt.split("_")[1])) + recent_ckpt = sorted_ckpts[0] + recent_ckpt_id = int(recent_ckpt.split("_")[1]) + + recent_ckpt_path = os.path.join( + tracker.checkpoint_base_path, f"checkpoint_{recent_ckpt_id}") + return recent_ckpt_path, recent_ckpt_id + + def load_trainer_state(self, tracker: Tracker): + recent_ckpt_path, _ = self._get_recent_ckpt_path(tracker) + state_dict = None + try: + if recent_ckpt_path is not None: + state_dict = torch.load( + recent_ckpt_path, map_location=torch.device("cuda")) + tracker.log_info("Model checkpoint found - Warm starting") + self._policy_state_dict = state_dict["policy_state"] + self._alg_state_dict = state_dict["alg_state"] + self._trainer_state = state_dict["trainer_state"] + + tracker.log_info( + f"Loaded the current trainer state from: {self._trainer_state}") + else: + self._policy_state_dict = None + self._alg_state_dict = None + self._trainer_state = { + "current_iter": 0, + } + except Exception as e: + tracker.log_info(f"Exception while doing warm start {e}") + tracker.log_info( + f"Checkpoint may be corrupted...skipping warm start") + self._policy_state_dict = None + self._alg_state_dict = None + self._trainer_state = { + "current_iter": 0, + } + + def save_trainer_state(self, tracker: Tracker, + policy, + trainer_state: Dict[str, Any]): + full_state = { + "alg_state": self._agent.alg.get_state_dict(), + "policy_state": policy.get_state_dict(), + "trainer_state": trainer_state + } + _, recent_ckpt_id = self._get_recent_ckpt_path(tracker) + + # hot fix - just to save only the last checkpoint (overwrite) + new_ckpt_id = 0 if recent_ckpt_id is None else recent_ckpt_id + 1 + new_ckpt_path = os.path.join( + tracker.checkpoint_base_path, f"checkpoint_{new_ckpt_id}") + torch.save(full_state, new_ckpt_path, pickle_protocol=4) From f816028ada98bac5e5086df42fb727cca045acde Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 6 Mar 2023 10:23:42 +0800 Subject: [PATCH 03/34] benchmark of RL4LMs v0.1 --- benchmark/torch/RL4LMs/algorithms/ppo.py | 5 +- .../RL4LMs/configs/summarization/t5_ppo.yml | 6 +- benchmark/torch/RL4LMs/env/vec_env.py | 7 + benchmark/torch/RL4LMs/models/base_model.py | 7 +- .../torch/RL4LMs/models/seq2seq_model.py | 4 +- benchmark/torch/RL4LMs/registry.py | 4 + .../summarization/rl4lms_summa_agent.py | 45 +++-- benchmark/torch/RL4LMs/train.py | 31 ++-- benchmark/torch/RL4LMs/trainers.py | 69 +++++--- benchmark/torch/RL4LMs/utils/__init__.py | 4 +- .../torch/RL4LMs/utils/evaluation_util.py | 18 +- benchmark/torch/RL4LMs/utils/tracker.py | 11 -- benchmark/torch/RL4LMs/utils/warm_start.py | 157 +++++++----------- 13 files changed, 175 insertions(+), 193 deletions(-) diff --git a/benchmark/torch/RL4LMs/algorithms/ppo.py b/benchmark/torch/RL4LMs/algorithms/ppo.py index 7af0c9114..fde623c1b 100644 --- a/benchmark/torch/RL4LMs/algorithms/ppo.py +++ b/benchmark/torch/RL4LMs/algorithms/ppo.py @@ -1,5 +1,4 @@ import parl -from benchmark.torch.RL4LMs.utils import Tracker from benchmark.torch.RL4LMs.utils import Schedule from typing import Union, Optional, Dict, Any import torch @@ -13,7 +12,6 @@ class RL4LMPPO(parl.Algorithm): def __init__(self, model: parl.Model, - tracker: Tracker, learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 2048, batch_size: int = 64, @@ -31,7 +29,6 @@ def __init__(self, _init_setup_model: bool = True, ): super(RL4LMPPO, self).__init__(model=model) - self.tracker = tracker self.learning_rate = learning_rate self.n_steps = n_steps self.batch_size = batch_size @@ -49,7 +46,7 @@ def __init__(self, def learn(self, rollout_buffer, log_info): entropy_losses = log_info["entropy_losses"] - pg_losses = log_info["entropy_losses"] + pg_losses = log_info["pg_losses"] value_losses = log_info["value_losses"] clip_fractions = log_info["clip_fractions"] approx_kl_divs = log_info["approx_kl_divs"] diff --git a/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml b/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml index a907057d0..2707b4431 100644 --- a/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml +++ b/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml @@ -37,8 +37,8 @@ alg: #####CHNAGE FOR DEBUG######## n_steps: 5 #####CHANGE FOR DEBUG######## - batch_size: 16 - verbose: 1 + batch_size: 32 +# verbose: 1 learning_rate: 0.000002 n_epochs: 5 ent_coef: 0.0 @@ -46,7 +46,7 @@ alg: coeff: 0.001 target_kl: 0.2 model: - id: seq2seq_lm_actor_critic_policy + id: seq2seq_lm_actor_critic_model args: model_name: t5-base apply_model_parallel: True diff --git a/benchmark/torch/RL4LMs/env/vec_env.py b/benchmark/torch/RL4LMs/env/vec_env.py index bd94490aa..717a80fd9 100644 --- a/benchmark/torch/RL4LMs/env/vec_env.py +++ b/benchmark/torch/RL4LMs/env/vec_env.py @@ -163,6 +163,13 @@ def close(self) -> None: process.join() self.closed = True + def get_attr(self, attr_name: str, indices) -> List[Any]: + """Return attribute from vectorized environment (see base class).""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("get_attr", attr_name)) + return [remote.recv() for remote in target_remotes] + def set_attr(self, attr_name: str, value: Any, indices = None) -> None: """Set attribute inside vectorized environments (see base class).""" target_remotes = self._get_target_remotes(indices) diff --git a/benchmark/torch/RL4LMs/models/base_model.py b/benchmark/torch/RL4LMs/models/base_model.py index 5325182df..9d1b2c768 100644 --- a/benchmark/torch/RL4LMs/models/base_model.py +++ b/benchmark/torch/RL4LMs/models/base_model.py @@ -31,7 +31,8 @@ def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None,): + optimizer_kwargs: Optional[Dict[str, Any]] = None, + device=None): super().__init__() if optimizer_kwargs is None: optimizer_kwargs = {} @@ -42,6 +43,7 @@ def __init__(self, self.optimizer_class = optimizer_class self.optimizer_kwargs = optimizer_kwargs self.optimizer = None + self.device = device @abstractmethod def forward(self, *args, **kwargs): @@ -176,6 +178,7 @@ def __init__( optimizer_class: torch.optim.Optimizer = torch.optim.AdamW, generation_kwargs: Dict[str, Any] = {}, prompt_truncation_side: str = "left", + device=None ): """ @@ -190,7 +193,7 @@ def __init__( generation_kwargs (Dict[str, Any], optional): generation parameters for rollout. Defaults to {}. prompt_truncation_side (str, optional): truncation side for prompt text. Defaults to "left". """ - super().__init__(observation_space, action_space) + super().__init__(observation_space, action_space, device=device) self._action_space = action_space self._apply_model_parallel = apply_model_parallel self._build_model_heads(model_name) diff --git a/benchmark/torch/RL4LMs/models/seq2seq_model.py b/benchmark/torch/RL4LMs/models/seq2seq_model.py index 447b99494..a2a39c65a 100644 --- a/benchmark/torch/RL4LMs/models/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/models/seq2seq_model.py @@ -34,6 +34,7 @@ def __init__( generation_kwargs: Dict[str, Any] = {}, prompt_truncation_side: str = "left", state_dict: Dict[str, Any] = None, + device: torch.DeviceObjType = None, ): super().__init__( observation_space, @@ -45,8 +46,9 @@ def __init__( optimizer_class, generation_kwargs, prompt_truncation_side, + device=device ) - self.load_from_dict(state_dict) + # self.load_from_dict(state_dict) def _build_model_heads(self, model_name: str): self._policy_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) diff --git a/benchmark/torch/RL4LMs/registry.py b/benchmark/torch/RL4LMs/registry.py index 1a468515f..3456abe71 100644 --- a/benchmark/torch/RL4LMs/registry.py +++ b/benchmark/torch/RL4LMs/registry.py @@ -6,6 +6,7 @@ from benchmark.torch.RL4LMs.utils import TextGenPool, CNNDailyMail # from rl4lms.envs.text_generation.alg_wrappers import wrap_onpolicy_alg +from parl.utils import logger from benchmark.torch.RL4LMs.metrics import ( BaseMetric, @@ -50,6 +51,7 @@ class DataPoolRegistry: @classmethod def get(cls, datapool_id: str, kwargs: Dict[str, Any]) -> TextGenPool: + logger.info(f"loading split of dataset: {datapool_id} -- {kwargs['split']}") datapool_cls = cls._registry[datapool_id] datapool = datapool_cls.prepare(**kwargs) return datapool @@ -75,6 +77,7 @@ class RewardFunctionRegistry: @classmethod def get(cls, reward_fn_id: str, kwargs: Dict[str, Any]) -> RewardFunction: + logger.info(f"loading reward function: {reward_fn_id}") reward_cls = cls._registry[reward_fn_id] reward_fn = reward_cls(**kwargs) return reward_fn @@ -106,6 +109,7 @@ class MetricRegistry: @classmethod def get(cls, metric_id: str, kwargs: Dict[str, Any]) -> BaseMetric: + logger.info(f"loading metric: {metric_id}") metric_cls = cls._registry[metric_id] metric = metric_cls(**kwargs) return metric diff --git a/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py b/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py index d07200c02..952196b1c 100644 --- a/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py +++ b/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py @@ -4,7 +4,8 @@ from typing import List import torch from benchmark.torch.RL4LMs.utils import TransitionInfo,\ - RewardFunction, Tracker + RewardFunction +from parl.utils import logger def compute_batched_rewards( @@ -59,14 +60,12 @@ class RL4LMsSummaAgent(parl.Agent): def __init__(self, algorithm, alg_config, - tracker: Tracker, norm_reward: bool = False, ): super(RL4LMsSummaAgent, self).__init__(algorithm) self.dataset = None self.config = alg_config - self.n_epochs = alg_config["n_epochs"] - self._tracker = tracker + self.n_epochs = alg_config["args"]["n_epochs"] self._norm_reward = norm_reward self._n_updates = 0 @@ -80,7 +79,7 @@ def learn(self, rollout_buffer): approx_kl_divs = [] log_info = { "entropy_losses": entropy_losses, - "pg_losses": entropy_losses, + "pg_losses": pg_losses, "value_losses": value_losses, "clip_fractions": clip_fractions, "approx_kl_divs": approx_kl_divs @@ -103,39 +102,37 @@ def learn(self, rollout_buffer): rollout_buffer.values.flatten(), rollout_buffer.returns.flatten()) # Logs - print("train/entropy_loss", np.mean(entropy_losses)) - print("train/policy_gradient_loss", np.mean(pg_losses)) - print("train/value_loss", np.mean(value_losses)) - print("train/approx_kl", np.mean(approx_kl_divs)) - print("train/clip_fraction", np.mean(clip_fractions)) - print("train/loss", loss.item()) - print("train/explained_variance", explained_var) - # self.logger.record("train/entropy_loss", np.mean(entropy_losses)) - # self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) - # self.logger.record("train/value_loss", np.mean(value_losses)) - # self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) - # self.logger.record("train/clip_fraction", np.mean(clip_fractions)) - # self.logger.record("train/loss", loss.item()) - # self.logger.record("train/explained_variance", explained_var) + train_info = { + "train/entropy_loss": np.mean(entropy_losses), + "train/policy_gradient_loss": np.mean(pg_losses), + "train/value_loss": np.mean(value_losses), + "train/approx_kl": np.mean(approx_kl_divs), + "train/clip_fraction": np.mean(clip_fractions), + "train/loss": loss.item(), + "train/explained_variance": explained_var + } + if hasattr(self.alg.model, "log_std"): # self.logger.record( # "train/std", torch.exp(self.policy.log_std).mean().item()) - print("train/std", torch.exp(self.alg.model.log_std).mean().item()) + train_info["train/std"] = torch.exp(self.alg.model.log_std).mean().item() # self.logger.record("train/n_updates", # self._n_updates, exclude="tensorboard") # self.logger.record("train/clip_range", clip_range) - print("train/n_updates", self._n_updates) - print("train/clip_range", self.alg.clip_range) + train_info["train/n_updates"] = self._n_updates + train_info["train/clip_range"] = self.alg.clip_range - train_info = { + logger.info(train_info) + + ppo_train_info = { "ppo/entropy_loss": np.mean(entropy_losses).item(), "ppo/policy_gradient_loss": np.mean(pg_losses).item(), "ppo/value_loss": np.mean(value_losses).item(), "ppo/approx_kl": np.mean(approx_kl_divs).item(), } - self._tracker.log_training_infos(train_info) + logger.info(ppo_train_info) # for k, v in train_info.items(): # print(f"{k}: {v}") diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 7ab7a2563..6217348f6 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -1,10 +1,11 @@ import os +import sys from argparse import ArgumentParser - +import datetime import yaml import collections from trainers import OnPolicyTrainer -from utils import Tracker +from parl.utils import logger def recursive_dict_update(d, u): @@ -19,14 +20,14 @@ def recursive_dict_update(d, u): def main(config): # load tracker - tracker = Tracker( - config["base_path_to_store_results"], - config, - config["project_name"], - config["experiment_name"], - config["entity_name"], - False, - ) + # tracker = Tracker( + # config["base_path_to_store_results"], + # config, + # config["project_name"], + # config["experiment_name"], + # config["entity_name"], + # False, + # ) # instantiate the trainer here # TODO: currently only complete ppo @@ -38,7 +39,6 @@ def main(config): env_config=config["env"], on_policy_alg_config=config["alg"], train_eval_config=config["train_evaluation"], - tracker=tracker, ) else: raise NotImplementedError @@ -68,7 +68,7 @@ def main(config): default=os.getcwd(), ) parser.add_argument( - "--entity_name", type=str, help="entity name", default=None + "--entity_name", type=str, help="entity name", default="summarization" ) args = parser.parse_args() @@ -77,6 +77,11 @@ def main(config): config = yaml.safe_load(fp) recursive_dict_update(config, vars(args)) - + log_dir = f"./{args.project_name}/{args.experiment_name}/{args.entity_name}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + logger.set_dir(log_dir) + config["logging_dir"] = log_dir + config["sys_arg"] = sys.argv + logger.info(config) + logger.set_level("DEBUG") main(config) diff --git a/benchmark/torch/RL4LMs/trainers.py b/benchmark/torch/RL4LMs/trainers.py index 7db9dbf02..585aa4a4b 100644 --- a/benchmark/torch/RL4LMs/trainers.py +++ b/benchmark/torch/RL4LMs/trainers.py @@ -1,11 +1,12 @@ import os +import time from functools import partial from typing import Any, Dict, List import numpy as np -from benchmark.torch.RL4LMs.utils import Sample, Tracker, RewardFunction,\ - evaluate_on_samples, TrainerWarmStartMixin,\ +from benchmark.torch.RL4LMs.utils import Sample, RewardFunction,\ + evaluate_on_samples,\ KLController, RolloutBuffer, DictRolloutBuffer, MaskableDictRolloutBuffer,\ TransitionInfo, TensorDict, RefPolicyOutput, ValueOutput, PolicyOutput from benchmark.torch.RL4LMs.registry import DataPoolRegistry, MetricRegistry, RewardFunctionRegistry, \ @@ -23,8 +24,10 @@ from benchmark.torch.RL4LMs.summarization import RL4LMsSummaAgent from benchmark.torch.RL4LMs.algorithms import RL4LMPPO import torch +from parl.utils import logger def build_tokenizer(tokenizer_config: Dict[str, Any]): + logger.info(f"loading tokenizer of [{tokenizer_config['model_name']}] model") tokenizer = AutoTokenizer.from_pretrained( tokenizer_config["model_name"]) if tokenizer.pad_token is None and tokenizer_config.get("pad_token_as_eos_token", True): @@ -89,8 +92,8 @@ def build_env(env_config: Dict[str, Any], def build_agent(alg_config: Dict[str, Any], env: LocalParallelVecEnv, - tracker: Tracker, model_state: Dict[str, Any] = None, # TODO: save model checkpoint + device = None, alg_state: Dict[str, Any] = None # TODO: save alg checkpoint ): model_config = alg_config["model"] @@ -103,15 +106,17 @@ def build_agent(alg_config: Dict[str, Any], rl4lms_model = model_cls( observation_space = env.observation_space, action_space= env.action_space, + device=device, **model_args ) rl4lm_alg_cls = alg_cls( model=rl4lms_model, + device=device, **alg_config.get("args") ) - rl4lm_agent = RL4LMsSummaAgent(rl4lm_alg_cls, alg_config, tracker) + rl4lm_agent = RL4LMsSummaAgent(rl4lm_alg_cls, alg_config) return rl4lm_agent @@ -133,7 +138,7 @@ def unpack_observations(obs_tensor, n_envs: int): return unpacked_obs -class OnPolicyTrainer(TrainerWarmStartMixin): +class OnPolicyTrainer(): """ A generic trainer for training LMs with onpolicy algorithms from SB3 """ @@ -145,7 +150,6 @@ def __init__(self, env_config: Dict[str, Any], on_policy_alg_config: Dict[str, Any], train_eval_config: Dict[str, Any], - tracker: Tracker = None, experiment_name: str = '' ): self._tokenizer_config = tokenizer_config @@ -154,7 +158,6 @@ def __init__(self, self._env_config = env_config self._on_policy_alg_config = on_policy_alg_config self._train_eval_config = train_eval_config - self._tracker = tracker self._experiment_name = experiment_name self._agent = None self._env = None @@ -169,7 +172,7 @@ def __init__(self, def _setup(self): # load trainer state from available previous checkpoint if available - self.load_trainer_state(self._tracker) + # self.load_trainer_state(self._tracker) # build components self._tokenizer = build_tokenizer(self._tokenizer_config) @@ -182,7 +185,7 @@ def _setup(self): self._agent = build_agent(self._on_policy_alg_config, - self._env, self._tracker) + self._env, device=self.device) self._rollout_buffer = MaskableDictRolloutBuffer( buffer_size=self._agent.alg.n_steps * self._env.num_envs, @@ -196,7 +199,7 @@ def _setup(self): self._kl_controller = KLController( self._on_policy_alg_config["kl_div"]["coeff"], - self._on_policy_alg_config["kl_div"].get("norm_reward", False)) + self._on_policy_alg_config["kl_div"].get("target_kl", None)) # extract train params self._max_episode_length = self._env_config["args"]["max_episode_length"] @@ -221,18 +224,22 @@ def _evaluate_on_datapools(self, epoch: int, metrics=self._metrics, epoch=epoch, split_name=split, - tracker=self._tracker, gen_kwargs=self._eval_gen_kwargs) def train_and_eval(self): # evaluate on val and test set before fine-tuning once - iter_start = self._trainer_state["current_iter"] + # iter_start = self._trainer_state["current_iter"] + iter_start = 0 self._evaluate_on_datapools(epoch=iter_start) # train for given number of iters for epoch in range(iter_start, self._n_iters): + print("========== BEGIN ==========") + print(f"outer epoch: {epoch} / {self._n_iters - 1}") + print("========== BEGIN ==========") + outer_start_time = time.time() # current state - self._trainer_state["current_iter"] = epoch + # self._trainer_state["current_iter"] = epoch self._num_timesteps = 0 @@ -243,21 +250,29 @@ def train_and_eval(self): self._agent.learn(self._rollout_buffer) # save the policy checkpoint - if (epoch + 1) % self._train_eval_config.get("save_every", 20) == 0: - self.save_trainer_state( - self._tracker, self._alg.policy, self._trainer_state) + # if (epoch + 1) % self._train_eval_config.get("save_every", 20) == 0: + # self.save_trainer_state( + # self._tracker, self._alg.policy, self._trainer_state) # evaluate on val set in the given intervals if (epoch + 1) % self._train_eval_config["eval_every"] == 0: self._evaluate_on_datapools(epoch=epoch, splits=["val"]) + outer_end_time = time.time() + print("========== END ==========") + print(f"outer epoch: {epoch} / {self._n_iters - 1}") + print(f"time used: {outer_end_time - outer_start_time} second(s), left time:" + f" {1.0 * (outer_end_time - outer_start_time) * (self._n_iters - epoch - 1) / 60 / 60} hour(s)") + print("========== END ==========") + + # finally evaluate on val and test samples self._evaluate_on_datapools(epoch=epoch) - # save model here - we save only the language model - if self._tracker is not None: - self._tracker.save_auto_model( - self._alg.policy.get_language_model()) + # # save model here - we save only the language model + # if self._tracker is not None: + # self._tracker.save_auto_model( + # self._alg.policy.get_language_model()) def get_policy_kwargs( @@ -294,8 +309,8 @@ def generate_batch( # generate text using the model obs_tensor = dict_to_tensor(current_obs, self.device) - generation_inputs = self._agent.model.get_inputs_for_generation(obs_tensor) - gen_output = self._agent.model.generate( + generation_inputs = self._agent.alg.model.get_inputs_for_generation(obs_tensor) + gen_output = self._agent.alg.model.generate( input_ids=generation_inputs.inputs, attention_mask=generation_inputs.attention_masks, tokenizer=tokenizer, @@ -329,7 +344,7 @@ def generate_batch( obs_tensor, actions_tensor, policy_past_state, action_mask ) - policy_outputs: PolicyOutput = self.policy.forward_policy( + policy_outputs: PolicyOutput = self._agent.alg.model.forward_policy( **policy_kwargs ) raw_log_probs, log_probs, policy_past_state = ( @@ -349,7 +364,7 @@ def generate_batch( ), "Infinite values in log probs" # get values - value_outputs: ValueOutput = self.policy.forward_value( + value_outputs: ValueOutput = self._agent.alg.model.forward_value( obs_tensor, value_past_state ) values, value_past_state = ( @@ -359,7 +374,7 @@ def generate_batch( # get reference log probs ref_policy_outputs: RefPolicyOutput = ( - self.policy.get_log_probs_ref_model( + self._agent.alg.model.get_log_probs_ref_model( obs_tensor, actions_tensor, ref_past_state ) ) @@ -494,10 +509,10 @@ def collect_rollouts( rollout_buffer: RolloutBuffer, ) -> bool: # max episode steps - max_steps = env.unwrapped.get_attr("max_steps", [0])[0] + max_steps = env.get_attr("max_steps", [0])[0] # get tokenizer - tokenizer = env.unwrapped.get_attr("tokenizer", [0]) + tokenizer = env.get_attr("tokenizer", [0]) tokenizer = tokenizer[0] # Switch to eval mode diff --git a/benchmark/torch/RL4LMs/utils/__init__.py b/benchmark/torch/RL4LMs/utils/__init__.py index 9e2ea3014..1905032f0 100644 --- a/benchmark/torch/RL4LMs/utils/__init__.py +++ b/benchmark/torch/RL4LMs/utils/__init__.py @@ -5,7 +5,7 @@ from .huggingface_generation_util import override_generation_routines -from .warm_start import ActorCriticWarmStartMixin, OnPolicyWarmStartMixin, TrainerWarmStartMixin +from .warm_start import ActorCriticWarmStartMixin, OnPolicyWarmStartMixin from .type_wrapper import TensorDict, Schedule @@ -18,8 +18,6 @@ from .kl_controller import KLController -from .tracker import Tracker - from .evaluation_util import evaluate_on_samples from .data_pool import TextGenPool, CNNDailyMail diff --git a/benchmark/torch/RL4LMs/utils/evaluation_util.py b/benchmark/torch/RL4LMs/utils/evaluation_util.py index c9f7319d6..0a86b0f09 100644 --- a/benchmark/torch/RL4LMs/utils/evaluation_util.py +++ b/benchmark/torch/RL4LMs/utils/evaluation_util.py @@ -4,7 +4,8 @@ from transformers import AutoTokenizer from . import Sample -from .metric_util import BaseMetric +from benchmark.torch.RL4LMs.metrics import BaseMetric +from parl.utils import logger def get_batch(samples: List[Sample], batch_size: int): @@ -25,8 +26,6 @@ def evaluate_on_samples( metrics: List[BaseMetric], epoch: int, split_name: str, - # tracker: Tracker = None, - tracker = None, # TODO: change tracker to parl logging dt_control_token: str = "", gen_kwargs: Dict[str, Any] = None, ): @@ -99,11 +98,14 @@ def evaluate_on_samples( sample_predictions_dict.append(sample_prediction) - if tracker is not None: - # log the entire predictions - tracker.log_predictions(epoch, split_name, sample_predictions_dict) - # log the corpus level scores - tracker.log_metrics(epoch, split_name, corpus_level_metrics) + + metrics_dict_ = { + "epoch": epoch, + "metrics": corpus_level_metrics + } + + # logger + logger.info(f"{split_name} metrics: {metrics_dict_}") def generate_text( diff --git a/benchmark/torch/RL4LMs/utils/tracker.py b/benchmark/torch/RL4LMs/utils/tracker.py index 5c48855b7..203aa94ec 100644 --- a/benchmark/torch/RL4LMs/utils/tracker.py +++ b/benchmark/torch/RL4LMs/utils/tracker.py @@ -141,14 +141,3 @@ def log_info(self, msg: str): {"ep_len": 3, "ep_reward": 0.5}, ] - tracker = Tracker(base_path, run_config, "test_logs", "test_run", "T_1", False) - tracker.log_predictions(1, "val", predictions["1"]) - tracker.log_metrics(1, "val", metrics["1"]) - tracker.log_predictions(2, "val", predictions["2"]) - tracker.log_metrics(2, "val", metrics["2"]) - tracker.log_predictions(3, "val", predictions["3"]) - tracker.log_metrics(3, "val", metrics["3"]) - tracker.log_rollout_infos(rollout_infos[0]) - tracker.log_rollout_infos(rollout_infos[1]) - tracker.log_rollout_infos(rollout_infos[2]) - tracker.done() \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/warm_start.py b/benchmark/torch/RL4LMs/utils/warm_start.py index d5c557d0e..fa700230f 100644 --- a/benchmark/torch/RL4LMs/utils/warm_start.py +++ b/benchmark/torch/RL4LMs/utils/warm_start.py @@ -1,6 +1,5 @@ import os from typing import Any, Dict -from .tracker import Tracker import torch # from rl4lms.envs.text_generation.logging_utils import Tracker @@ -39,109 +38,73 @@ def load_from_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: self._kl_controller.load_from_state_dict( state_dict["kl_controller_state"]) -# ################## Policy Warm Start Mixins####################################### -# -# -# class ActorOnlyWarmStartMixin: -# def get_state_dict(self) -> Dict[str, Any]: -# state_dict = { -# "policy_model": self._policy_model.state_dict(), -# "optimizer": self.optimizer.state_dict() -# } -# return state_dict -# -# def load_from_dict(self, state_dict: dict = None): -# if state_dict is not None: -# self._policy_model.load_state_dict(state_dict["policy_model"]) -# self.optimizer.load_state_dict(state_dict["optimizer"]) -# + + # # +################## Trainer Warm Start Mixins####################################### +# class TrainerWarmStartMixin: +# def _get_recent_ckpt_path(self, tracker: Tracker): +# try: +# checkpoints = os.listdir(tracker.checkpoint_base_path) +# except: +# os.makedirs(tracker.checkpoint_base_path) +# checkpoints = os.listdir(tracker.checkpoint_base_path) # +# if len(checkpoints) == 0: +# return None, None # +# sorted_ckpts = sorted(checkpoints, reverse=True, +# key=lambda ckpt: int(ckpt.split("_")[1])) +# recent_ckpt = sorted_ckpts[0] +# recent_ckpt_id = int(recent_ckpt.split("_")[1]) # +# recent_ckpt_path = os.path.join( +# tracker.checkpoint_base_path, f"checkpoint_{recent_ckpt_id}") +# return recent_ckpt_path, recent_ckpt_id # -# ################## Algorithm Warm Start Mixins####################################### - +# def load_trainer_state(self, tracker: Tracker): +# recent_ckpt_path, _ = self._get_recent_ckpt_path(tracker) +# state_dict = None +# try: +# if recent_ckpt_path is not None: +# state_dict = torch.load( +# recent_ckpt_path, map_location=torch.device("cuda")) +# tracker.log_info("Model checkpoint found - Warm starting") +# self._policy_state_dict = state_dict["policy_state"] +# self._alg_state_dict = state_dict["alg_state"] +# self._trainer_state = state_dict["trainer_state"] # +# tracker.log_info( +# f"Loaded the current trainer state from: {self._trainer_state}") +# else: +# self._policy_state_dict = None +# self._alg_state_dict = None +# self._trainer_state = { +# "current_iter": 0, +# } +# except Exception as e: +# tracker.log_info(f"Exception while doing warm start {e}") +# tracker.log_info( +# f"Checkpoint may be corrupted...skipping warm start") +# self._policy_state_dict = None +# self._alg_state_dict = None +# self._trainer_state = { +# "current_iter": 0, +# } # -# class OffPolicyWarmStartMixin: -# def get_state_dict(self) -> Dict[str, Any]: -# # TBD: just buffer is sufficient? or is there something else? -# state_dict = { -# "replay_buffer": self.replay_buffer.get_state_dict(), +# def save_trainer_state(self, tracker: Tracker, +# policy, +# trainer_state: Dict[str, Any]): +# full_state = { +# "alg_state": self._agent.alg.get_state_dict(), +# "policy_state": policy.get_state_dict(), +# "trainer_state": trainer_state # } -# return state_dict -# -# def load_from_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: -# if state_dict is not None: -# self.replay_buffer.load_from_state_dict( -# state_dict["replay_buffer"]) +# _, recent_ckpt_id = self._get_recent_ckpt_path(tracker) # -# -################## Trainer Warm Start Mixins####################################### -class TrainerWarmStartMixin: - def _get_recent_ckpt_path(self, tracker: Tracker): - try: - checkpoints = os.listdir(tracker.checkpoint_base_path) - except: - os.makedirs(tracker.checkpoint_base_path) - checkpoints = os.listdir(tracker.checkpoint_base_path) - - if len(checkpoints) == 0: - return None, None - - sorted_ckpts = sorted(checkpoints, reverse=True, - key=lambda ckpt: int(ckpt.split("_")[1])) - recent_ckpt = sorted_ckpts[0] - recent_ckpt_id = int(recent_ckpt.split("_")[1]) - - recent_ckpt_path = os.path.join( - tracker.checkpoint_base_path, f"checkpoint_{recent_ckpt_id}") - return recent_ckpt_path, recent_ckpt_id - - def load_trainer_state(self, tracker: Tracker): - recent_ckpt_path, _ = self._get_recent_ckpt_path(tracker) - state_dict = None - try: - if recent_ckpt_path is not None: - state_dict = torch.load( - recent_ckpt_path, map_location=torch.device("cuda")) - tracker.log_info("Model checkpoint found - Warm starting") - self._policy_state_dict = state_dict["policy_state"] - self._alg_state_dict = state_dict["alg_state"] - self._trainer_state = state_dict["trainer_state"] - - tracker.log_info( - f"Loaded the current trainer state from: {self._trainer_state}") - else: - self._policy_state_dict = None - self._alg_state_dict = None - self._trainer_state = { - "current_iter": 0, - } - except Exception as e: - tracker.log_info(f"Exception while doing warm start {e}") - tracker.log_info( - f"Checkpoint may be corrupted...skipping warm start") - self._policy_state_dict = None - self._alg_state_dict = None - self._trainer_state = { - "current_iter": 0, - } - - def save_trainer_state(self, tracker: Tracker, - policy, - trainer_state: Dict[str, Any]): - full_state = { - "alg_state": self._agent.alg.get_state_dict(), - "policy_state": policy.get_state_dict(), - "trainer_state": trainer_state - } - _, recent_ckpt_id = self._get_recent_ckpt_path(tracker) - - # hot fix - just to save only the last checkpoint (overwrite) - new_ckpt_id = 0 if recent_ckpt_id is None else recent_ckpt_id + 1 - new_ckpt_path = os.path.join( - tracker.checkpoint_base_path, f"checkpoint_{new_ckpt_id}") - torch.save(full_state, new_ckpt_path, pickle_protocol=4) +# # hot fix - just to save only the last checkpoint (overwrite) +# new_ckpt_id = 0 if recent_ckpt_id is None else recent_ckpt_id + 1 +# new_ckpt_path = os.path.join( +# tracker.checkpoint_base_path, f"checkpoint_{new_ckpt_id}") +# torch.save(full_state, new_ckpt_path, pickle_protocol=4) From 029373445c26e8f67fac5c5f4b4df331d5641a58 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 6 Mar 2023 15:21:16 +0800 Subject: [PATCH 04/34] fix pg reward bug, remove no use warmstartup --- benchmark/torch/RL4LMs/agents/__init__.py | 1 + .../rl4lms_summa_agent.py | 94 +- benchmark/torch/RL4LMs/algorithms/ppo.py | 50 +- .../RL4LMs/configs/summarization/t5_ppo.yml | 8 +- .../torch/RL4LMs/models/seq2seq_model.py | 6 +- benchmark/torch/RL4LMs/registry.py | 2 +- .../torch/RL4LMs/summarization/__init__.py | 1 - benchmark/torch/RL4LMs/train.py | 10 - benchmark/torch/RL4LMs/trainers.py | 32 +- benchmark/torch/RL4LMs/utils/__init__.py | 2 - .../utils/huggingface_generation_util.py | 1964 +---------------- benchmark/torch/RL4LMs/utils/kl_controller.py | 11 - benchmark/torch/RL4LMs/utils/tracker.py | 143 -- benchmark/torch/RL4LMs/utils/warm_start.py | 110 - 14 files changed, 175 insertions(+), 2259 deletions(-) rename benchmark/torch/RL4LMs/{summarization => agents}/rl4lms_summa_agent.py (70%) delete mode 100644 benchmark/torch/RL4LMs/summarization/__init__.py delete mode 100644 benchmark/torch/RL4LMs/utils/tracker.py delete mode 100644 benchmark/torch/RL4LMs/utils/warm_start.py diff --git a/benchmark/torch/RL4LMs/agents/__init__.py b/benchmark/torch/RL4LMs/agents/__init__.py index e69de29bb..b3361187c 100644 --- a/benchmark/torch/RL4LMs/agents/__init__.py +++ b/benchmark/torch/RL4LMs/agents/__init__.py @@ -0,0 +1 @@ +from .rl4lms_summa_agent import RL4LMsSummaAgent \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py b/benchmark/torch/RL4LMs/agents/rl4lms_summa_agent.py similarity index 70% rename from benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py rename to benchmark/torch/RL4LMs/agents/rl4lms_summa_agent.py index 952196b1c..e873c54c4 100644 --- a/benchmark/torch/RL4LMs/summarization/rl4lms_summa_agent.py +++ b/benchmark/torch/RL4LMs/agents/rl4lms_summa_agent.py @@ -3,44 +3,8 @@ from typing import List import torch -from benchmark.torch.RL4LMs.utils import TransitionInfo,\ - RewardFunction from parl.utils import logger - -def compute_batched_rewards( - episode_wise_transitions: List[List[TransitionInfo]], reward_fn: RewardFunction -): - # first collect all the prompts, ref and gen texts - prompts = [] - reference_texts = [] - generated_texts = [] - is_dones = [] - indices = [] - meta_infos = [] - for env_ix, transitions in enumerate(episode_wise_transitions): - for trans_ix, transition in enumerate(transitions): - done = transition.done - info = transition.info - prompts.append(info["prompt_text"]) - reference_texts.append(info["reference_text"]) - generated_texts.append(info["output"]) - is_dones.append(done) - meta_infos.append(info["meta_info"]) - indices.append((env_ix, trans_ix)) - - # compute rewards all at once - rewards = reward_fn(prompts, generated_texts, reference_texts, is_dones, meta_infos) - # rewards = rewards.numpy().flatten() - - # override the rewards in transitions - for (env_ix, trans_ix), reward in zip(indices, rewards): - episode_wise_transitions[env_ix][trans_ix].task_reward = reward - episode_wise_transitions[env_ix][trans_ix].total_reward = ( - reward + episode_wise_transitions[env_ix][trans_ix].kl_reward - ) - - def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: """ Computes fraction of variance that ypred explains about y. @@ -69,9 +33,6 @@ def __init__(self, self._norm_reward = norm_reward self._n_updates = 0 - - - def learn(self, rollout_buffer): entropy_losses = [] pg_losses, value_losses = [], [] @@ -85,7 +46,6 @@ def learn(self, rollout_buffer): "approx_kl_divs": approx_kl_divs } - continue_training = True loss = torch.tensor(0.0) # train for n_epochs epochs @@ -133,8 +93,11 @@ def learn(self, rollout_buffer): } logger.info(ppo_train_info) - # for k, v in train_info.items(): - # print(f"{k}: {v}") + + + def get_inputs_for_generation(self, obs_tensor): + return self.alg.model.get_inputs_for_generation(obs_tensor) + def predict(self, *args, **kwargs): pass @@ -142,4 +105,51 @@ def predict(self, *args, **kwargs): def sample(self, *args, **kwargs): pass + def forward_value( + self, + obs, + past_model_kwargs = None, + ): + return self.alg.forward_value(obs, past_model_kwargs) + + def forward_policy( + self, + obs, + actions, + past_model_kwargs = None, + ): + return self.alg.forward_policy( + obs = obs, + actions = actions, + past_model_kwargs = past_model_kwargs, + ) + + + def get_log_probs_ref_model( + self, + obs, + action, + model_kwarpast_model_kwargsgs = None, + ): + return self.alg.get_log_probs_ref_model(obs, action, model_kwarpast_model_kwargsgs) + + def generate( + self, + tokenizer, + texts = None, + max_prompt_length = None, + input_ids = None, + attention_mask = None, + gen_kwargs = None, + ): + return self.alg.generate( + input_ids=input_ids, + attention_mask=attention_mask, + tokenizer=tokenizer, + texts=texts, + max_prompt_length=max_prompt_length, + gen_kwargs=gen_kwargs + ) + def eval_mode(self): + self.alg.eval_mode() diff --git a/benchmark/torch/RL4LMs/algorithms/ppo.py b/benchmark/torch/RL4LMs/algorithms/ppo.py index fde623c1b..060da4fc9 100644 --- a/benchmark/torch/RL4LMs/algorithms/ppo.py +++ b/benchmark/torch/RL4LMs/algorithms/ppo.py @@ -138,5 +138,51 @@ def predict(self, obs): def value(self, obs): pass - - + def forward_value( + self, + obs, + past_model_kwargs = None, + ): + return self.model.forward_value(obs, past_model_kwargs) + + def forward_policy( + self, + obs, + actions: torch.tensor, + past_model_kwargs = None, + ): + return self.model.forward_policy( + obs = obs, + actions = actions, + past_model_kwargs = past_model_kwargs, + ) + + + def get_log_probs_ref_model( + self, + obs, + action, + model_kwarpast_model_kwargsgs = None, + ): + return self.model.get_log_probs_ref_model(obs, action, model_kwarpast_model_kwargsgs) + + def generate( + self, + tokenizer, + texts = None, + max_prompt_length = None, + input_ids = None, + attention_mask = None, + gen_kwargs = None, + ): + return self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + tokenizer=tokenizer, + texts = texts, + max_prompt_length = max_prompt_length, + gen_kwargs = gen_kwargs + ) + + def eval_mode(self): + self.model.eval() \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml b/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml index 2707b4431..783d6d246 100644 --- a/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml +++ b/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml @@ -20,8 +20,8 @@ datapool: env: ## CHANGE FOR DEBUG ## -# n_envs: 10 - n_envs: 2 + n_envs: 10 +# n_envs: 2 ## CHANGE FOR DEBUG ## args: max_prompt_length: 512 @@ -33,9 +33,9 @@ env: alg: id: ppo args: -# n_steps: 512 + n_steps: 512 #####CHNAGE FOR DEBUG######## - n_steps: 5 +# n_steps: 5 #####CHANGE FOR DEBUG######## batch_size: 32 # verbose: 1 diff --git a/benchmark/torch/RL4LMs/models/seq2seq_model.py b/benchmark/torch/RL4LMs/models/seq2seq_model.py index a2a39c65a..055e57d9d 100644 --- a/benchmark/torch/RL4LMs/models/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/models/seq2seq_model.py @@ -10,9 +10,7 @@ from benchmark.torch.RL4LMs.utils import ( override_generation_routines, - ActorCriticWarmStartMixin, - - TensorDict, Schedule, + TensorDict, GenerationInputs, PolicyOutput, RefPolicyOutput, ValueOutput, PolicyType, EvaluateActionsOutput, GenerationOutputs, @@ -21,7 +19,7 @@ from .base_model import LMActorCriticModel -class Seq2SeqLMModel(LMActorCriticModel, ActorCriticWarmStartMixin): +class Seq2SeqLMModel(LMActorCriticModel): def __init__( self, observation_space: DictSpace, diff --git a/benchmark/torch/RL4LMs/registry.py b/benchmark/torch/RL4LMs/registry.py index 3456abe71..3577fffae 100644 --- a/benchmark/torch/RL4LMs/registry.py +++ b/benchmark/torch/RL4LMs/registry.py @@ -2,7 +2,7 @@ from benchmark.torch.RL4LMs.algorithms import RL4LMPPO -from benchmark.torch.RL4LMs.summarization import RL4LMsSummaAgent +from benchmark.torch.RL4LMs.agents import RL4LMsSummaAgent from benchmark.torch.RL4LMs.utils import TextGenPool, CNNDailyMail # from rl4lms.envs.text_generation.alg_wrappers import wrap_onpolicy_alg diff --git a/benchmark/torch/RL4LMs/summarization/__init__.py b/benchmark/torch/RL4LMs/summarization/__init__.py deleted file mode 100644 index dcf74dbe4..000000000 --- a/benchmark/torch/RL4LMs/summarization/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .rl4lms_summa_agent import RL4LMsSummaAgent diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 6217348f6..df81a8dc3 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -19,16 +19,6 @@ def recursive_dict_update(d, u): def main(config): - # load tracker - # tracker = Tracker( - # config["base_path_to_store_results"], - # config, - # config["project_name"], - # config["experiment_name"], - # config["entity_name"], - # False, - # ) - # instantiate the trainer here # TODO: currently only complete ppo if "ppo" == config["alg"]["id"]: diff --git a/benchmark/torch/RL4LMs/trainers.py b/benchmark/torch/RL4LMs/trainers.py index 585aa4a4b..5dbd7bde8 100644 --- a/benchmark/torch/RL4LMs/trainers.py +++ b/benchmark/torch/RL4LMs/trainers.py @@ -1,28 +1,19 @@ -import os import time -from functools import partial from typing import Any, Dict, List import numpy as np from benchmark.torch.RL4LMs.utils import Sample, RewardFunction,\ evaluate_on_samples,\ - KLController, RolloutBuffer, DictRolloutBuffer, MaskableDictRolloutBuffer,\ + KLController, RolloutBuffer, MaskableDictRolloutBuffer,\ TransitionInfo, TensorDict, RefPolicyOutput, ValueOutput, PolicyOutput from benchmark.torch.RL4LMs.registry import DataPoolRegistry, MetricRegistry, RewardFunctionRegistry, \ ModelRegistry, AlgorithmRegistry from benchmark.torch.RL4LMs.env import TextGenEnv -from transformers import (AutoTokenizer, - AutoModelForCausalLM, - AutoModelForSeq2SeqLM, - Trainer, - TrainingArguments, - DataCollatorForLanguageModeling, - DataCollatorForSeq2Seq) +from transformers import AutoTokenizer from benchmark.torch.RL4LMs.env import LocalParallelVecEnv, make_vec_env from transformers import PreTrainedTokenizer -from benchmark.torch.RL4LMs.summarization import RL4LMsSummaAgent -from benchmark.torch.RL4LMs.algorithms import RL4LMPPO +from benchmark.torch.RL4LMs.agents import RL4LMsSummaAgent import torch from parl.utils import logger @@ -189,8 +180,8 @@ def _setup(self): self._rollout_buffer = MaskableDictRolloutBuffer( buffer_size=self._agent.alg.n_steps * self._env.num_envs, - observation_space=self._agent.alg.model.observation_space, - action_space=self._agent.alg.model.action_space, + observation_space=self._env.observation_space, + action_space=self._env.action_space, device=self.device, gamma=self._agent.alg.gamma, gae_lambda=self._agent.alg.gae_lambda, @@ -309,8 +300,8 @@ def generate_batch( # generate text using the model obs_tensor = dict_to_tensor(current_obs, self.device) - generation_inputs = self._agent.alg.model.get_inputs_for_generation(obs_tensor) - gen_output = self._agent.alg.model.generate( + generation_inputs = self._agent.get_inputs_for_generation(obs_tensor) + gen_output = self._agent.generate( input_ids=generation_inputs.inputs, attention_mask=generation_inputs.attention_masks, tokenizer=tokenizer, @@ -344,7 +335,7 @@ def generate_batch( obs_tensor, actions_tensor, policy_past_state, action_mask ) - policy_outputs: PolicyOutput = self._agent.alg.model.forward_policy( + policy_outputs: PolicyOutput = self._agent.forward_policy( **policy_kwargs ) raw_log_probs, log_probs, policy_past_state = ( @@ -364,7 +355,7 @@ def generate_batch( ), "Infinite values in log probs" # get values - value_outputs: ValueOutput = self._agent.alg.model.forward_value( + value_outputs: ValueOutput = self._agent.forward_value( obs_tensor, value_past_state ) values, value_past_state = ( @@ -374,7 +365,7 @@ def generate_batch( # get reference log probs ref_policy_outputs: RefPolicyOutput = ( - self._agent.alg.model.get_log_probs_ref_model( + self._agent.get_log_probs_ref_model( obs_tensor, actions_tensor, ref_past_state ) ) @@ -516,7 +507,8 @@ def collect_rollouts( tokenizer = tokenizer[0] # Switch to eval mode - self._agent.alg.model.set_training_mode(False) + # self._agent.alg.model.set_training_mode(False) + self._agent.eval_mode() # reset rollout buffer and stats rollout_buffer.reset() diff --git a/benchmark/torch/RL4LMs/utils/__init__.py b/benchmark/torch/RL4LMs/utils/__init__.py index 1905032f0..fe791503d 100644 --- a/benchmark/torch/RL4LMs/utils/__init__.py +++ b/benchmark/torch/RL4LMs/utils/__init__.py @@ -5,8 +5,6 @@ from .huggingface_generation_util import override_generation_routines -from .warm_start import ActorCriticWarmStartMixin, OnPolicyWarmStartMixin - from .type_wrapper import TensorDict, Schedule from .distribution_wrapper import CategoricalDistribution diff --git a/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py b/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py index b54644216..421510446 100644 --- a/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py +++ b/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py @@ -25,7 +25,6 @@ from torch import nn from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint -from transformers.generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from transformers.generation_logits_process import ( EncoderNoRepeatNGramLogitsProcessor, ExponentialDecayLengthPenalty, @@ -1288,26 +1287,7 @@ def generate( ) # 9. go into different generation modes - if is_greedy_gen_mode: - if num_return_sequences > 1: - raise ValueError( - f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." - ) - - # 10. run greedy search - return self.greedy_search( - input_ids, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - elif is_sample_gen_mode: + if is_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams @@ -1335,217 +1315,16 @@ def generate( **model_kwargs, ) - elif is_beam_gen_mode: - if num_return_sequences > num_beams: - raise ValueError( - "`num_return_sequences` has to be smaller or equal to `num_beams`.") - - if stopping_criteria.max_length is None: - raise ValueError( - "`max_length` needs to be a stopping_criteria for now.") - - # 10. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=self.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, - ) - # 11. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs - ) - # 12. run beam search - return self.beam_search( - input_ids, - beam_scorer, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - elif is_beam_sample_gen_mode: - # 10. prepare logits warper - logits_warper = self._get_logits_warper( - top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams - ) - - if stopping_criteria.max_length is None: - raise ValueError( - "`max_length` needs to be a stopping_criteria for now.") - # 11. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size * num_return_sequences, - num_beams=num_beams, - device=self.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - ) - - # 12. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, - expand_size=num_beams * num_return_sequences, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - - # 13. run beam sample - return self.beam_sample( - input_ids, - beam_scorer, - logits_processor=logits_processor, - logits_warper=logits_warper, - stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - elif is_group_beam_gen_mode: - if num_return_sequences > num_beams: - raise ValueError( - "`num_return_sequences` has to be smaller or equal to `num_beams`.") - - if num_beams % num_beam_groups != 0: - raise ValueError( - "`num_beams` should be divisible by `num_beam_groups` for group beam search.") - - if stopping_criteria.max_length is None: - raise ValueError( - "`max_length` needs to be a stopping_criteria for now.") - - # 10. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - max_length=stopping_criteria.max_length, - device=self.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, - num_beam_groups=num_beam_groups, - ) - # 11. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs - ) - # 12. run beam search - return self.group_beam_search( - input_ids, - beam_scorer, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - elif is_constraint_gen_mode: - if num_return_sequences > num_beams: - raise ValueError( - "`num_return_sequences` has to be smaller or equal to `num_beams`.") - - if stopping_criteria.max_length is None: - raise ValueError( - "`max_length` needs to be a stopping_criteria for now.") - - if num_beams <= 1: - raise ValueError( - "`num_beams` needs to be greater than 1 for constrained genertation.") - - if do_sample: - raise ValueError( - "`do_sample` needs to be false for constrained generation.") - - if num_beam_groups is not None and num_beam_groups > 1: - raise ValueError( - "`num_beam_groups` not supported yet for constrained generation.") - - final_constraints = [] - if constraints is not None: - final_constraints = constraints - - if force_words_ids is not None: - - def typeerror(): - raise ValueError( - "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" - f"of positive integers, but is {force_words_ids}." - ) + else: + raise NotImplementedError - if not isinstance(force_words_ids, list) or len(force_words_ids) == 0: - typeerror() - - for word_ids in force_words_ids: - if isinstance(word_ids[0], list): - if not isinstance(word_ids, list) or len(word_ids) == 0: - typeerror() - if any(not isinstance(token_ids, list) for token_ids in word_ids): - typeerror() - if any( - any((not isinstance(token_id, int) or token_id < 0) - for token_id in token_ids) - for token_ids in word_ids - ): - typeerror() - - constraint = DisjunctiveConstraint(word_ids) - else: - if not isinstance(word_ids, list) or len(word_ids) == 0: - typeerror() - if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): - typeerror() - - constraint = PhrasalConstraint(word_ids) - final_constraints.append(constraint) - - # 10. prepare beam search scorer - constrained_beam_scorer = ConstrainedBeamSearchScorer( - constraints=final_constraints, - batch_size=batch_size, - num_beams=num_beams, - device=self.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, - ) - # 11. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs - ) - # 12. run beam search - return self.constrained_beam_search( - input_ids, - constrained_beam_scorer=constrained_beam_scorer, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - synced_gpus=synced_gpus, - **model_kwargs, - ) - def greedy_search( + def sample( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, @@ -1555,10 +1334,10 @@ def greedy_search( return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, **model_kwargs, - ) -> Union[GreedySearchOutput, torch.LongTensor]: + ) -> Union[SampleOutput, torch.LongTensor]: r""" - Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be - used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: @@ -1570,7 +1349,10 @@ def greedy_search( stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. @@ -1591,14 +1373,14 @@ def greedy_search( synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) model_kwargs: - Additional model specific keyword arguments will be forwarded to the `forward` function of the model. - If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. Return: - [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`] - or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if + [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: @@ -1609,9 +1391,12 @@ def greedy_search( ... AutoModelForCausalLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, ... StoppingCriteriaList, ... MaxLengthCriteria, ... ) + >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") @@ -1619,34 +1404,48 @@ def greedy_search( >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token >>> model.config.pad_token_id = model.config.eos_token_id - >>> input_prompt = "It might be possible to" + >>> input_prompt = "Today is a beautiful day, and" >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [ - ... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id), + ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), ... ] ... ) + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - >>> outputs = model.greedy_search( - ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ["It might be possible to get a better understanding of the nature of the problem, but it's not"] + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] ```""" + # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( - "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria( stopping_criteria, max_length) + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores @@ -1679,6 +1478,7 @@ def greedy_search( cur_len = input_ids.shape[-1] this_peer_finished = False # used by synced_gpus only + # auto-regressive generation while True: if synced_gpus: @@ -1708,12 +1508,19 @@ def greedy_search( cur_len = cur_len + 1 continue # don't waste resources running the code we don't need + next_token_logits_raw = outputs.logits[:, -1, :].clone() next_token_logits = outputs.logits[:, -1, :] + # pre-process distribution + next_token_scores = logits_processor( + input_ids, next_token_logits, model_inputs=model_inputs) + next_token_scores = logits_warper( + input_ids, next_token_scores) + # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: - scores += (next_token_logits,) + scores += ((next_token_logits_raw, next_token_scores),) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( @@ -1729,12 +1536,9 @@ def greedy_search( else (outputs.hidden_states,) ) - # pre-process distribution - next_tokens_scores = logits_processor( - input_ids, next_token_logits, model_inputs) - - # argmax - next_tokens = torch.argmax(next_tokens_scores, dim=-1) + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token if eos_token_id is not None: @@ -1765,7 +1569,7 @@ def greedy_search( if return_dict_in_generate: if self.config.is_encoder_decoder: - return GreedySearchEncoderDecoderOutput( + return SampleEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, @@ -1775,7 +1579,7 @@ def greedy_search( decoder_hidden_states=decoder_hidden_states, ) else: - return GreedySearchDecoderOnlyOutput( + return SampleDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, @@ -1784,1664 +1588,6 @@ def greedy_search( else: return input_ids - def sample( - self, - input_ids: torch.LongTensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - **model_kwargs, - ) -> Union[SampleOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and - can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForCausalLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... TopKLogitsWarper, - ... TemperatureLogitsWarper, - ... StoppingCriteriaList, - ... MaxLengthCriteria, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - - >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token - >>> model.config.pad_token_id = model.config.eos_token_id - - >>> input_prompt = "Today is a beautiful day, and" - >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - >>> # instantiate logits processors - >>> logits_warper = LogitsProcessorList( - ... [ - ... TopKLogitsWarper(50), - ... TemperatureLogitsWarper(0.7), - ... ] - ... ) - - >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - - >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT - >>> outputs = model.sample( - ... input_ids, - ... logits_processor=logits_processor, - ... logits_warper=logits_warper, - ... stopping_criteria=stopping_criteria, - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] - ```""" - - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length) - logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if ( - return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get( - "attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get( - "hidden_states") if output_hidden_states else None - ) - - # keep track of which sequences are already finished - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - cur_len = input_ids.shape[-1] - - this_peer_finished = False # used by synced_gpus only - # auto-regressive generation - while True: - - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation( - input_ids, **model_kwargs) - - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - next_token_logits_raw = outputs.logits[:, -1, :].clone() - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor( - input_ids, next_token_logits, model_inputs=model_inputs) - next_token_scores = logits_warper( - input_ids, next_token_scores) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += ((next_token_logits_raw, next_token_scores),) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( - outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + \ - pad_token_id * (1 - unfinished_sequences) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - cur_len = cur_len + 1 - - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul( - (next_tokens != eos_token_id).long()) - - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: - this_peer_finished = True - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return SampleEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return SampleDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return input_ids - - def beam_search( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - **model_kwargs, - ) -> Union[BeamSearchOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **beam search decoding** and - can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... BeamSearchScorer, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - - >>> # lets run beam search using 3 beams - >>> num_beams = 3 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> # instantiate beam scorer - >>> beam_scorer = BeamSearchScorer( - ... batch_size=1, - ... num_beams=num_beams, - ... device=model.device, - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - - >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt bist du?'] - ```""" - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length) - if len(stopping_criteria) == 0: - warnings.warn( - "You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate - ) - - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - beam_indices = ( - tuple(() for _ in range(batch_beam_size)) if ( - return_dict_in_generate and output_scores) else None - ) - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if ( - return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get( - "attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get( - "hidden_states") if output_hidden_states else None - ) - - beam_scores = torch.zeros( - (batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False # used by synced_gpus only - while True: - - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - model_inputs = self.prepare_inputs_for_generation( - input_ids, **model_kwargs) - - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - next_token_logits = outputs.logits[:, -1, :] - next_token_logits_raw = next_token_logits.clone() - - # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` - # cannot be generated both before and after the `nn.functional.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation( - next_token_logits, cur_len=cur_len) - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) - - next_token_scores_processed = logits_processor( - input_ids, next_token_scores, model_inputs=model_inputs) - next_token_scores = next_token_scores_processed + \ - beam_scores[:, None].expand_as(next_token_scores) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_logits_raw,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( - outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view( - batch_size, num_beams * vocab_size) - - next_token_scores, next_tokens = torch.topk( - next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True - ) - - next_indices = torch_int_div(next_tokens, vocab_size) - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - ) - - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids = torch.cat( - [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - if model_kwargs["past"] is not None: - model_kwargs["past"] = self._reorder_cache( - model_kwargs["past"], beam_idx) - - if return_dict_in_generate and output_scores: - beam_indices = tuple( - (beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) - - # increase cur_len - cur_len = cur_len + 1 - - if beam_scorer.is_done or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: - this_peer_finished = True - - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - else: - num_return_sequences = beam_scorer.num_beam_hyps_to_keep - # return only as many indices as sequences - beam_indices = tuple( - (beam_indices[i * num_beams: i * num_beams + - num_return_sequences] for i in range(batch_size)) - ) - beam_indices = sum(beam_indices, ()) - - step_wise_raw_logits = self.compute_beam_search_raw_logits( - sequence_outputs["sequences"].clone(), - scores, - beam_indices, - eos_token_id) - - if self.config.is_encoder_decoder: - return BeamSearchEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=step_wise_raw_logits, # raw logits - beam_indices=beam_indices, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return BeamSearchDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - beam_indices=beam_indices, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return sequence_outputs["sequences"] - - def beam_sample( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - **model_kwargs, - ) -> Union[BeamSampleOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **beam search multinomial - sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... TopKLogitsWarper, - ... TemperatureLogitsWarper, - ... BeamSearchScorer, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - >>> # lets run beam search using 3 beams - >>> num_beams = 3 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> # instantiate beam scorer - >>> beam_scorer = BeamSearchScorer( - ... batch_size=1, - ... max_length=model.config.max_length, - ... num_beams=num_beams, - ... device=model.device, - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)] - ... ) - >>> # instantiate logits processors - >>> logits_warper = LogitsProcessorList( - ... [ - ... TopKLogitsWarper(50), - ... TemperatureLogitsWarper(0.7), - ... ] - ... ) - - >>> outputs = model.beam_sample( - ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt bist du?'] - ```""" - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate - ) - - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - beam_indices = ( - tuple(() for _ in range(batch_beam_size)) if ( - return_dict_in_generate and output_scores) else None - ) - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if ( - return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get( - "attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get( - "hidden_states") if output_hidden_states else None - ) - - beam_scores = torch.zeros( - (batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False # used by synced_gpus only - while True: - - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - model_inputs = self.prepare_inputs_for_generation( - input_ids, **model_kwargs) - - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - next_token_logits_raw = outputs.logits[:, -1, :] - - # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` - # cannot be generated both before and after the `nn.functional.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation( - next_token_logits_raw, cur_len=cur_len) - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) - - next_token_scores_processed = logits_processor( - input_ids, next_token_logits, model_inputs=model_inputs) - next_token_scores = next_token_scores_processed + \ - beam_scores[:, None].expand_as(next_token_scores) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - # return raw scores instead of post-processed - scores += ((next_token_logits_raw, next_token_scores),) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( - outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view( - batch_size, num_beams * vocab_size) - - probs = nn.functional.softmax(next_token_scores, dim=-1) - - next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) - next_token_scores = torch.gather( - next_token_scores, -1, next_tokens) - - next_token_scores, _indices = torch.sort( - next_token_scores, descending=True, dim=1) - next_tokens = torch.gather(next_tokens, -1, _indices) - - next_indices = torch_int_div(next_tokens, vocab_size) - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - ) - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids = torch.cat( - [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - if model_kwargs["past"] is not None: - model_kwargs["past"] = self._reorder_cache( - model_kwargs["past"], beam_idx) - - if return_dict_in_generate and output_scores: - beam_indices = tuple( - (beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) - - # increase cur_len - cur_len = cur_len + 1 - - if beam_scorer.is_done or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: - this_peer_finished = True - - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - else: - num_return_sequences = beam_scorer.num_beam_hyps_to_keep - # return only as many indices as sequences - beam_indices = tuple( - (beam_indices[i * num_beams: i * num_beams + - num_return_sequences] for i in range(batch_size)) - ) - beam_indices = sum(beam_indices, ()) - - if self.config.is_encoder_decoder: - return BeamSampleEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - beam_indices=beam_indices, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return BeamSampleDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - beam_indices=beam_indices, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return sequence_outputs["sequences"] - - def group_beam_search( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - **model_kwargs, - ): - r""" - Generates sequences of token ids for models with a language modeling head using **diverse beam search - decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - - model_kwargs: - Additional model specific kwargs that will be forwarded to the `forward` function of the model. If - model is an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if - `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a - [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... HammingDiversityLogitsProcessor, - ... BeamSearchScorer, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - - >>> # lets run diverse beam search using 6 beams - >>> num_beams = 6 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> # instantiate beam scorer - >>> beam_scorer = BeamSearchScorer( - ... batch_size=1, - ... max_length=model.config.max_length, - ... num_beams=num_beams, - ... device=model.device, - ... num_beam_groups=3, - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), - ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - - >>> outputs = model.group_beam_search( - ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt bist du?'] - ```""" - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate - ) - - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams - num_beam_groups = beam_scorer.num_beam_groups - num_sub_beams = num_beams // num_beam_groups - device = input_ids.device - - batch_beam_size, cur_len = input_ids.shape - - if return_dict_in_generate and output_scores: - beam_indices = [tuple(() for _ in range( - num_sub_beams * batch_size)) for _ in range(num_beam_groups)] - else: - beam_indices = None - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if ( - return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get( - "attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get( - "hidden_states") if output_hidden_states else None - ) - - beam_scores = torch.full( - (batch_size, num_beams), -1e9, dtype=torch.float, device=device) - # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in - # the same group don't produce same tokens everytime. - beam_scores[:, ::num_sub_beams] = 0 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False # used by synced_gpus only - while True: - - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - # predicted tokens in cur_len step - current_tokens = torch.zeros( - batch_size * num_beams, dtype=input_ids.dtype, device=device) - - # indices which will form the beams in the next time step - reordering_indices = torch.zeros( - batch_size * num_beams, dtype=torch.long, device=device) - - # do one decoder step on all beams of all sentences in batch - model_inputs = self.prepare_inputs_for_generation( - input_ids, **model_kwargs) - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - if output_scores: - processed_score = torch.zeros_like(outputs.logits[:, -1, :]) - - for beam_group_idx in range(num_beam_groups): - group_start_idx = beam_group_idx * num_sub_beams - group_end_idx = min(group_start_idx + num_sub_beams, num_beams) - group_size = group_end_idx - group_start_idx - - # indices of beams of current group among all sentences in batch - batch_group_indices = [] - - for batch_idx in range(batch_size): - batch_group_indices.extend( - [batch_idx * num_beams + - idx for idx in range(group_start_idx, group_end_idx)] - ) - group_input_ids = input_ids[batch_group_indices] - - # select outputs of beams of current group only - next_token_logits_raw = outputs.logits[batch_group_indices, -1, :] - - # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` - # cannot be generated both before and after the `nn.functional.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation( - next_token_logits_raw, cur_len=cur_len) - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * group_size, vocab_size) - vocab_size = next_token_scores.shape[-1] - - next_token_scores_processed = logits_processor( - group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx, model_inputs=model_inputs - ) - next_token_scores = next_token_scores_processed + \ - beam_scores[batch_group_indices].unsqueeze(-1) - next_token_scores = next_token_scores.expand_as( - next_token_scores_processed) - - if output_scores: - processed_score[batch_group_indices] = next_token_logits_raw - - # reshape for beam search - next_token_scores = next_token_scores.view( - batch_size, group_size * vocab_size) - - next_token_scores, next_tokens = torch.topk( - next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True - ) - - next_indices = torch_int_div(next_tokens, vocab_size) - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = beam_scorer.process( - group_input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - ) - beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - if return_dict_in_generate and output_scores: - beam_indices[beam_group_idx] = tuple( - beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) - ) - - input_ids[batch_group_indices] = group_input_ids[beam_idx] - group_input_ids = torch.cat( - [group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - current_tokens[batch_group_indices] = group_input_ids[:, -1] - - # (beam_idx // group_size) -> batch_idx - # (beam_idx % group_size) -> offset of idx inside the group - reordering_indices[batch_group_indices] = ( - num_beams * - torch_int_div(beam_idx, group_size) + - group_start_idx + (beam_idx % group_size) - ) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (processed_score,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( - outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - input_ids = torch.cat( - [input_ids, current_tokens.unsqueeze(-1)], dim=-1) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - if model_kwargs["past"] is not None: - model_kwargs["past"] = self._reorder_cache( - model_kwargs["past"], reordering_indices) - - # increase cur_len - cur_len = cur_len + 1 - - if beam_scorer.is_done or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: - this_peer_finished = True - - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - else: - beam_indices = sum(beam_indices, ()) - num_return_sequences = beam_scorer.num_beam_hyps_to_keep - # return only as many indices as sequences - beam_indices = tuple( - (beam_indices[i * num_beams: i * num_beams + - num_return_sequences] for i in range(batch_size)) - ) - beam_indices = sum(beam_indices, ()) - - if self.config.is_encoder_decoder: - return BeamSearchEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - beam_indices=beam_indices, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return BeamSearchDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return sequence_outputs["sequences"] - - def constrained_beam_search( - self, - input_ids: torch.LongTensor, - constrained_beam_scorer: ConstrainedBeamSearchScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = None, - **model_kwargs, - ) -> Union[BeamSearchOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **constrained beam search - decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - constrained_beam_scorer (`ConstrainedBeamSearchScorer`): - A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation, while satisfying a list of positive constraints. For more information, the - documentation of [`ConstrainedBeamSearchScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... ConstrainedBeamSearchScorer, - ... PhrasalConstraint, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - - >>> # lets run beam search using 3 beams - >>> num_beams = 3 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> constraint_str = "Sie" - >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token - >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] - - - >>> # instantiate beam scorer - >>> beam_scorer = ConstrainedBeamSearchScorer( - ... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - - >>> outputs = model.constrained_beam_search( - ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt sind Sie?'] - ```""" - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length) - if len(stopping_criteria) == 0: - warnings.warn( - "You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if ( - return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get( - "attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get( - "hidden_states") if output_hidden_states else None - ) - - batch_size = len(constrained_beam_scorer._beam_hyps) - num_beams = constrained_beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - ) - - beam_scores = torch.zeros( - (batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False # used by synced_gpus only - while True: - - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - model_inputs = self.prepare_inputs_for_generation( - input_ids, **model_kwargs) - - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - next_token_logits_raw = outputs.logits[:, -1, :] - # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` - # cannot be generated both before and after the `nn.functional.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation( - next_token_logits_raw, cur_len=cur_len) - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) - - next_token_scores_processed = logits_processor( - input_ids, next_token_scores, model_inputs=model_inputs) - - scores_for_all_vocab = next_token_scores_processed.clone() - - next_token_scores = next_token_scores_processed + \ - beam_scores[:, None].expand_as(next_token_scores) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += ((next_token_logits_raw, next_token_scores),) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( - outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view( - batch_size, num_beams * vocab_size) - - next_token_scores, next_tokens = torch.topk( - next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True - ) - - next_indices = (next_tokens / vocab_size).long() - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = constrained_beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - scores_for_all_vocab, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - ) - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids = torch.cat( - [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - if model_kwargs["past"] is not None: - model_kwargs["past"] = self._reorder_cache( - model_kwargs["past"], beam_idx) - - # increase cur_len - cur_len = cur_len + 1 - - if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: - this_peer_finished = True - - sequence_outputs = constrained_beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - if self.config.is_encoder_decoder: - return BeamSearchEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return BeamSearchDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return sequence_outputs["sequences"] def top_k_top_p_filtering( diff --git a/benchmark/torch/RL4LMs/utils/kl_controller.py b/benchmark/torch/RL4LMs/utils/kl_controller.py index ad2d3a7ab..377d196aa 100644 --- a/benchmark/torch/RL4LMs/utils/kl_controller.py +++ b/benchmark/torch/RL4LMs/utils/kl_controller.py @@ -19,14 +19,3 @@ def step(self, kl_div: torch.tensor): @property def kl_coeff(self): return self._kl_coeff - - def get_state_dict(self) -> Dict[str, Any]: - state = { - "target_kl": self._target_kl, - "current_kl_coeff": self._kl_coeff - } - return state - - def load_from_state_dict(self, state_dict: Dict[str, Any]): - self._kl_coeff = state_dict["current_kl_coeff"] - self._target_kl = state_dict["target_kl"] \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/tracker.py b/benchmark/torch/RL4LMs/utils/tracker.py deleted file mode 100644 index 203aa94ec..000000000 --- a/benchmark/torch/RL4LMs/utils/tracker.py +++ /dev/null @@ -1,143 +0,0 @@ -from collections import defaultdict -from typing import Dict, Any, List -import os -import json -import jsonlines -import pandas as pd -from transformers import AutoModel -import logging -import copy -import random - - -class Tracker: - def __init__(self, - base_path_to_store_results: str, - run_config: Dict[str, Any], - project_name: str, - experiment_name: str, - entity_name: str = None, - wandb_log: bool = False, - log_level: int = logging.DEBUG, - ): - self._log_level = log_level - self._base_path_to_store_results = base_path_to_store_results - self._config = run_config - self._experiment_name = experiment_name - self._project_name = project_name - self._entity_name = entity_name - self._wandb_log = wandb_log - self._init() - - def _init(self): - # create a folder - self._run_path = os.path.join( - self._base_path_to_store_results, - self._project_name, - self._experiment_name) - os.makedirs(self._run_path, exist_ok=True) - - # store also the config into it - config_path = os.path.join(self._run_path, "config.json") - with open(config_path, "w") as fp: - json.dump(self._config, fp) - - # init logger - log_path = os.path.join(self._run_path, "log.txt") - logging.basicConfig( - level=self._log_level, - format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[ - logging.FileHandler(log_path) ] - ) - - - def log_predictions(self, epoch: int, - split_name: str, - predictions: List[Dict]): - # log them per epoch in a separate file as they can get huge - prediction_file_at_epoch = os.path.join( - self._run_path, f"epoch_{epoch}_{split_name}_split_predictions.json") - with open(prediction_file_at_epoch, "w") as fp: - json.dump(predictions, fp) - - # randomly display few predictions for logging - predictions_ = copy.deepcopy(predictions) - random.shuffle(predictions_) - logging.info(f"Split {split_name} predictions") - for pred in predictions_[:10]: - logging.info(pred) - - - def log_metrics(self, epoch: int, - split_name: str, - metrics_dict: Dict[str, float]): - # for each split, one file - metric_file_per_split = os.path.join( - self._run_path, f"{split_name}_split_metrics.jsonl") - metrics_dict_ = { - "epoch": epoch, - "metrics": metrics_dict - } - with jsonlines.open(metric_file_per_split, "a") as writer: - writer.write(metrics_dict_) - - # logger - logging.info(f"{split_name} metrics: {metrics_dict_}") - - def log_rollout_infos(self, rollout_info: Dict[str, float]): - logging.info(f"Rollout Info: {rollout_info}") - rollout_info_file = os.path.join( - self._run_path, "rollout_info.jsonl") - with jsonlines.open(rollout_info_file, mode="a") as writer: - writer.write(rollout_info) - - def log_training_infos(self, training_info: Dict[str, float]): - logging.info(f"Training Info: {training_info}") - training_info_file = os.path.join( - self._run_path, "training_info.jsonl") - with jsonlines.open(training_info_file, mode="a") as writer: - writer.write(training_info) - - def done(self): - pass - - def save_auto_model(self, model: AutoModel): - model_path = os.path.join(self._run_path, "model") - model.save_pretrained(model_path) - - @property - def checkpoint_base_path(self): - return os.path.join(self._run_path, "checkpoints") - - def log_info(self, msg: str): - logging.info(msg) - - -if __name__ == "__main__": - base_path = "/data/zhangsw/" - run_config = { - "param_1": 1, - "param_2": 2 - } - predictions = { - "1": [{"sample_id": "1", "prompt_text": "Hello", "gen_text": "I am there"}, - {"sample_id": "2", "prompt_text": "Hi", "gen_text": "there"}], - "2": [{"sample_id": "1", "prompt_text": "Hello", "gen_text": "I am there"}, - {"sample_id": "2", "prompt_text": "Hi", "gen_text": "there"}], - "3": [{"sample_id": "1", "prompt_text": "Hello", "gen_text": "I am there"}, - {"sample_id": "2", "prompt_text": "Hi", "gen_text": "there"}], - } - - metrics = { - "1": {"metric_1": 0.05, "metric_2": 0.1}, - "2": {"metric_1": 0.06, "metric_2": 0.2}, - "3": {"metric_1": 0.06, "metric_2": 0.3}, - } - - rollout_infos = [ - {"ep_len": 2, "ep_reward": 0.4}, - {"ep_len": 3, "ep_reward": 0.5}, - {"ep_len": 3, "ep_reward": 0.5}, - ] - diff --git a/benchmark/torch/RL4LMs/utils/warm_start.py b/benchmark/torch/RL4LMs/utils/warm_start.py deleted file mode 100644 index fa700230f..000000000 --- a/benchmark/torch/RL4LMs/utils/warm_start.py +++ /dev/null @@ -1,110 +0,0 @@ -import os -from typing import Any, Dict -import torch - -# from rl4lms.envs.text_generation.logging_utils import Tracker -# from rl4lms.envs.text_generation.policy.base_policy import LMActorCriticPolicy - - -class ActorCriticWarmStartMixin: - def get_state_dict(self) -> Dict[str, Any]: - state_dict = { - "policy_model": self._policy_model.state_dict(), - "value_model": self._value_model.state_dict(), - "value_head": self._value_head.state_dict(), - "optimizer": self.optimizer.state_dict() - } - return state_dict - - def load_from_dict(self, state_dict: dict = None): - if state_dict is not None: - self._policy_model.load_state_dict(state_dict["policy_model"]) - self._value_model.load_state_dict(state_dict["value_model"]) - self._value_head.load_state_dict(state_dict["value_head"]) - self.optimizer.load_state_dict(state_dict["optimizer"]) - - - -class OnPolicyWarmStartMixin: - def get_state_dict(self) -> Dict[str, Any]: - # just the kl controller state is sufficient for onpolicy algs - state_dict = { - "kl_controller_state": self._kl_controller.get_state_dict(), - } - return state_dict - - def load_from_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: - if state_dict is not None: - self._kl_controller.load_from_state_dict( - state_dict["kl_controller_state"]) - - - -# -# -################## Trainer Warm Start Mixins####################################### -# class TrainerWarmStartMixin: -# def _get_recent_ckpt_path(self, tracker: Tracker): -# try: -# checkpoints = os.listdir(tracker.checkpoint_base_path) -# except: -# os.makedirs(tracker.checkpoint_base_path) -# checkpoints = os.listdir(tracker.checkpoint_base_path) -# -# if len(checkpoints) == 0: -# return None, None -# -# sorted_ckpts = sorted(checkpoints, reverse=True, -# key=lambda ckpt: int(ckpt.split("_")[1])) -# recent_ckpt = sorted_ckpts[0] -# recent_ckpt_id = int(recent_ckpt.split("_")[1]) -# -# recent_ckpt_path = os.path.join( -# tracker.checkpoint_base_path, f"checkpoint_{recent_ckpt_id}") -# return recent_ckpt_path, recent_ckpt_id -# -# def load_trainer_state(self, tracker: Tracker): -# recent_ckpt_path, _ = self._get_recent_ckpt_path(tracker) -# state_dict = None -# try: -# if recent_ckpt_path is not None: -# state_dict = torch.load( -# recent_ckpt_path, map_location=torch.device("cuda")) -# tracker.log_info("Model checkpoint found - Warm starting") -# self._policy_state_dict = state_dict["policy_state"] -# self._alg_state_dict = state_dict["alg_state"] -# self._trainer_state = state_dict["trainer_state"] -# -# tracker.log_info( -# f"Loaded the current trainer state from: {self._trainer_state}") -# else: -# self._policy_state_dict = None -# self._alg_state_dict = None -# self._trainer_state = { -# "current_iter": 0, -# } -# except Exception as e: -# tracker.log_info(f"Exception while doing warm start {e}") -# tracker.log_info( -# f"Checkpoint may be corrupted...skipping warm start") -# self._policy_state_dict = None -# self._alg_state_dict = None -# self._trainer_state = { -# "current_iter": 0, -# } -# -# def save_trainer_state(self, tracker: Tracker, -# policy, -# trainer_state: Dict[str, Any]): -# full_state = { -# "alg_state": self._agent.alg.get_state_dict(), -# "policy_state": policy.get_state_dict(), -# "trainer_state": trainer_state -# } -# _, recent_ckpt_id = self._get_recent_ckpt_path(tracker) -# -# # hot fix - just to save only the last checkpoint (overwrite) -# new_ckpt_id = 0 if recent_ckpt_id is None else recent_ckpt_id + 1 -# new_ckpt_path = os.path.join( -# tracker.checkpoint_base_path, f"checkpoint_{new_ckpt_id}") -# torch.save(full_state, new_ckpt_path, pickle_protocol=4) From 02efdd91189d21f156154972cf7d081b7712eb4b Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 6 Mar 2023 18:01:52 +0800 Subject: [PATCH 05/34] merge models and buffers, add README.md --- benchmark/torch/RL4LMs/README.md | 19 + benchmark/torch/RL4LMs/agents/__init__.py | 2 +- ...{rl4lms_summa_agent.py => rl4lms_agent.py} | 4 +- .../RL4LMs/configs/summarization/t5_ppo.yml | 10 +- benchmark/torch/RL4LMs/models/__init__.py | 1 - benchmark/torch/RL4LMs/models/base_model.py | 431 ------------- .../torch/RL4LMs/models/seq2seq_model.py | 183 +++++- benchmark/torch/RL4LMs/registry.py | 31 +- benchmark/torch/RL4LMs/trainers.py | 164 ++--- benchmark/torch/RL4LMs/utils/__init__.py | 3 +- benchmark/torch/RL4LMs/utils/buffer.py | 593 +++--------------- benchmark/torch/RL4LMs/utils/data_wrapper.py | 29 +- 12 files changed, 392 insertions(+), 1078 deletions(-) create mode 100644 benchmark/torch/RL4LMs/README.md rename benchmark/torch/RL4LMs/agents/{rl4lms_summa_agent.py => rl4lms_agent.py} (98%) delete mode 100644 benchmark/torch/RL4LMs/models/base_model.py diff --git a/benchmark/torch/RL4LMs/README.md b/benchmark/torch/RL4LMs/README.md new file mode 100644 index 000000000..2112925ab --- /dev/null +++ b/benchmark/torch/RL4LMs/README.md @@ -0,0 +1,19 @@ +## Reproduce (Reconfiguration) Summarization in RL4LMs using PARL + +> Paper: [Is Reinforcement Learning (Not) for Natural Language Processing: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization](https://arxiv.org/abs/2210.01241) +> +> Official code: [RL4LMs](https://github.com/allenai/RL4LMs) +> +> Other code referenced: [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) + + +### Main contribution + +- Change from **\{ trainer: \{ ppo: \{ env, rollout_buffer, policy/model \} \} \}** to + **\{trainer: \{env, rollout_buffer, agent: \{ ppo: \{ model \} \} \} \}** according PARL architecture. + +### Running command + +```bash +python train.py --config_path configs/summarization/t5_ppo.yml +``` \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/agents/__init__.py b/benchmark/torch/RL4LMs/agents/__init__.py index b3361187c..72f8da7f4 100644 --- a/benchmark/torch/RL4LMs/agents/__init__.py +++ b/benchmark/torch/RL4LMs/agents/__init__.py @@ -1 +1 @@ -from .rl4lms_summa_agent import RL4LMsSummaAgent \ No newline at end of file +from .rl4lms_agent import RL4LMsAgent \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/agents/rl4lms_summa_agent.py b/benchmark/torch/RL4LMs/agents/rl4lms_agent.py similarity index 98% rename from benchmark/torch/RL4LMs/agents/rl4lms_summa_agent.py rename to benchmark/torch/RL4LMs/agents/rl4lms_agent.py index e873c54c4..0c366df7d 100644 --- a/benchmark/torch/RL4LMs/agents/rl4lms_summa_agent.py +++ b/benchmark/torch/RL4LMs/agents/rl4lms_agent.py @@ -20,13 +20,13 @@ def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y -class RL4LMsSummaAgent(parl.Agent): +class RL4LMsAgent(parl.Agent): def __init__(self, algorithm, alg_config, norm_reward: bool = False, ): - super(RL4LMsSummaAgent, self).__init__(algorithm) + super(RL4LMsAgent, self).__init__(algorithm) self.dataset = None self.config = alg_config self.n_epochs = alg_config["args"]["n_epochs"] diff --git a/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml b/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml index 783d6d246..50fe402ad 100644 --- a/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml +++ b/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml @@ -1,6 +1,5 @@ - tokenizer: model_name: t5-base padding_side: left @@ -20,8 +19,8 @@ datapool: env: ## CHANGE FOR DEBUG ## - n_envs: 10 -# n_envs: 2 +# n_envs: 10 + n_envs: 2 ## CHANGE FOR DEBUG ## args: max_prompt_length: 512 @@ -31,11 +30,12 @@ env: context_start_token: 0 alg: + agent_id: rl4lm_agent id: ppo args: - n_steps: 512 +# n_steps: 512 #####CHNAGE FOR DEBUG######## -# n_steps: 5 + n_steps: 5 #####CHANGE FOR DEBUG######## batch_size: 32 # verbose: 1 diff --git a/benchmark/torch/RL4LMs/models/__init__.py b/benchmark/torch/RL4LMs/models/__init__.py index 3a53cbfc4..ed9b32d20 100644 --- a/benchmark/torch/RL4LMs/models/__init__.py +++ b/benchmark/torch/RL4LMs/models/__init__.py @@ -1,2 +1 @@ -from .base_model import BaseModel, LMActorCriticModel from .seq2seq_model import Seq2SeqLMModel \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/models/base_model.py b/benchmark/torch/RL4LMs/models/base_model.py deleted file mode 100644 index 9d1b2c768..000000000 --- a/benchmark/torch/RL4LMs/models/base_model.py +++ /dev/null @@ -1,431 +0,0 @@ -from abc import abstractmethod -from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple, Type, Union - -import torch -from gym.spaces import Discrete -from gym.spaces.dict import Dict as DictSpace -from torch.distributions import Categorical -from transformers import AutoTokenizer, PreTrainedModel -from transformers.modeling_utils import unwrap_model -from torch import nn - -import gym -import numpy as np - -import parl - -TensorDict = Dict[Union[str, int], torch.Tensor] -from benchmark.torch.RL4LMs.utils import ( - - CategoricalDistribution, - - EvaluateActionsOutput, PolicyOutput, RefPolicyOutput, ValueOutput, - GenerationInputs, GenerationOutputs, PolicyType -) - - -# refer to stable_baselines3.common.policies -class BaseModel(parl.Model): - def __init__(self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - device=None): - super().__init__() - if optimizer_kwargs is None: - optimizer_kwargs = {} - - self.observation_space = observation_space - self.action_space = action_space - - self.optimizer_class = optimizer_class - self.optimizer_kwargs = optimizer_kwargs - self.optimizer = None - self.device = device - - @abstractmethod - def forward(self, *args, **kwargs): - pass - - @staticmethod - def _dummy_schedule(progress_remaining: float) -> float: - """(float) Useful for pickling policy.""" - del progress_remaining - return 0.0 - - - @staticmethod - def init_weights(module: nn.Module, gain: float = 1) -> None: - """ - Orthogonal initialization (used in PPO and A2C) - """ - if isinstance(module, (nn.Linear, nn.Conv2d)): - nn.init.orthogonal_(module.weight, gain=gain) - if module.bias is not None: - module.bias.data.fill_(0.0) - - @abstractmethod - def _predict(self, observation: torch.Tensor, deterministic: bool = False) -> torch.Tensor: - """ - Get the action according to the policy for a given observation. - - By default provides a dummy implementation -- not all BasePolicy classes - implement this, e.g. if they are a Critic in an Actor-Critic method. - - :param observation: - :param deterministic: Whether to use stochastic or deterministic actions - :return: Taken action according to the policy - """ - def _get_constructor_parameters(self) -> Dict[str, Any]: - return dict( - observation_space=self.observation_space, - action_space=self.action_space, - ) - - def save(self, path: str) -> None: - """ - Save model to a given location. - - :param path: - """ - torch.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) - - def set_training_mode(self, mode: bool) -> None: - self.train(mode) - - def predict( - self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, - episode_start: Optional[np.ndarray] = None, - deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: - """ - Get the policy action from an observation (and optional hidden state). - Includes sugar-coating to handle different observations (e.g. normalizing images). - - :param observation: the input observation - :param state: The last hidden states (can be None, used in recurrent policies) - :param episode_start: The last masks (can be None, used in recurrent policies) - this correspond to beginning of episodes, - where the hidden states of the RNN must be reset. - :param deterministic: Whether or not to return deterministic actions. - :return: the model's action and the next hidden state - (used in recurrent policies) - """ - # TODO (GH/1): add support for RNN policies - # if state is None: - # state = self.initial_state - # if episode_start is None: - # episode_start = [False for _ in range(self.n_envs)] - # Switch to eval mode (this affects batch norm / dropout) - self.set_training_mode(False) - - observation, vectorized_env = self.obs_to_tensor(observation) - - with torch.no_grad(): - actions = self._predict(observation, deterministic=deterministic) - # Convert to numpy - actions = actions.cpu().numpy() - - if isinstance(self.action_space, gym.spaces.Box): - if self.squash_output: - # Rescale to proper domain when using squashing - actions = self.unscale_action(actions) - else: - # Actions could be on arbitrary scale, so clip the actions to avoid - # out of bound error (e.g. if sampling from a Gaussian distribution) - actions = np.clip(actions, self.action_space.low, self.action_space.high) - - # Remove batch dimension if needed - if not vectorized_env: - actions = actions[0] - - return actions, state - - def scale_action(self, action: np.ndarray) -> np.ndarray: - """ - Rescale the action from [low, high] to [-1, 1] - (no need for symmetric action space) - - :param action: Action to scale - :return: Scaled action - """ - low, high = self.action_space.low, self.action_space.high - return 2.0 * ((action - low) / (high - low)) - 1.0 - - def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: - """ - Rescale the action from [-1, 1] to [low, high] - (no need for symmetric action space) - - :param scaled_action: Action to un-scale - """ - low, high = self.action_space.low, self.action_space.high - return low + (0.5 * (scaled_action + 1.0) * (high - low)) - -class LMActorCriticModel(BaseModel): - def __init__( - self, - observation_space: DictSpace, - action_space: Discrete, - model_name: str, - optimizer_kwargs: Dict[str, Any] = {}, - weight_decay: float = 1e-6, - apply_model_parallel: bool = True, - optimizer_class: torch.optim.Optimizer = torch.optim.AdamW, - generation_kwargs: Dict[str, Any] = {}, - prompt_truncation_side: str = "left", - device=None - ): - """ - - Args: - observation_space (DictSpace): Observation space - action_space (Discrete): Action space - model_name (str): name of the causal or seq2seq model from transformers library - optimizer_kwargs (Dict[str, Any], optional): optimizer kwargs. Defaults to {}. - weight_decay (float, optional): weight decay. Defaults to 1e-6. - apply_model_parallel (bool, optional): whether to apply model parallel. Defaults to True. - optimizer_class (torch.optim.Optimizer, optional): Optimizer class. Defaults to torch.optim.AdamW. - generation_kwargs (Dict[str, Any], optional): generation parameters for rollout. Defaults to {}. - prompt_truncation_side (str, optional): truncation side for prompt text. Defaults to "left". - """ - super().__init__(observation_space, action_space, device=device) - self._action_space = action_space - self._apply_model_parallel = apply_model_parallel - self._build_model_heads(model_name) - self._setup_optimizer(optimizer_kwargs, weight_decay, optimizer_class) - self._action_dist = CategoricalDistribution(self._action_space.n) - self._generation_kwargs = generation_kwargs - self._prompt_truncation_side = prompt_truncation_side - - def _setup_optimizer( - self, - optimizer_kwargs: Dict[str, Any], - weight_decay: float, - optimizer_class: torch.optim, - ): - params = list(self.named_parameters()) - - no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in params if not any(nd in n for nd in no_decay)], - "weight_decay": weight_decay, - }, - { - "params": [p for n, p in params if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - self.optimizer = optimizer_class( - optimizer_grouped_parameters, **optimizer_kwargs - ) - - def forward(self, *args, **kwargs): - # dummy just to comply with base policy - pass - - - def _predict( - self, observation: Dict[str, torch.tensor], deterministic: bool = False - ) -> torch.Tensor: - # dummy just to comply with base policy - pass - - def is_encoder_decoder(self, model: PreTrainedModel): - return unwrap_model(model).config.is_encoder_decoder - - def generate( - self, - tokenizer: AutoTokenizer, - texts: List[str] = None, - max_prompt_length: int = None, - input_ids: torch.tensor = None, - attention_mask: torch.tensor = None, - gen_kwargs: Dict[str, Any] = None, - ) -> GenerationOutputs: - - # if it different from rollout gen kwargs - if gen_kwargs is None: - gen_kwargs = self._generation_kwargs - - # switch to eval - self._policy_model.eval() - - if ( - input_ids is None - and attention_mask is None - and texts is not None - and max_prompt_length is not None - ): - # override truncation side for prompt - prev_truncation_side = tokenizer.truncation_side - tokenizer.truncation_side = self._prompt_truncation_side - encodings = tokenizer( - texts, - padding="max_length", - max_length=max_prompt_length, - return_tensors="pt", - return_attention_mask=True, - truncation=True, - ) - input_ids = encodings.input_ids - attention_mask = encodings.attention_mask - tokenizer.truncation_side = prev_truncation_side - - # if min_length argument is set and if policy is not a seq2seq LM (ie. causal LM) - # then it has to be adjusted to input_size + min_length - if "min_length" in gen_kwargs.keys() and not self.is_encoder_decoder( - self._policy_model - ): - generation_kwargs_ = deepcopy(gen_kwargs) - generation_kwargs_["min_length"] = ( - input_ids.shape[1] + gen_kwargs["min_length"] - ) - else: - generation_kwargs_ = gen_kwargs - - # generate - gen_output = unwrap_model(self._policy_model).generate( - inputs=input_ids.to(self.get_policy_first_device()), - attention_mask=attention_mask.to(self.get_policy_first_device()), - return_dict_in_generate=True, - output_scores=True, - **generation_kwargs_, - ) - - # number of tokens generated - seq_length = len(gen_output["scores"]) - - # get only the generated text (excluding prompt) - gen_tokens = gen_output["sequences"][:, -seq_length:] - - # to texts - gen_texts = [ - tokenizer.decode(output, skip_special_tokens=True) - for output in gen_tokens.tolist() - ] - - # extract scores (logits) - step_wise_logprobs = [] - step_wise_actions = [] - for step, logits in enumerate(gen_output["scores"]): - raw_logits, _ = logits - actions_at_step = gen_tokens[:, step] - distribution = Categorical(logits=raw_logits) - log_probs = distribution.log_prob(actions_at_step) - step_wise_logprobs.append(log_probs) - step_wise_actions.append(actions_at_step) - - gen_output = GenerationOutputs( - step_wise_logprobs, step_wise_actions, gen_tokens, gen_texts - ) - return gen_output - - def get_language_model(self): - return unwrap_model(self._policy_model) - - # Following methods need to be implemented by sub-classing - @abstractmethod - def _build_model_heads(self, model_name: str): - """ - Builds policy and value models - and sets self._policy_model and self._value_model - """ - raise NotImplementedError - - @abstractmethod - def forward_policy( - self, - obs: TensorDict, - actions: torch.tensor, - past_model_kwargs: Optional[Dict[str, torch.tensor]] = None, - ) -> PolicyOutput: - """ - Performs a forward pass on the policy and gets log_probs, entropy etc - corresponding to specified observation, actions - - This is invoked during rollout generation - - Args: - obs (TensorDict): observation - actions (torch.tensor): actions - past_model_kwargs (Optional[Dict[str, torch.tensor]], optional): Any cached past model activations which can be used for sequential foward passes. - Defaults to None. - """ - raise NotImplementedError - - @abstractmethod - def forward_value( - self, - obs: TensorDict, - past_model_kwargs: Optional[Dict[str, torch.tensor]] = None, - ) -> ValueOutput: - """ - Performs a forward pass on the value network and gets values corresponding to observations - - This is invoked during rollout generation - - Args: - obs (TensorDict): observation - past_model_kwargs (Optional[Dict[str, torch.tensor]], optional): Any cached past model activations which can be used for sequential foward passes. - Defaults to None. - """ - raise NotImplementedError - - @abstractmethod - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> EvaluateActionsOutput: - """ - Evaluates specified - and returns log_probs, values, entropy - - This is invoked for each mini-batch in rollout buffer during training iteration - """ - raise NotImplementedError - - @abstractmethod - def get_log_probs_ref_model( - self, - obs: TensorDict, - action: torch.tensor, - past_model_kwargs: Dict[str, Any] = None, - ) -> RefPolicyOutput: - """ - Performs a forward pass on the reference policy and gets log_probs - corresponding to specified observation, actions - - This is invoked during rollout generation to compute KL rewards - - Args: - obs (TensorDict): observation - past_model_kwargs (Optional[Dict[str, torch.tensor]], optional): Any cached past model activations which can be used for sequential foward passes. - Defaults to None. - """ - raise NotImplementedError - - @abstractmethod - def get_policy_first_device(self) -> torch.device: - """ - Returns the first device of the policy. Used in the case of model parallel - """ - raise NotImplementedError - - @abstractmethod - def get_policy_type(self) -> PolicyType: - """ - Returns the type of policy (causal or seq2seq) - """ - raise NotImplementedError - - @abstractmethod - def get_inputs_for_generation(self, obs: TensorDict) -> GenerationInputs: - """ - Extracts the prompt inputs and attention masks which is used as seed for generation - """ - raise NotImplementedError diff --git a/benchmark/torch/RL4LMs/models/seq2seq_model.py b/benchmark/torch/RL4LMs/models/seq2seq_model.py index 055e57d9d..08da7836c 100644 --- a/benchmark/torch/RL4LMs/models/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/models/seq2seq_model.py @@ -3,23 +3,25 @@ from gym.spaces import Discrete from gym.spaces.dict import Dict as DictSpace from torch import nn -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel from copy import deepcopy +from torch.distributions import Categorical from transformers.modeling_utils import unwrap_model + +import parl from benchmark.torch.RL4LMs.utils import ( override_generation_routines, - TensorDict, + TensorDict, CategoricalDistribution, GenerationInputs, PolicyOutput, RefPolicyOutput, ValueOutput, PolicyType, EvaluateActionsOutput, GenerationOutputs, ) -from .base_model import LMActorCriticModel -class Seq2SeqLMModel(LMActorCriticModel): +class Seq2SeqLMModel(parl.Model): def __init__( self, observation_space: DictSpace, @@ -34,18 +36,28 @@ def __init__( state_dict: Dict[str, Any] = None, device: torch.DeviceObjType = None, ): - super().__init__( - observation_space, - action_space, - model_name, - optimizer_kwargs, - weight_decay, - apply_model_parallel, - optimizer_class, - generation_kwargs, - prompt_truncation_side, - device=device - ) + super(Seq2SeqLMModel, self).__init__() + if optimizer_kwargs is None: + optimizer_kwargs = {} + + self.observation_space = observation_space + self.action_space = action_space + + self.optimizer_class = optimizer_class + self.optimizer_kwargs = optimizer_kwargs + self.optimizer = None + self.device = device + + self._action_space = action_space + self._apply_model_parallel = apply_model_parallel + self._build_model_heads(model_name) + self._setup_optimizer(optimizer_kwargs, weight_decay, optimizer_class) + self._action_dist = CategoricalDistribution(self._action_space.n) + self._generation_kwargs = generation_kwargs + self._prompt_truncation_side = prompt_truncation_side + + + # self.load_from_dict(state_dict) def _build_model_heads(self, model_name: str): @@ -323,3 +335,142 @@ def get_inputs_for_generation(self, obs: TensorDict) -> GenerationInputs: def get_policy_type(self): return PolicyType.SEQ2SEQ + + def get_language_model(self): + return unwrap_model(self._policy_model) + + def generate( + self, + tokenizer: AutoTokenizer, + texts: List[str] = None, + max_prompt_length: int = None, + input_ids: torch.tensor = None, + attention_mask: torch.tensor = None, + gen_kwargs: Dict[str, Any] = None, + ) -> GenerationOutputs: + + # if it different from rollout gen kwargs + if gen_kwargs is None: + gen_kwargs = self._generation_kwargs + + # switch to eval + self._policy_model.eval() + + if ( + input_ids is None + and attention_mask is None + and texts is not None + and max_prompt_length is not None + ): + # override truncation side for prompt + prev_truncation_side = tokenizer.truncation_side + tokenizer.truncation_side = self._prompt_truncation_side + encodings = tokenizer( + texts, + padding="max_length", + max_length=max_prompt_length, + return_tensors="pt", + return_attention_mask=True, + truncation=True, + ) + input_ids = encodings.input_ids + attention_mask = encodings.attention_mask + tokenizer.truncation_side = prev_truncation_side + + # if min_length argument is set and if policy is not a seq2seq LM (ie. causal LM) + # then it has to be adjusted to input_size + min_length + if "min_length" in gen_kwargs.keys() and not self.is_encoder_decoder( + self._policy_model + ): + generation_kwargs_ = deepcopy(gen_kwargs) + generation_kwargs_["min_length"] = ( + input_ids.shape[1] + gen_kwargs["min_length"] + ) + else: + generation_kwargs_ = gen_kwargs + + # generate + gen_output = unwrap_model(self._policy_model).generate( + inputs=input_ids.to(self.get_policy_first_device()), + attention_mask=attention_mask.to(self.get_policy_first_device()), + return_dict_in_generate=True, + output_scores=True, + **generation_kwargs_, + ) + + # number of tokens generated + seq_length = len(gen_output["scores"]) + + # get only the generated text (excluding prompt) + gen_tokens = gen_output["sequences"][:, -seq_length:] + + # to texts + gen_texts = [ + tokenizer.decode(output, skip_special_tokens=True) + for output in gen_tokens.tolist() + ] + + # extract scores (logits) + step_wise_logprobs = [] + step_wise_actions = [] + for step, logits in enumerate(gen_output["scores"]): + raw_logits, _ = logits + actions_at_step = gen_tokens[:, step] + distribution = Categorical(logits=raw_logits) + log_probs = distribution.log_prob(actions_at_step) + step_wise_logprobs.append(log_probs) + step_wise_actions.append(actions_at_step) + + gen_output = GenerationOutputs( + step_wise_logprobs, step_wise_actions, gen_tokens, gen_texts + ) + return gen_output + + + def is_encoder_decoder(self, model: PreTrainedModel): + return unwrap_model(model).config.is_encoder_decoder + + def set_training_mode(self, mode: bool) -> None: + self.train(mode) + + + def _get_constructor_parameters(self) -> Dict[str, Any]: + return dict( + observation_space=self.observation_space, + action_space=self.action_space, + ) + + def save(self, path: str) -> None: + """ + Save model to a given location. + + :param path: + """ + torch.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) + + + def _setup_optimizer( + self, + optimizer_kwargs: Dict[str, Any], + weight_decay: float, + optimizer_class: torch.optim, + ): + params = list(self.named_parameters()) + + no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in params if not any(nd in n for nd in no_decay)], + "weight_decay": weight_decay, + }, + { + "params": [p for n, p in params if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + self.optimizer = optimizer_class( + optimizer_grouped_parameters, **optimizer_kwargs + ) + + + diff --git a/benchmark/torch/RL4LMs/registry.py b/benchmark/torch/RL4LMs/registry.py index 3577fffae..3384332db 100644 --- a/benchmark/torch/RL4LMs/registry.py +++ b/benchmark/torch/RL4LMs/registry.py @@ -1,8 +1,8 @@ from typing import Any, Dict, Type, Union - +import parl from benchmark.torch.RL4LMs.algorithms import RL4LMPPO -from benchmark.torch.RL4LMs.agents import RL4LMsSummaAgent +from benchmark.torch.RL4LMs.agents import RL4LMsAgent from benchmark.torch.RL4LMs.utils import TextGenPool, CNNDailyMail # from rl4lms.envs.text_generation.alg_wrappers import wrap_onpolicy_alg @@ -24,7 +24,6 @@ TERMetric, chrFmetric, ) -from benchmark.torch.RL4LMs.models import LMActorCriticModel from benchmark.torch.RL4LMs.models import Seq2SeqLMModel @@ -125,7 +124,7 @@ class ModelRegistry: } @classmethod - def get(cls, model_id: str) -> Type[LMActorCriticModel]: + def get(cls, model_id: str) -> Type[parl.Model]: model_cls = cls._registry[model_id] return model_cls @@ -152,9 +151,9 @@ def add( AlgorithmRegistry._registry[id] = alg_cls -class WrapperRegistry: +class AgentRegistry: _registry = { - "ppo": RL4LMsSummaAgent, + "rl4lm_agent": RL4LMsAgent, } @classmethod @@ -166,22 +165,6 @@ def get(cls, alg_id: str): return wrapper_def @classmethod - def add(cls, id: str, wrapper_def): - WrapperRegistry._registry[id] = wrapper_def - - -class PostProcessorRegistry: - _registry = { - } - - @classmethod - def get(cls, post_processor_id: str): - try: - wrapper_def = cls._registry[post_processor_id] - except KeyError: - raise NotImplementedError - return wrapper_def + def add(cls, id: str, agent_def): + AgentRegistry._registry[id] = agent_def - @classmethod - def add(cls, id: str, post_processor_fn): - PostProcessorRegistry._registry[id] = post_processor_fn diff --git a/benchmark/torch/RL4LMs/trainers.py b/benchmark/torch/RL4LMs/trainers.py index 5dbd7bde8..8157ee4d0 100644 --- a/benchmark/torch/RL4LMs/trainers.py +++ b/benchmark/torch/RL4LMs/trainers.py @@ -5,15 +5,14 @@ from benchmark.torch.RL4LMs.utils import Sample, RewardFunction,\ evaluate_on_samples,\ - KLController, RolloutBuffer, MaskableDictRolloutBuffer,\ + KLController, MaskableDictRolloutBuffer,\ TransitionInfo, TensorDict, RefPolicyOutput, ValueOutput, PolicyOutput from benchmark.torch.RL4LMs.registry import DataPoolRegistry, MetricRegistry, RewardFunctionRegistry, \ - ModelRegistry, AlgorithmRegistry + ModelRegistry, AlgorithmRegistry, AgentRegistry from benchmark.torch.RL4LMs.env import TextGenEnv from transformers import AutoTokenizer from benchmark.torch.RL4LMs.env import LocalParallelVecEnv, make_vec_env from transformers import PreTrainedTokenizer -from benchmark.torch.RL4LMs.agents import RL4LMsSummaAgent import torch from parl.utils import logger @@ -90,6 +89,7 @@ def build_agent(alg_config: Dict[str, Any], model_config = alg_config["model"] model_cls = ModelRegistry.get(model_config["id"]) alg_cls = AlgorithmRegistry.get(alg_config["id"]) + agent_cls = AgentRegistry.get(alg_config["agent_id"]) model_args = model_config["args"] model_args["state_dict"] = model_state @@ -107,7 +107,7 @@ def build_agent(alg_config: Dict[str, Any], **alg_config.get("args") ) - rl4lm_agent = RL4LMsSummaAgent(rl4lm_alg_cls, alg_config) + rl4lm_agent = agent_cls(rl4lm_alg_cls, alg_config) return rl4lm_agent @@ -129,7 +129,7 @@ def unpack_observations(obs_tensor, n_envs: int): return unpacked_obs -class OnPolicyTrainer(): +class OnPolicyTrainer: """ A generic trainer for training LMs with onpolicy algorithms from SB3 """ @@ -143,80 +143,42 @@ def __init__(self, train_eval_config: Dict[str, Any], experiment_name: str = '' ): + # + self._tokenizer = None self._tokenizer_config = tokenizer_config + + # datapool self._datapool_config = datapool_config + self._samples_by_split = None + + # reward function & metrics self._reward_config = reward_config + self._reward_fn = None + self._metrics = None + self._norm_reward = False + + # env self._env_config = env_config + self._env = None + + # algorithm config & model config self._on_policy_alg_config = on_policy_alg_config + + # agent + self._agent = None + + # rollout buffer + self._rollout_buffer = None + self._train_eval_config = train_eval_config self._experiment_name = experiment_name - self._agent = None - self._env = None - self.num_timesteps = None + self._num_timesteps = None self._kl_controller = None self.device = torch.device("cuda" if torch.cuda. is_available() else "cpu") - self._norm_reward = False self._setup() - def _setup(self): - - # load trainer state from available previous checkpoint if available - # self.load_trainer_state(self._tracker) - - # build components - self._tokenizer = build_tokenizer(self._tokenizer_config) - self._reward_fn = build_reward_fn(self._reward_config) - self._metrics = build_metrics( - self._train_eval_config.get("metrics", [])) - self._samples_by_split = build_datapool(self._datapool_config) - self._env = build_env(self._env_config, self._reward_fn, - self._tokenizer, self._samples_by_split["train"]) - - - self._agent = build_agent(self._on_policy_alg_config, - self._env, device=self.device) - - self._rollout_buffer = MaskableDictRolloutBuffer( - buffer_size=self._agent.alg.n_steps * self._env.num_envs, - observation_space=self._env.observation_space, - action_space=self._env.action_space, - device=self.device, - gamma=self._agent.alg.gamma, - gae_lambda=self._agent.alg.gae_lambda, - n_envs=1, - ) - - self._kl_controller = KLController( - self._on_policy_alg_config["kl_div"]["coeff"], - self._on_policy_alg_config["kl_div"].get("target_kl", None)) - - # extract train params - self._max_episode_length = self._env_config["args"]["max_episode_length"] - self._max_prompt_length = self._env_config["args"]["max_prompt_length"] - self._eval_batch_size = self._train_eval_config["eval_batch_size"] - self._n_iters = int(self._train_eval_config["n_iters"]) - self._n_steps_per_iter = self._env.num_envs * self._agent.alg.n_steps - self._num_timesteps = 0 - - # gen kwargs for evaluation (if it is different from rollout gen kwargs) - self._eval_gen_kwargs = self._train_eval_config.get( - "generation_kwargs", None) - - def _evaluate_on_datapools(self, epoch: int, - splits: List[str] = ["val", "test"]): - for split in splits: - evaluate_on_samples(policy=self._agent.alg.model, - tokenizer=self._tokenizer, - samples=self._samples_by_split[split], - batch_size=self._eval_batch_size, - max_prompt_length=self._max_prompt_length, - metrics=self._metrics, - epoch=epoch, - split_name=split, - gen_kwargs=self._eval_gen_kwargs) - def train_and_eval(self): # evaluate on val and test set before fine-tuning once # iter_start = self._trainer_state["current_iter"] @@ -235,7 +197,7 @@ def train_and_eval(self): self._num_timesteps = 0 while self._num_timesteps < self._n_steps_per_iter: - self.collect_rollouts(self._env, self._rollout_buffer) + self._collect_rollouts(self._env, self._rollout_buffer) # inner rollout and learn loop for on-policy algorithm # self._agent.learn(self._n_steps_per_iter) self._agent.learn(self._rollout_buffer) @@ -265,8 +227,51 @@ def train_and_eval(self): # self._tracker.save_auto_model( # self._alg.policy.get_language_model()) + def _setup(self): + + # load trainer state from available previous checkpoint if available + # self.load_trainer_state(self._tracker) + + # build components + self._tokenizer = build_tokenizer(self._tokenizer_config) + self._reward_fn = build_reward_fn(self._reward_config) + self._metrics = build_metrics( + self._train_eval_config.get("metrics", [])) + self._samples_by_split = build_datapool(self._datapool_config) + self._env = build_env(self._env_config, self._reward_fn, + self._tokenizer, self._samples_by_split["train"]) + + + self._agent = build_agent(self._on_policy_alg_config, + self._env, device=self.device) - def get_policy_kwargs( + self._rollout_buffer = MaskableDictRolloutBuffer( + buffer_size=self._agent.alg.n_steps * self._env.num_envs, + observation_space=self._env.observation_space, + action_space=self._env.action_space, + device=self.device, + gamma=self._agent.alg.gamma, + gae_lambda=self._agent.alg.gae_lambda, + n_envs=1, + ) + + self._kl_controller = KLController( + self._on_policy_alg_config["kl_div"]["coeff"], + self._on_policy_alg_config["kl_div"].get("target_kl", None)) + + # extract train params + self._max_episode_length = self._env_config["args"]["max_episode_length"] + self._max_prompt_length = self._env_config["args"]["max_prompt_length"] + self._eval_batch_size = self._train_eval_config["eval_batch_size"] + self._n_iters = int(self._train_eval_config["n_iters"]) + self._n_steps_per_iter = self._env.num_envs * self._agent.alg.n_steps + self._num_timesteps = 0 + + # gen kwargs for evaluation (if it is different from rollout gen kwargs) + self._eval_gen_kwargs = self._train_eval_config.get( + "generation_kwargs", None) + + def _get_policy_kwargs( self, obs: TensorDict, action: torch.tensor, @@ -283,7 +288,7 @@ def get_policy_kwargs( policy_kwargs["action_masks"] = action_mask return policy_kwargs - def generate_batch( + def _generate_batch( self, rollout_buffer, tokenizer: PreTrainedTokenizer, @@ -331,7 +336,7 @@ def generate_batch( obs_tensor = dict_to_tensor(current_obs, self.device) # get log probs (TBD: generalize this a bit) - policy_kwargs = self.get_policy_kwargs( + policy_kwargs = self._get_policy_kwargs( obs_tensor, actions_tensor, policy_past_state, action_mask ) @@ -432,6 +437,19 @@ def generate_batch( ) return rollout_info + def _evaluate_on_datapools(self, epoch: int, + splits: List[str] = ["val", "test"]): + for split in splits: + evaluate_on_samples(policy=self._agent.alg.model, + tokenizer=self._tokenizer, + samples=self._samples_by_split[split], + batch_size=self._eval_batch_size, + max_prompt_length=self._max_prompt_length, + metrics=self._metrics, + epoch=epoch, + split_name=split, + gen_kwargs=self._eval_gen_kwargs) + def _add_to_buffer( self, rollout_buffer, episode_wise_transitions, rollout_info ): @@ -494,10 +512,10 @@ def _add_to_buffer( rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward) return rollout_info - def collect_rollouts( + def _collect_rollouts( self, env, - rollout_buffer: RolloutBuffer, + rollout_buffer: MaskableDictRolloutBuffer, ) -> bool: # max episode steps max_steps = env.get_attr("max_steps", [0])[0] @@ -525,7 +543,7 @@ def collect_rollouts( } while not rollout_buffer.full: # generate batch of rollouts - rollout_info = self.generate_batch( + rollout_info = self._generate_batch( rollout_buffer, tokenizer, max_steps, rollout_info ) diff --git a/benchmark/torch/RL4LMs/utils/__init__.py b/benchmark/torch/RL4LMs/utils/__init__.py index fe791503d..b093afb85 100644 --- a/benchmark/torch/RL4LMs/utils/__init__.py +++ b/benchmark/torch/RL4LMs/utils/__init__.py @@ -11,8 +11,7 @@ from .sample_util import PrioritySampler -from .buffer import DictRolloutBuffer, RolloutBuffer,\ - MaskableDictRolloutBuffer, MaskableRolloutBuffer +from .buffer import MaskableDictRolloutBuffer from .kl_controller import KLController diff --git a/benchmark/torch/RL4LMs/utils/buffer.py b/benchmark/torch/RL4LMs/utils/buffer.py index 380dc1435..b0a99a7f5 100644 --- a/benchmark/torch/RL4LMs/utils/buffer.py +++ b/benchmark/torch/RL4LMs/utils/buffer.py @@ -6,8 +6,7 @@ import torch from gym import spaces -from .data_wrapper import RolloutBufferSamples, DictRolloutBufferSamples,\ - MaskableRolloutBufferSamples, MaskableDictRolloutBufferSamples +from .data_wrapper import MaskableDictRolloutBufferSamples try: # Check memory used by replay buffer when possible @@ -65,307 +64,7 @@ def get_obs_shape( raise NotImplementedError(f"{observation_space} observation space is not supported") -class BaseBuffer(ABC): - """ - Base class that represent a buffer (rollout or replay) - - :param buffer_size: Max number of element in the buffer - :param observation_space: Observation space - :param action_space: Action space - :param device: PyTorch device - to which the values will be converted - :param n_envs: Number of parallel environments - """ - - def __init__( - self, - buffer_size: int, - observation_space: spaces.Space, - action_space: spaces.Space, - device: Union[torch.device, str] = "cpu", - n_envs: int = 1, - ): - super().__init__() - self.buffer_size = buffer_size - self.observation_space = observation_space - self.action_space = action_space - self.obs_shape = get_obs_shape(observation_space) - - self.action_dim = get_action_dim(action_space) - self.pos = 0 - self.full = False - self.device = device - self.n_envs = n_envs - - @staticmethod - def swap_and_flatten(arr: np.ndarray) -> np.ndarray: - """ - Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) - to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) - to [n_steps * n_envs, ...] (which maintain the order) - - :param arr: - :return: - """ - shape = arr.shape - if len(shape) < 3: - shape = shape + (1,) - return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) - - def size(self) -> int: - """ - :return: The current size of the buffer - """ - if self.full: - return self.buffer_size - return self.pos - - def add(self, *args, **kwargs) -> None: - """ - Add elements to the buffer. - """ - raise NotImplementedError() - - def extend(self, *args, **kwargs) -> None: - """ - Add a new batch of transitions to the buffer - """ - # Do a for loop along the batch axis - for data in zip(*args): - self.add(*data) - - def reset(self) -> None: - """ - Reset the buffer. - """ - self.pos = 0 - self.full = False - - def sample(self, batch_size: int, env = None): - """ - :param batch_size: Number of element to sample - :param env: associated gym VecEnv - to normalize the observations/rewards when sampling - :return: - """ - upper_bound = self.buffer_size if self.full else self.pos - batch_inds = np.random.randint(0, upper_bound, size=batch_size) - return self._get_samples(batch_inds, env=env) - - @abstractmethod - def _get_samples( - self, batch_inds: np.ndarray, env = None - ) -> RolloutBufferSamples: - """ - :param batch_inds: - :param env: - :return: - """ - raise NotImplementedError() - - def to_torch(self, array: np.ndarray, copy: bool = True) -> torch.Tensor: - """ - Convert a numpy array to a PyTorch tensor. - Note: it copies the data by default - - :param array: - :param copy: Whether to copy or not the data - (may be useful to avoid changing things be reference) - :return: - """ - if copy: - return torch.tensor(array).to(self.device) - return torch.as_tensor(array).to(self.device) - - @staticmethod - def _normalize_obs( - obs: Union[np.ndarray, Dict[str, np.ndarray]], - env = None, - ) -> Union[np.ndarray, Dict[str, np.ndarray]]: - if env is not None: - return env.normalize_obs(obs) - return obs - - @staticmethod - def _normalize_reward(reward: np.ndarray, env = None) -> np.ndarray: - if env is not None: - return env.normalize_reward(reward).astype(np.float32) - return reward - - - -class RolloutBuffer(BaseBuffer): - """ - Rollout buffer used in on-policy algorithms like A2C/PPO. - It corresponds to ``buffer_size`` transitions collected - using the current policy. - This experience will be discarded after the policy update. - In order to use PPO objective, we also store the current value of each state - and the log probability of each taken action. - - The term rollout here refers to the model-free notion and should not - be used with the concept of rollout used in model-based RL or planning. - Hence, it is only involved in policy and value function training but not action selection. - - :param buffer_size: Max number of element in the buffer - :param observation_space: Observation space - :param action_space: Action space - :param device: - :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator - Equivalent to classic advantage when set to 1. - :param gamma: Discount factor - :param n_envs: Number of parallel environments - """ - - def __init__( - self, - buffer_size: int, - observation_space: spaces.Space, - action_space: spaces.Space, - device: Union[torch.device, str] = "cpu", - gae_lambda: float = 1, - gamma: float = 0.99, - n_envs: int = 1, - ): - - super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) - self.gae_lambda = gae_lambda - self.gamma = gamma - self.observations, self.actions, self.rewards, self.advantages = None, None, None, None - self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None - self.generator_ready = False - self.reset() - - def reset(self) -> None: - - self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32) - self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) - self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.generator_ready = False - super().reset() - - def compute_returns_and_advantage(self, last_values: torch.Tensor, dones: np.ndarray) -> None: - """ - Post-processing step: compute the lambda-return (TD(lambda) estimate) - and GAE(lambda) advantage. - - Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) - to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S)) - where R is the sum of discounted reward with value bootstrap - (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization. - - The TD(lambda) estimator has also two special cases: - - TD(1) is Monte-Carlo estimate (sum of discounted rewards) - - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1})) - - For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375. - - :param last_values: state value estimation for the last step (one for each env) - :param dones: if the last step was a terminal step (one bool for each env). - """ - # Convert to numpy - last_values = last_values.clone().cpu().numpy().flatten() - - last_gae_lam = 0 - for step in reversed(range(self.buffer_size)): - if step == self.buffer_size - 1: - next_non_terminal = 1.0 - dones - next_values = last_values - else: - next_non_terminal = 1.0 - self.episode_starts[step + 1] - next_values = self.values[step + 1] - delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] - last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam - self.advantages[step] = last_gae_lam - # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)" - # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA - self.returns = self.advantages + self.values - - def add( - self, - obs: np.ndarray, - action: np.ndarray, - reward: np.ndarray, - episode_start: np.ndarray, - value: torch.Tensor, - log_prob: torch.Tensor, - ) -> None: - """ - :param obs: Observation - :param action: Action - :param reward: - :param episode_start: Start of episode signal. - :param value: estimated value of the current state - following the current policy. - :param log_prob: log probability of the action - following the current policy. - """ - if len(log_prob.shape) == 0: - # Reshape 0-d tensor to avoid error - log_prob = log_prob.reshape(-1, 1) - - # Reshape needed when using multiple envs with discrete observations - # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) - if isinstance(self.observation_space, spaces.Discrete): - obs = obs.reshape((self.n_envs,) + self.obs_shape) - - self.observations[self.pos] = np.array(obs).copy() - self.actions[self.pos] = np.array(action).copy() - self.rewards[self.pos] = np.array(reward).copy() - self.episode_starts[self.pos] = np.array(episode_start).copy() - self.values[self.pos] = value.clone().cpu().numpy().flatten() - self.log_probs[self.pos] = log_prob.clone().cpu().numpy() - self.pos += 1 - if self.pos == self.buffer_size: - self.full = True - - def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]: - assert self.full, "" - indices = np.random.permutation(self.buffer_size * self.n_envs) - # Prepare the data - if not self.generator_ready: - - _tensor_names = [ - "observations", - "actions", - "values", - "log_probs", - "advantages", - "returns", - ] - - for tensor in _tensor_names: - self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) - self.generator_ready = True - - # Return everything, don't create minibatches - if batch_size is None: - batch_size = self.buffer_size * self.n_envs - - start_idx = 0 - while start_idx < self.buffer_size * self.n_envs: - yield self._get_samples(indices[start_idx : start_idx + batch_size]) - start_idx += batch_size - - def _get_samples(self, batch_inds: np.ndarray, env = None) -> RolloutBufferSamples: - data = ( - self.observations[batch_inds], - self.actions[batch_inds], - self.values[batch_inds].flatten(), - self.log_probs[batch_inds].flatten(), - self.advantages[batch_inds].flatten(), - self.returns[batch_inds].flatten(), - ) - return RolloutBufferSamples(*tuple(map(self.to_torch, data))) - - - -class DictRolloutBuffer(RolloutBuffer): +class MaskableDictRolloutBuffer: """ Dict Rollout buffer used in on-policy algorithms like A2C/PPO. Extends the RolloutBuffer to use dictionary observations @@ -385,7 +84,7 @@ class DictRolloutBuffer(RolloutBuffer): :param action_space: Action space :param device: :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator - Equivalent to Monte-Carlo advantage estimate when set to 1. + Equivalent to classic advantage when set to 1. :param gamma: Discount factor :param n_envs: Number of parallel environments """ @@ -400,8 +99,17 @@ def __init__( gamma: float = 0.99, n_envs: int = 1, ): + self.action_masks = None + self.buffer_size = buffer_size + self.observation_space = observation_space + self.action_space = action_space + self.obs_shape = get_obs_shape(observation_space) - super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + self.action_dim = get_action_dim(action_space) + self.pos = 0 + self.full = False + self.device = device + self.n_envs = n_envs assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" @@ -413,6 +121,20 @@ def __init__( self.reset() def reset(self) -> None: + if isinstance(self.action_space, spaces.Discrete): + mask_dims = self.action_space.n + elif isinstance(self.action_space, spaces.MultiDiscrete): + mask_dims = sum(self.action_space.nvec) + elif isinstance(self.action_space, spaces.MultiBinary): + mask_dims = 2 * self.action_space.n # One mask per binary outcome + else: + raise ValueError( + f"Unsupported action space {type(self.action_space)}") + + self.mask_dims = mask_dims + self.action_masks = np.ones( + (self.buffer_size, self.n_envs, self.mask_dims)) # .to(self.device) + assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" self.observations = {} for key, obs_input_shape in self.obs_shape.items(): @@ -425,17 +147,18 @@ def reset(self) -> None: self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.generator_ready = False - super(RolloutBuffer, self).reset() - def add( - self, - obs: Dict[str, np.ndarray], - action: np.ndarray, - reward: np.ndarray, - episode_start: np.ndarray, - value: torch.Tensor, - log_prob: torch.Tensor, - ) -> None: + self.pos = 0 + self.full = False + + def add(self, + obs: Dict[str, np.ndarray], + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: torch.Tensor, + log_prob: torch.Tensor, + action_masks: Optional[torch.Tensor] = None) -> None: """ :param obs: Observation :param action: Action @@ -445,7 +168,12 @@ def add( following the current policy. :param log_prob: log probability of the action following the current policy. + :param action_masks: Masks applied to constrain the choice of possible actions. """ + if action_masks is not None: + self.action_masks[self.pos] = action_masks.reshape( + (self.n_envs, self.mask_dims)) + if len(log_prob.shape) == 0: # Reshape 0-d tensor to avoid error log_prob = log_prob.reshape(-1, 1) @@ -467,195 +195,56 @@ def add( if self.pos == self.buffer_size: self.full = True - def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]: - assert self.full, "" - indices = np.random.permutation(self.buffer_size * self.n_envs) - # Prepare the data - if not self.generator_ready: - - for key, obs in self.observations.items(): - self.observations[key] = self.swap_and_flatten(obs) - - _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] - - for tensor in _tensor_names: - self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) - self.generator_ready = True - - # Return everything, don't create minibatches - if batch_size is None: - batch_size = self.buffer_size * self.n_envs - - start_idx = 0 - while start_idx < self.buffer_size * self.n_envs: - yield self._get_samples(indices[start_idx : start_idx + batch_size]) - start_idx += batch_size - - def _get_samples(self, batch_inds: np.ndarray, env = None) -> DictRolloutBufferSamples: - - return DictRolloutBufferSamples( - observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, - actions=self.to_torch(self.actions[batch_inds]), - old_values=self.to_torch(self.values[batch_inds].flatten()), - old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), - advantages=self.to_torch(self.advantages[batch_inds].flatten()), - returns=self.to_torch(self.returns[batch_inds].flatten()), - ) - - -class MaskableRolloutBuffer(RolloutBuffer): - """ - Rollout buffer that also stores the invalid action masks associated with each observation. - - :param buffer_size: Max number of element in the buffer - :param observation_space: Observation space - :param action_space: Action space - :param device: - :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator - Equivalent to classic advantage when set to 1. - :param gamma: Discount factor - :param n_envs: Number of parallel environments - """ - - def __init__(self, *args, **kwargs): - self.action_masks = None - super().__init__(*args, **kwargs) - - def reset(self) -> None: - if isinstance(self.action_space, spaces.Discrete): - mask_dims = self.action_space.n - elif isinstance(self.action_space, spaces.MultiDiscrete): - mask_dims = sum(self.action_space.nvec) - elif isinstance(self.action_space, spaces.MultiBinary): - mask_dims = 2 * self.action_space.n # One mask per binary outcome - else: - raise ValueError( - f"Unsupported action space {type(self.action_space)}") - - self.mask_dims = mask_dims - self.action_masks = np.ones( - (self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32) - - super().reset() - - def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None: - """ - :param action_masks: Masks applied to constrain the choice of possible actions. + def compute_returns_and_advantage(self, last_values: torch.Tensor, dones: np.ndarray) -> None: """ - if action_masks is not None: - self.action_masks[self.pos] = action_masks.reshape( - (self.n_envs, self.mask_dims)) - - super().add(*args, **kwargs) - - def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: - assert self.full, "" - indices = np.random.permutation(self.buffer_size * self.n_envs) - # Prepare the data - if not self.generator_ready: - for tensor in [ - "observations", - "actions", - "values", - "log_probs", - "advantages", - "returns", - "action_masks", - ]: - self.__dict__[tensor] = self.swap_and_flatten( - self.__dict__[tensor]) - self.generator_ready = True - - # Return everything, don't create minibatches - if batch_size is None: - batch_size = self.buffer_size * self.n_envs - - start_idx = 0 - while start_idx < self.buffer_size * self.n_envs: - yield self._get_samples(indices[start_idx: start_idx + batch_size]) - start_idx += batch_size - - def _get_samples(self, batch_inds: np.ndarray, env = None) -> MaskableRolloutBufferSamples: - data = ( - self.observations[batch_inds], - self.actions[batch_inds], - self.values[batch_inds].flatten(), - self.log_probs[batch_inds].flatten(), - self.advantages[batch_inds].flatten(), - self.returns[batch_inds].flatten(), - self.action_masks[batch_inds].reshape(-1, self.mask_dims), - ) - return MaskableRolloutBufferSamples(*map(self.to_torch, data)) - - - - - -class MaskableDictRolloutBuffer(DictRolloutBuffer): - """ - Dict Rollout buffer used in on-policy algorithms like A2C/PPO. - Extends the RolloutBuffer to use dictionary observations - - It corresponds to ``buffer_size`` transitions collected - using the current policy. - This experience will be discarded after the policy update. - In order to use PPO objective, we also store the current value of each state - and the log probability of each taken action. - - The term rollout here refers to the model-free notion and should not - be used with the concept of rollout used in model-based RL or planning. - Hence, it is only involved in policy and value function training but not action selection. + Post-processing step: compute the lambda-return (TD(lambda) estimate) + and GAE(lambda) advantage. - :param buffer_size: Max number of element in the buffer - :param observation_space: Observation space - :param action_space: Action space - :param device: - :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator - Equivalent to classic advantage when set to 1. - :param gamma: Discount factor - :param n_envs: Number of parallel environments - """ + Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) + to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S)) + where R is the sum of discounted reward with value bootstrap + (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization. - def __init__( - self, - buffer_size: int, - observation_space: spaces.Space, - action_space: spaces.Space, - device: Union[torch.device, str] = "cpu", - gae_lambda: float = 1, - gamma: float = 0.99, - n_envs: int = 1, - ): - self.action_masks = None - super().__init__(buffer_size, observation_space, - action_space, device, gae_lambda, gamma, n_envs=n_envs) + The TD(lambda) estimator has also two special cases: + - TD(1) is Monte-Carlo estimate (sum of discounted rewards) + - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1})) - def reset(self) -> None: - if isinstance(self.action_space, spaces.Discrete): - mask_dims = self.action_space.n - elif isinstance(self.action_space, spaces.MultiDiscrete): - mask_dims = sum(self.action_space.nvec) - elif isinstance(self.action_space, spaces.MultiBinary): - mask_dims = 2 * self.action_space.n # One mask per binary outcome - else: - raise ValueError( - f"Unsupported action space {type(self.action_space)}") + For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375. - self.mask_dims = mask_dims - self.action_masks = np.ones( - (self.buffer_size, self.n_envs, self.mask_dims)) # .to(self.device) + :param last_values: state value estimation for the last step (one for each env) + :param dones: if the last step was a terminal step (one bool for each env). + """ + # Convert to numpy + last_values = last_values.clone().cpu().numpy().flatten() - super().reset() + last_gae_lam = 0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - dones + next_values = last_values + else: + next_non_terminal = 1.0 - self.episode_starts[step + 1] + next_values = self.values[step + 1] + delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] + last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam + self.advantages[step] = last_gae_lam + # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)" + # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA + self.returns = self.advantages + self.values - def add(self, *args, action_masks: Optional[torch.Tensor] = None, **kwargs) -> None: - """ - :param action_masks: Masks applied to constrain the choice of possible actions. + def swap_and_flatten(self, arr: np.ndarray) -> np.ndarray: """ - if action_masks is not None: - self.action_masks[self.pos] = action_masks.reshape( - (self.n_envs, self.mask_dims)) + Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) + to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) + to [n_steps * n_envs, ...] (which maintain the order) - super().add(*args, **kwargs) + :param arr: + :return: + """ + shape = arr.shape + if len(shape) < 3: + shape = shape + (1,) + return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: assert self.full, "" @@ -683,6 +272,20 @@ def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRollout yield self._get_samples(indices[start_idx: start_idx + batch_size]) start_idx += batch_size + def to_torch(self, array: np.ndarray, copy: bool = True) -> torch.Tensor: + """ + Convert a numpy array to a PyTorch tensor. + Note: it copies the data by default + + :param array: + :param copy: Whether to copy or not the data + (may be useful to avoid changing things be reference) + :return: + """ + if copy: + return torch.tensor(array).to(self.device) + return torch.as_tensor(array).to(self.device) + def _get_samples(self, batch_inds: np.ndarray, env = None) -> MaskableDictRolloutBufferSamples: return MaskableDictRolloutBufferSamples( diff --git a/benchmark/torch/RL4LMs/utils/data_wrapper.py b/benchmark/torch/RL4LMs/utils/data_wrapper.py index a85cd291c..234b42c5a 100644 --- a/benchmark/torch/RL4LMs/utils/data_wrapper.py +++ b/benchmark/torch/RL4LMs/utils/data_wrapper.py @@ -26,16 +26,7 @@ class TransitionInfo: info: Dict[str, Any] -class MaskableRolloutBufferSamples(NamedTuple): - observations: torch.Tensor - actions: torch.Tensor - old_values: torch.Tensor - old_log_prob: torch.Tensor - advantages: torch.Tensor - returns: torch.Tensor - action_masks: torch.Tensor - -class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples): +class MaskableDictRolloutBufferSamples(NamedTuple): observations: TensorDict actions: torch.Tensor old_values: torch.Tensor @@ -45,24 +36,6 @@ class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples): action_masks: torch.Tensor -class RolloutBufferSamples(NamedTuple): - observations: torch.Tensor - actions: torch.Tensor - old_values: torch.Tensor - old_log_prob: torch.Tensor - advantages: torch.Tensor - returns: torch.Tensor - - -class DictRolloutBufferSamples(RolloutBufferSamples): - observations: TensorDict - actions: torch.Tensor - old_values: torch.Tensor - old_log_prob: torch.Tensor - advantages: torch.Tensor - returns: torch.Tensor - - @dataclass(init=True) class Sample: id: str From 89c4efb09130d8327040aa0fcd7deb001bca1d04 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Tue, 7 Mar 2023 13:01:22 +0800 Subject: [PATCH 06/34] simplified code v0.0 --- benchmark/torch/RL4LMs/README.md | 4 +- benchmark/torch/RL4LMs/agents/__init__.py | 1 - benchmark/torch/RL4LMs/algorithms/__init__.py | 1 - benchmark/torch/RL4LMs/env/__init__.py | 2 +- benchmark/torch/RL4LMs/env/text_gen_env.py | 50 +- benchmark/torch/RL4LMs/env/vec_env.py | 58 +- benchmark/torch/RL4LMs/metrics/__init__.py | 16 - benchmark/torch/RL4LMs/metrics/metric_util.py | 644 ------------------ benchmark/torch/RL4LMs/models/__init__.py | 1 - benchmark/torch/RL4LMs/registry.py | 170 ----- .../{algorithms/ppo.py => rl4lm_ppo.py} | 5 +- .../torch/RL4LMs/{agents => }/rl4lms_agent.py | 5 +- .../RL4LMs/{models => }/seq2seq_model.py | 99 ++- .../{configs/summarization => }/t5_ppo.yml | 3 - benchmark/torch/RL4LMs/train.py | 134 +++- benchmark/torch/RL4LMs/trainers.py | 566 --------------- benchmark/torch/RL4LMs/utils/__init__.py | 13 +- benchmark/torch/RL4LMs/utils/buffer.py | 96 +-- .../RL4LMs/utils/component_build_util.py | 52 ++ benchmark/torch/RL4LMs/utils/data_pool.py | 108 +-- benchmark/torch/RL4LMs/utils/data_wrapper.py | 11 +- .../RL4LMs/utils/distribution_wrapper.py | 18 +- .../torch/RL4LMs/utils/evaluation_util.py | 32 +- .../utils/huggingface_generation_util.py | 291 +------- benchmark/torch/RL4LMs/utils/kl_controller.py | 2 +- benchmark/torch/RL4LMs/utils/metric_util.py | 177 +++++ benchmark/torch/RL4LMs/utils/reward_util.py | 428 +----------- benchmark/torch/RL4LMs/utils/rollout_util.py | 281 ++++++++ benchmark/torch/RL4LMs/utils/sample_util.py | 10 +- benchmark/torch/RL4LMs/utils/type_wrapper.py | 7 - 30 files changed, 845 insertions(+), 2440 deletions(-) delete mode 100644 benchmark/torch/RL4LMs/agents/__init__.py delete mode 100644 benchmark/torch/RL4LMs/algorithms/__init__.py delete mode 100644 benchmark/torch/RL4LMs/metrics/__init__.py delete mode 100644 benchmark/torch/RL4LMs/metrics/metric_util.py delete mode 100644 benchmark/torch/RL4LMs/models/__init__.py delete mode 100644 benchmark/torch/RL4LMs/registry.py rename benchmark/torch/RL4LMs/{algorithms/ppo.py => rl4lm_ppo.py} (97%) rename benchmark/torch/RL4LMs/{agents => }/rl4lms_agent.py (96%) rename benchmark/torch/RL4LMs/{models => }/seq2seq_model.py (87%) rename benchmark/torch/RL4LMs/{configs/summarization => }/t5_ppo.yml (97%) delete mode 100644 benchmark/torch/RL4LMs/trainers.py create mode 100644 benchmark/torch/RL4LMs/utils/component_build_util.py create mode 100644 benchmark/torch/RL4LMs/utils/metric_util.py create mode 100644 benchmark/torch/RL4LMs/utils/rollout_util.py delete mode 100644 benchmark/torch/RL4LMs/utils/type_wrapper.py diff --git a/benchmark/torch/RL4LMs/README.md b/benchmark/torch/RL4LMs/README.md index 2112925ab..172025756 100644 --- a/benchmark/torch/RL4LMs/README.md +++ b/benchmark/torch/RL4LMs/README.md @@ -10,10 +10,10 @@ ### Main contribution - Change from **\{ trainer: \{ ppo: \{ env, rollout_buffer, policy/model \} \} \}** to - **\{trainer: \{env, rollout_buffer, agent: \{ ppo: \{ model \} \} \} \}** according PARL architecture. + **\{trainer: \{env, rollout_buffer, agent: \{ ppo: \{ model \} \} \} \}** according to PARL architecture. ### Running command ```bash -python train.py --config_path configs/summarization/t5_ppo.yml +python train.py --config_path t5_ppo.yml ``` \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/agents/__init__.py b/benchmark/torch/RL4LMs/agents/__init__.py deleted file mode 100644 index 72f8da7f4..000000000 --- a/benchmark/torch/RL4LMs/agents/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .rl4lms_agent import RL4LMsAgent \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/algorithms/__init__.py b/benchmark/torch/RL4LMs/algorithms/__init__.py deleted file mode 100644 index 8bacfd707..000000000 --- a/benchmark/torch/RL4LMs/algorithms/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .ppo import RL4LMPPO \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/env/__init__.py b/benchmark/torch/RL4LMs/env/__init__.py index 39f83816f..09abf2026 100644 --- a/benchmark/torch/RL4LMs/env/__init__.py +++ b/benchmark/torch/RL4LMs/env/__init__.py @@ -1,2 +1,2 @@ from .text_gen_env import TextGenEnv -from .vec_env import LocalParallelVecEnv, make_vec_env \ No newline at end of file +from .vec_env import make_vec_env \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/env/text_gen_env.py b/benchmark/torch/RL4LMs/env/text_gen_env.py index faf9eafb1..7f9d2c9b7 100644 --- a/benchmark/torch/RL4LMs/env/text_gen_env.py +++ b/benchmark/torch/RL4LMs/env/text_gen_env.py @@ -1,27 +1,25 @@ from cmath import inf from typing import Dict, Tuple, Optional, List -import torch from gym import Env, spaces from gym.spaces.dict import Dict as DictSpace from gym.spaces.discrete import Discrete from benchmark.torch.RL4LMs.utils import Sample, Observation, PrioritySampler -from benchmark.torch.RL4LMs.utils import RewardFunction, BatchedRewardFunction from transformers import AutoTokenizer class TextGenEnv(Env): def __init__( self, - tokenizer: AutoTokenizer, - reward_function: RewardFunction, - samples: Tuple[List[Sample], float], - max_episode_length: int = 512, - priority_scale: float = 0.0, - max_prompt_length: Optional[int] = None, - terminate_on_eos: bool = False, - context_start_token: Optional[int] = None, - prompt_truncation_side: str = "left", + tokenizer, + reward_function, + samples, + max_episode_length = 512, + priority_scale = 0.0, + max_prompt_length = None, + terminate_on_eos = False, + context_start_token = None, + prompt_truncation_side = "left", ): """ A generic RL environment to generate textual sequences. @@ -99,7 +97,7 @@ def __init__( self.__current_obs = None self.__time_step = None - def step(self, action: int) -> Tuple[Dict[str, torch.tensor], int, bool, dict]: + def step(self, action): self.__time_step += 1 # previous obs @@ -114,20 +112,18 @@ def step(self, action: int) -> Tuple[Dict[str, torch.tensor], int, bool, dict]: ) # compute reward - if not isinstance(self.reward_function, BatchedRewardFunction): - reward = ( - None - if self.reward_function is None - else self.reward_function( - previous_obs, - action, - self.__current_obs, - done, - self.__current_obs.meta_info, - ) + reward = ( + None + if self.reward_function is None + else self.reward_function( + previous_obs, + action, + self.__current_obs, + done, + self.__current_obs.meta_info, ) - else: - reward = -inf # will be overridden later + ) + # populate additional info info = { @@ -141,7 +137,7 @@ def step(self, action: int) -> Tuple[Dict[str, torch.tensor], int, bool, dict]: return self.__current_obs.to_dict(), reward, done, info - def reset(self, sample: Sample = None) -> Dict[str, torch.tensor]: + def reset(self, sample = None): """ Resets the environment and starts a new episode """ @@ -173,5 +169,5 @@ def render(self): def close(self): pass - def add_sample(self, sample: Sample, weight: int = 1.0): + def add_sample(self, sample, weight = 1.0): self.sampler_for_replaying.add(sample, weight) diff --git a/benchmark/torch/RL4LMs/env/vec_env.py b/benchmark/torch/RL4LMs/env/vec_env.py index 717a80fd9..e4f5e4136 100644 --- a/benchmark/torch/RL4LMs/env/vec_env.py +++ b/benchmark/torch/RL4LMs/env/vec_env.py @@ -13,10 +13,9 @@ def __init__(self, var): def __getstate__(self): return cloudpickle.dumps(self.var) - def __setstate__(self, var) -> None: + def __setstate__(self, var): self.var = cloudpickle.loads(var) - def _flatten_obs(obs, space: gym.spaces.Space): assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" assert len(obs) > 0, "need observations from at least one environment" @@ -32,10 +31,9 @@ def _flatten_obs(obs, space: gym.spaces.Space): else: return np.stack(obs) - def _worker( remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper -) -> None: +): # Import here to avoid a circular import parent_remote.close() @@ -75,27 +73,6 @@ def _worker( except EOFError: break - -def make_vec_env( - env_id: Union[str, Type[gym.Env]], - vec_env_cls, - n_envs: int = 1, - seed: Optional[int] = None, - start_index: int = 0, - env_kwargs: Optional[Dict[str, Any]] = None, -): - def make_env(rank): - def _init(): - env = env_id(**env_kwargs) - if seed is not None: - env.seed(seed + rank) - env.action_space.seed(seed + rank) - return env - return _init - - return vec_env_cls([make_env(i + start_index) for i in range(n_envs)]) - - class LocalParallelVecEnv: def __init__(self, env_fns, start_method = None): @@ -127,7 +104,7 @@ def __init__(self, env_fns, start_method = None): self.observation_space = observation_space self.action_space = action_space - def step_async(self, actions: np.ndarray) -> None: + def step_async(self, actions): for remote, action in zip(self.remotes, actions): remote.send(("step", action)) self.waiting = True @@ -207,4 +184,31 @@ def step(self, actions: np.ndarray): :return: observation, reward, done, information """ self.step_async(actions) - return self.step_wait() \ No newline at end of file + return self.step_wait() + +def make_vec_env( + env_id: Union[str, Type[gym.Env]], + seed: Optional[int] = None, + start_index: int = 0, + env_config = None, + reward_fn = None, + tokenizer = None, + train_samples = None +): + n_envs = env_config["n_envs"] + env_kwargs = { + "reward_function": reward_fn, + "tokenizer": tokenizer, + "samples": train_samples, + } + env_kwargs = {**env_kwargs, **env_config.get("args", {})} + def make_env(rank): + def _init(): + env = env_id(**env_kwargs) + if seed is not None: + env.seed(seed + rank) + env.action_space.seed(seed + rank) + return env + return _init + + return LocalParallelVecEnv([make_env(i + start_index) for i in range(n_envs)]) diff --git a/benchmark/torch/RL4LMs/metrics/__init__.py b/benchmark/torch/RL4LMs/metrics/__init__.py deleted file mode 100644 index 30fe430fa..000000000 --- a/benchmark/torch/RL4LMs/metrics/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from.metric_util import ( - BaseMetric, - BERTScoreMetric, - BLEUMetric, - BLEURTMetric, - BLEUToTTo, - DiversityMetrics, - LearnedRewardMetric, - MeteorMetric, - Perplexity, - RougeLMax, - RougeMetric, - SacreBLEUMetric, - TERMetric, - chrFmetric, -) \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/metrics/metric_util.py b/benchmark/torch/RL4LMs/metrics/metric_util.py deleted file mode 100644 index e06c4aace..000000000 --- a/benchmark/torch/RL4LMs/metrics/metric_util.py +++ /dev/null @@ -1,644 +0,0 @@ -from transformers import AutoModelForSequenceClassification, AutoTokenizer -from transformers import PreTrainedModel -import torch -from typing import List, Dict, Tuple, Any -from abc import abstractmethod -import numpy as np -from datasets import load_metric -from gem_metrics.msttr import MSTTR -from gem_metrics.ngrams import NGramStats -from gem_metrics.texts import Predictions -from tqdm import tqdm -import copy -import rouge -import json -from tempfile import TemporaryDirectory -import subprocess -import os -import jsonlines - -# Cider, Spice, SummaCConv, SummaCZS, compute_parent, - - -def compute_bleu(predicted_texts: List[str], - raw_tables: List[dict]): - - def _read_results(path): - try: - with open(path) as fp: - score = json.load(fp)["score"]/100 - except: - score = 0.0 - return score - - with TemporaryDirectory() as temp_dir: - - # write tables - target_path = os.path.join(temp_dir, "samples.jsonl") - with jsonlines.open(target_path, "w") as writer: - for table in raw_tables: - writer.write(table) - - # write gen texts - prediction_path = os.path.join(temp_dir, "predictions.txt") - with open(prediction_path, "w") as fp: - predicted_texts = '\n'.join(predicted_texts) - fp.write(predicted_texts) - - cmd = ['bash', 'totto_bleu_eval.sh', - '-p', prediction_path, - '-t', target_path, - '--output_dir', temp_dir, - ] - subprocess.check_call(cmd, - cwd=os.path.dirname(os.path.abspath(__file__)), - stdout=subprocess.DEVNULL) - - # read the results back - bleu_overall = _read_results( - os.path.join(temp_dir, "bleu_overall.json")) - bleu_overlap = _read_results( - os.path.join(temp_dir, "bleu_overlap.json")) - bleu_non_overlap = _read_results( - os.path.join(temp_dir, "bleu_non_overlap.json")) - return bleu_overall, bleu_overlap, bleu_non_overlap - - - - -class BaseMetric: - @abstractmethod - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ): - """ - Returns a dict where key is the metric name and value is again a dict consisting of tuple of individual scores (if any) and corpus level score - - eg. { - metric_name: (individual_scores, corpus_level_score) - "metric_1": ([0.5, 0.5, 0.8], 0.1) - } - - """ - raise NotImplementedError - - -class LearnedRewardMetric(BaseMetric): - def __init__( - self, - model_name: str, - label_ix: int, - batch_size: int, - include_prompt_for_eval: bool = True, - ) -> None: - super().__init__() - self._device = "cuda" if torch.cuda.is_available() else "cpu" - self._tokenizer = AutoTokenizer.from_pretrained(model_name) - self._tokenizer.truncation_side = "left" - self._model = AutoModelForSequenceClassification.from_pretrained(model_name).to( - self._device - ) - self._label_ix = label_ix - self._batch_size = batch_size - self._include_prompt_for_eval = include_prompt_for_eval - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ) -> Dict[str, float]: - all_scores = [] - current_ix = 0 - n_texts = len(generated_texts) - while current_ix < n_texts: - batch_gen_texts = generated_texts[ - current_ix : current_ix + self._batch_size - ] - batch_prompt_texts = prompt_texts[ - current_ix : current_ix + self._batch_size - ] - - if self._include_prompt_for_eval: - batch_gen_texts = [ - (prompt + gen) - for gen, prompt in zip(batch_gen_texts, batch_prompt_texts) - ] - encoded = self._tokenizer( - batch_gen_texts, return_tensors="pt", truncation=True, padding=True - ) - with torch.no_grad(): - outputs = self._model( - input_ids=encoded.input_ids.to(self._device), - attention_mask=encoded.attention_mask.to(self._device), - ) - scores = torch.softmax(outputs.logits, dim=1) - scores = scores[:, self._label_ix].tolist() - all_scores.extend(scores) - current_ix += self._batch_size - - metric_dict = { - "semantic/learned_automodel_metric": (all_scores, np.mean(all_scores)) - } - return metric_dict - - -class MeteorMetric(BaseMetric): - def __init__(self) -> None: - super().__init__() - self._metric = load_metric("meteor") - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ): - - score = self._metric.compute( - predictions=generated_texts, references=reference_texts - )["meteor"] - - metric_dict = {"lexical/meteor": (None, score)} - return metric_dict - - -class RougeMetric(BaseMetric): - def __init__(self, use_single_ref: bool = True) -> None: - super().__init__() - self._metric = load_metric("rouge") - self._use_single_ref = use_single_ref - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ): - if self._use_single_ref: - # TBD: this is required for CNN/DM dataset, without this we get low scores - # TBD: needs investigation - ref_texts = [ref[0] for ref in reference_texts] - else: - ref_texts = reference_texts - - metric_results = self._metric.compute( - predictions=generated_texts, references=ref_texts, use_stemmer=True - ) - score_keys = ["rouge1", "rouge2", "rougeL", "rougeLsum"] - metric_dict = {} - for rouge_type in score_keys: - rouge_score = metric_results[rouge_type].mid.fmeasure - metric_dict[f"lexical/rouge_{rouge_type}"] = (None, rouge_score) - return metric_dict - - -class BERTScoreMetric(BaseMetric): - def __init__(self, language: str) -> None: - super().__init__() - self._metric = load_metric("bertscore") - self._language = language - # since models are loaded heavily on cuda:0, use the last one to avoid memory - self._last_gpu = f"cuda:{torch.cuda.device_count() - 1}" - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ) -> Tuple[List[float], float]: - with torch.no_grad(): - metric_results = self._metric.compute( - predictions=generated_texts, - references=reference_texts, - lang=self._language, - device=self._last_gpu, - ) - bert_scores = metric_results["f1"] - corpus_level_score = np.mean(bert_scores) - metric_dict = {"semantic/bert_score": (bert_scores, corpus_level_score)} - return metric_dict - - -class BLEUMetric(BaseMetric): - def __init__(self) -> None: - super().__init__() - self._metric = load_metric("bleu") - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ) -> Tuple[List[float], float]: - - tokenized_predictions = [] - tokenized_reference_texts = [] - for prediction, refs in zip(generated_texts, reference_texts): - tokenized_prediction = prediction.split() - tokenized_refs = [ref.split() for ref in refs] - tokenized_predictions.append(tokenized_prediction) - tokenized_reference_texts.append(tokenized_refs) - - try: - metric_results = self._metric.compute( - predictions=tokenized_predictions, references=tokenized_reference_texts - ) - bleu_score = metric_results["bleu"] - metric_dict = {"lexical/bleu": (None, bleu_score)} - return metric_dict - except Exception as e: - return {"lexical/bleu": (None, "n/a")} - - -class BLEURTMetric(BaseMetric): - def __init__(self, config_name: str = None) -> None: - super().__init__() - self._metric = load_metric("bleurt", config_name=config_name) - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ) -> Tuple[List[float], float]: - metric_results = self._metric.compute( - predictions=generated_texts, references=reference_texts - ) - corpus_score = np.mean(metric_results["scores"]) - metric_dict = {"semantic/bleurt": (metric_results["scores"], corpus_score)} - return metric_dict - - -def get_generated_and_predictions( - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - split_name: str, -): - split_name = "" if split_name is None else split_name - preds = {} - refs = {} - for ix, (prompt_text, gen_text, ref_text) in enumerate( - zip(prompt_texts, generated_texts, reference_texts) - ): - preds[split_name + prompt_text] = [gen_text] - refs[split_name + prompt_text] = ref_text - return preds, refs - - -def get_individual_scores( - prompt_texts: List[str], split_name: str, scores_dict: Dict[str, float] -): - split_name = "" if split_name is None else split_name - scores = [] - for prompt_text in prompt_texts: - scores.append(scores_dict.get(split_name + prompt_text, "n/a")) - return scores - - - - -class DiversityMetrics(BaseMetric): - def __init__(self, window_size: int = 100) -> None: - self._msttr_metric = MSTTR(window_size=window_size) - self._n_gram_metric = NGramStats() - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ) -> Tuple[List[float], float]: - - predictions = Predictions(data={"filename": "", "values": generated_texts}) - diversity_metrics = {} - msttr_metrics = self._msttr_metric.compute(None, predictions) - n_gram_metrics = self._n_gram_metric.compute(None, predictions) - - for key, value in msttr_metrics.items(): - diversity_metrics[f"diversity_metrics/{key}"] = (None, value) - for key, value in n_gram_metrics.items(): - diversity_metrics[f"diversity_metrics/{key}"] = (None, value) - - return diversity_metrics - - -# class SummaCZSMetric(BaseMetric): -# """ -# Consistency metric for summarization -# -# https://github.com/tingofurro/summac/ -# """ -# -# def __init__(self, **kwargs) -> None: -# super().__init__() -# self._scorer = SummaCZS(**kwargs) -# -# def compute( -# self, -# prompt_texts: List[str], -# generated_texts: List[str], -# reference_texts: List[List[str]], -# meta_infos: List[Dict[str, Any]] = None, -# model: PreTrainedModel = None, -# split_name: str = None, -# ) -> Tuple[List[float], float]: -# metric_results = self._scorer.score(prompt_texts, generated_texts) -# corpus_score = np.mean(metric_results["scores"]) -# metric_dict = {"consistency/summaczs": (metric_results["scores"], corpus_score)} -# return metric_dict - - - - - -class Perplexity(BaseMetric): - def __init__( - self, - stride: int, - tokenizer_id: str, - model_type: str = "causal", - use_text_from_meta_data: bool = False, - ) -> None: - super().__init__() - self._tokenizer_id = tokenizer_id - self._model_type = model_type - self._stride = stride - self._use_text_from_meta_data = use_text_from_meta_data - - def get_device(self, model: PreTrainedModel): - try: - return model.transformer.first_device - except: - return model.device - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ) -> Tuple[List[float], float]: - if split_name == "train": - return {} - - if self._model_type != "causal": - raise NotImplementedError - - # we compute perplexity on reference texts - if self._use_text_from_meta_data: - reference_texts = [info["reference"] for info in meta_infos] - else: - reference_texts = [ref for refs in reference_texts for ref in refs] - tokenizer = AutoTokenizer.from_pretrained(self._tokenizer_id) - encodings = tokenizer("\n\n".join(reference_texts), return_tensors="pt") - - device = self.get_device(model) - - nlls = [] - max_length = model.config.n_positions - for i in tqdm(range(0, encodings.input_ids.size(1), self._stride)): - begin_loc = max(i + self._stride - max_length, 0) - end_loc = min(i + self._stride, encodings.input_ids.size(1)) - trg_len = end_loc - i # may be different from stride on last loop - - # run on last device - input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) - target_ids = input_ids.clone() - target_ids[:, :-trg_len] = -100 - - with torch.no_grad(): - outputs = model(input_ids, labels=target_ids) - neg_log_likelihood = outputs[0] * trg_len - - nlls.append(neg_log_likelihood) - - return { - "fluency_metrics/perplexity": ( - None, - torch.exp(torch.stack(nlls).sum() / end_loc).item(), - ) - } - - - - - -class BLEUToTTo: - """ - Official version - """ - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]], - model: PreTrainedModel = None, - split_name: str = None, - ): - tables = [info["raw_table"] for info in meta_infos] - bleu_overall, bleu_overlap, bleu_non_overlap = compute_bleu( - generated_texts, tables - ) - - metric_results = { - "table_to_text/bleu_overall": (None, bleu_overall), - "table_to_text/bleu_overlap": (None, bleu_overlap), - "table_to_text/bleu_non_overlap": (None, bleu_non_overlap), - } - return metric_results - - -class RougeLMax(BaseMetric): - def __init__(self, **args) -> None: - super().__init__() - self._metric = rouge.Rouge(metrics=["rouge-l"], **args) - - def _rouge_max_over_ground_truths(self, prediction, ground_truths): - """ - Computes max of Rouge-L (https://github.com/allenai/unifiedqa/blob/bad6ef339db6286f0d8bd0661a2daeeb0f800f59/evaluation/evaluate_narrativeqa.py#L25) - """ - # load stemmer - self._metric.load_stemmer(self._metric.ensure_compatibility) - - scores_for_ground_truths = [] - for ground_truth in ground_truths: - score = self._metric.get_scores(prediction, [ground_truth]) - scores_for_ground_truths.append(score) - max_score = copy.deepcopy(score) - max_score = max([score["rouge-l"]["f"] for score in scores_for_ground_truths]) - return max_score - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ): - all_scores = [] - for gen_text, ref_texts in zip(generated_texts, reference_texts): - rouge_max_score = self._rouge_max_over_ground_truths(gen_text, ref_texts) - all_scores.append(rouge_max_score) - - metric_dict = {"lexical/rouge_l_max": (all_scores, np.mean(all_scores))} - return metric_dict - - -class SacreBLEUMetric(BaseMetric): - def __init__(self, **args) -> None: - super().__init__() - self._args = args - self._metric = load_metric("sacrebleu") - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ) -> Tuple[List[float], float]: - - metric_results = self._metric.compute( - predictions=generated_texts, references=reference_texts, **self._args - ) - bleu_score = metric_results["score"] / 100 - metric_dict = {"lexical/sacrebleu": (None, bleu_score)} - return metric_dict - - -class TERMetric(BaseMetric): - def __init__(self) -> None: - super().__init__() - self._metric = load_metric("ter") - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ) -> Tuple[List[float], float]: - - metric_results = self._metric.compute( - predictions=generated_texts, references=reference_texts - ) - score = metric_results["score"] / 100 - metric_dict = {"lexical/ter": (None, score)} - return metric_dict - - -class chrFmetric(BaseMetric): - def __init__(self) -> None: - super().__init__() - self._metric = load_metric("chrf") - - def compute( - self, - prompt_texts: List[str], - generated_texts: List[str], - reference_texts: List[List[str]], - meta_infos: List[Dict[str, Any]] = None, - model: PreTrainedModel = None, - split_name: str = None, - ) -> Tuple[List[float], float]: - - metric_results = self._metric.compute( - predictions=generated_texts, references=reference_texts - ) - score = metric_results["score"] / 100 - metric_dict = {"lexical/chrf": (None, score)} - return metric_dict - - - - -if __name__ == "__main__": - prompt_texts = [""] - gen_texts = ["Hello there general kenobi", "foo bar foobar"] - reference_texts = [["Hello there general kenobi"], ["foo bar foobar"]] - # metric = MeteorMetric() - # print(metric.compute(prompt_texts, gen_texts, reference_texts)) - - # metric = RougeMetric() - # print(metric.compute(prompt_texts, gen_texts, reference_texts)) - - # metric = SacreBLEUMetric(tokenize="intl") - # print(metric.compute(prompt_texts, gen_texts, reference_texts)) - - # metric = TERMetric() - # print(metric.compute(prompt_texts, gen_texts, reference_texts)) - - # metric = chrFmetric() - # print(metric.compute(prompt_texts, gen_texts, reference_texts)) - - # metric = BERTScoreMetric(language="en") - # print(metric.compute(prompt_texts, gen_texts, reference_texts)) - - # metric = BLEUMetric() - # print(metric.compute(prompt_texts, gen_texts, reference_texts)) - - # metric = BLEURTMetric() - # print(metric.compute(prompt_texts, gen_texts, reference_texts)) - - # metric = DiversityMetrics() - # print(metric.compute(prompt_texts, gen_texts, reference_texts)) - - # document = """Jeff joined Microsoft in 1992 to lead corporate developer evangelism for Windows NT. He then served as a Group Program manager in Microsoft’s Internet Business Unit. In 1998, he led the creation of SharePoint Portal Server, which became one of Microsoft’s fastest-growing businesses, exceeding $2 billion in revenues. Jeff next served as Corporate Vice President for Program Management across Office 365 Services and Servers, which is the foundation of Microsoft’s enterprise cloud leadership. He then led Corporate Strategy supporting Satya Nadella and Amy Hood on Microsoft’s mobile-first/cloud-first transformation and acquisitions. Prior to joining Microsoft, Jeff was vice president for software development for an investment firm in New York. He leads Office shared experiences and core applications, as well as OneDrive and SharePoint consumer and business services in Office 365. Jeff holds a Master of Business Administration degree from Harvard Business School and a Bachelor of Science degree in information systems and finance from New York University.""" - # summary = "Jeff joined Microsoft in 1992 to lead the company's corporate evangelism. He then served as a Group Manager in Microsoft's Internet Business Unit. In 1998, Jeff led Sharepoint Portal Server, which became the company's fastest-growing business, surpassing $3 million in revenue. Jeff next leads corporate strategy for SharePoint and Servers which is the basis of Microsoft's cloud-first strategy. He leads corporate strategy for Satya Nadella and Amy Hood on Microsoft's mobile-first." - - # metric = SummaCZSMetric(granularity="sentence", - # use_ent=True, - # use_con=False) - # print(metric.compute([document], [summary], [])) - - # metric = SummaCConvMetric(granularity="sentence") - # print(metric.compute([document], [summary], [])) - - prompt_texts = ["1", "2"] - gen_texts = [ - "The dog is the boy's cat.", - "A boy is picking apples from trees and put them into bags.", - ] - reference_texts = [ - ["The dog is the boy's cat.", "The dog eats the cat of the boy."], - ["A boy is picking apples from trees."], - ] diff --git a/benchmark/torch/RL4LMs/models/__init__.py b/benchmark/torch/RL4LMs/models/__init__.py deleted file mode 100644 index ed9b32d20..000000000 --- a/benchmark/torch/RL4LMs/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .seq2seq_model import Seq2SeqLMModel \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/registry.py b/benchmark/torch/RL4LMs/registry.py deleted file mode 100644 index 3384332db..000000000 --- a/benchmark/torch/RL4LMs/registry.py +++ /dev/null @@ -1,170 +0,0 @@ -from typing import Any, Dict, Type, Union - -import parl -from benchmark.torch.RL4LMs.algorithms import RL4LMPPO -from benchmark.torch.RL4LMs.agents import RL4LMsAgent - -from benchmark.torch.RL4LMs.utils import TextGenPool, CNNDailyMail -# from rl4lms.envs.text_generation.alg_wrappers import wrap_onpolicy_alg -from parl.utils import logger - -from benchmark.torch.RL4LMs.metrics import ( - BaseMetric, - BERTScoreMetric, - BLEUMetric, - BLEURTMetric, - BLEUToTTo, - DiversityMetrics, - LearnedRewardMetric, - MeteorMetric, - Perplexity, - RougeLMax, - RougeMetric, - SacreBLEUMetric, - TERMetric, - chrFmetric, -) - -from benchmark.torch.RL4LMs.models import Seq2SeqLMModel - -from benchmark.torch.RL4LMs.utils import ( - BERTScoreRewardFunction, - BLEURewardFunction, - BLEURTRewardFunction, - CommonGenPenaltyShapingFunction, - LearnedRewardFunction, - MeteorRewardFunction, - RewardFunction, - RougeCombined, - RougeLMaxRewardFunction, - RougeRewardFunction, - SacreBleu, -) - - - -class DataPoolRegistry: - _registry = { - "cnn_daily_mail": CNNDailyMail, - } - - @classmethod - def get(cls, datapool_id: str, kwargs: Dict[str, Any]) -> TextGenPool: - logger.info(f"loading split of dataset: {datapool_id} -- {kwargs['split']}") - datapool_cls = cls._registry[datapool_id] - datapool = datapool_cls.prepare(**kwargs) - return datapool - - @classmethod - def add(cls, id: str, datapool_cls: Type[TextGenPool]): - DataPoolRegistry._registry[id] = datapool_cls - - -class RewardFunctionRegistry: - _registry = { - "learned_reward": LearnedRewardFunction, - "meteor": MeteorRewardFunction, - "rouge": RougeRewardFunction, - "bert_score": BERTScoreRewardFunction, - "bleu": BLEURewardFunction, - "bleurt": BLEURTRewardFunction, - "rouge_combined": RougeCombined, - "common_gen_repeat_penalty": CommonGenPenaltyShapingFunction, - "sacre_bleu": SacreBleu, - "rouge_l_max": RougeLMaxRewardFunction, - } - - @classmethod - def get(cls, reward_fn_id: str, kwargs: Dict[str, Any]) -> RewardFunction: - logger.info(f"loading reward function: {reward_fn_id}") - reward_cls = cls._registry[reward_fn_id] - reward_fn = reward_cls(**kwargs) - return reward_fn - - @classmethod - def add(cls, id: str, reward_fn_cls: Type[RewardFunction]): - RewardFunctionRegistry._registry[id] = reward_fn_cls - - -class MetricRegistry: - _registry = { - "learned_reward": LearnedRewardMetric, - "meteor": MeteorMetric, - "rouge": RougeMetric, - "bert_score": BERTScoreMetric, - "bleu": BLEUMetric, - "bleurt": BLEURTMetric, - "diversity": DiversityMetrics, - - "causal_perplexity": Perplexity, - - "bleu_totto": BLEUToTTo, - "rouge_l_max": RougeLMax, - "sacre_bleu": SacreBLEUMetric, - "ter": TERMetric, - "chrf": chrFmetric, - - } - - @classmethod - def get(cls, metric_id: str, kwargs: Dict[str, Any]) -> BaseMetric: - logger.info(f"loading metric: {metric_id}") - metric_cls = cls._registry[metric_id] - metric = metric_cls(**kwargs) - return metric - - @classmethod - def add(cls, id: str, metric_cls: Type[BaseMetric]): - MetricRegistry._registry[id] = metric_cls - - -class ModelRegistry: - _registry = { - "seq2seq_lm_actor_critic_model": Seq2SeqLMModel, - } - - @classmethod - def get(cls, model_id: str) -> Type[parl.Model]: - model_cls = cls._registry[model_id] - return model_cls - - -class AlgorithmRegistry: - _registry = { - "ppo": RL4LMPPO, - } - - @classmethod - def get( - cls, alg_id: str - ): - try: - alg_cls = cls._registry[alg_id] - except KeyError: - raise NotImplementedError - return alg_cls - - @classmethod - def add( - cls, id: str, alg_cls - ): - AlgorithmRegistry._registry[id] = alg_cls - - -class AgentRegistry: - _registry = { - "rl4lm_agent": RL4LMsAgent, - } - - @classmethod - def get(cls, alg_id: str): - try: - wrapper_def = cls._registry[alg_id] - except KeyError: - raise NotImplementedError - return wrapper_def - - @classmethod - def add(cls, id: str, agent_def): - AgentRegistry._registry[id] = agent_def - diff --git a/benchmark/torch/RL4LMs/algorithms/ppo.py b/benchmark/torch/RL4LMs/rl4lm_ppo.py similarity index 97% rename from benchmark/torch/RL4LMs/algorithms/ppo.py rename to benchmark/torch/RL4LMs/rl4lm_ppo.py index 060da4fc9..4a1c753e7 100644 --- a/benchmark/torch/RL4LMs/algorithms/ppo.py +++ b/benchmark/torch/RL4LMs/rl4lm_ppo.py @@ -1,5 +1,4 @@ import parl -from benchmark.torch.RL4LMs.utils import Schedule from typing import Union, Optional, Dict, Any import torch from gym import spaces @@ -12,13 +11,13 @@ class RL4LMPPO(parl.Algorithm): def __init__(self, model: parl.Model, - learning_rate: Union[float, Schedule] = 3e-4, + learning_rate = 3e-4, n_steps: int = 2048, batch_size: int = 64, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, - clip_range: Union[float, Schedule] = 0.2, + clip_range = 0.2, normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, diff --git a/benchmark/torch/RL4LMs/agents/rl4lms_agent.py b/benchmark/torch/RL4LMs/rl4lms_agent.py similarity index 96% rename from benchmark/torch/RL4LMs/agents/rl4lms_agent.py rename to benchmark/torch/RL4LMs/rl4lms_agent.py index 0c366df7d..468a9a3d2 100644 --- a/benchmark/torch/RL4LMs/agents/rl4lms_agent.py +++ b/benchmark/torch/RL4LMs/rl4lms_agent.py @@ -1,11 +1,10 @@ import parl import numpy as np -from typing import List import torch from parl.utils import logger -def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: +def explained_variance(y_pred, y_true): """ Computes fraction of variance that ypred explains about y. Returns 1 - Var[y-ypred] / Var[y] @@ -24,7 +23,7 @@ class RL4LMsAgent(parl.Agent): def __init__(self, algorithm, alg_config, - norm_reward: bool = False, + norm_reward = False, ): super(RL4LMsAgent, self).__init__(algorithm) self.dataset = None diff --git a/benchmark/torch/RL4LMs/models/seq2seq_model.py b/benchmark/torch/RL4LMs/seq2seq_model.py similarity index 87% rename from benchmark/torch/RL4LMs/models/seq2seq_model.py rename to benchmark/torch/RL4LMs/seq2seq_model.py index 08da7836c..7704a7d6f 100644 --- a/benchmark/torch/RL4LMs/models/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/seq2seq_model.py @@ -1,9 +1,6 @@ -from typing import Any, Dict, Optional, List, Union import torch -from gym.spaces import Discrete -from gym.spaces.dict import Dict as DictSpace from torch import nn -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel +from transformers import AutoModelForSeq2SeqLM from copy import deepcopy from torch.distributions import Categorical @@ -11,12 +8,10 @@ import parl from benchmark.torch.RL4LMs.utils import ( - override_generation_routines, - - TensorDict, CategoricalDistribution, + override_generation_routines, CategoricalDistribution, GenerationInputs, PolicyOutput, RefPolicyOutput, ValueOutput, - PolicyType, EvaluateActionsOutput, GenerationOutputs, + EvaluateActionsOutput, GenerationOutputs, ) @@ -24,17 +19,17 @@ class Seq2SeqLMModel(parl.Model): def __init__( self, - observation_space: DictSpace, - action_space: Discrete, - model_name: str, - optimizer_kwargs: Dict[str, Any] = {}, - weight_decay: float = 1e-6, - apply_model_parallel: bool = True, - optimizer_class: torch.optim.Optimizer = torch.optim.AdamW, - generation_kwargs: Dict[str, Any] = {}, - prompt_truncation_side: str = "left", - state_dict: Dict[str, Any] = None, - device: torch.DeviceObjType = None, + observation_space, + action_space, + model_name, + optimizer_kwargs = {}, + weight_decay = 1e-6, + apply_model_parallel = True, + optimizer_class = torch.optim.AdamW, + generation_kwargs = {}, + prompt_truncation_side = "left", + state_dict = None, + device = None, ): super(Seq2SeqLMModel, self).__init__() if optimizer_kwargs is None: @@ -57,10 +52,7 @@ def __init__( self._prompt_truncation_side = prompt_truncation_side - - # self.load_from_dict(state_dict) - - def _build_model_heads(self, model_name: str): + def _build_model_heads(self, model_name): self._policy_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) self._policy_model.__class__ = override_generation_routines(type(self._policy_model)) @@ -88,10 +80,10 @@ def _build_model_heads(self, model_name: str): def forward_policy( self, - obs: TensorDict, - actions: torch.tensor, - past_model_kwargs: Optional[Dict[str, torch.tensor]] = None, - ) -> PolicyOutput: + obs, + actions, + past_model_kwargs = None, + ): # Temp workaround for Seq2seq policy past_model_kwargs = None @@ -162,9 +154,9 @@ def forward_policy( def forward_value( self, - obs: TensorDict, - past_model_kwargs: Optional[Dict[str, torch.tensor]] = None, - ) -> ValueOutput: + obs, + past_model_kwargs = None, + ): # Temp workaround for Seq2seq policy past_model_kwargs = None @@ -231,8 +223,8 @@ def forward_value( return value_output def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> EvaluateActionsOutput: + self, obs, actions + ): policy_outputs = self.forward_policy(obs=obs, actions=actions) value_outputs = self.forward_value(obs) @@ -244,7 +236,7 @@ def evaluate_actions( ) return eval_outputs - def to(self, device: str): + def to(self, device): if self._apply_model_parallel: self._value_head = self._value_head.to(device) return self @@ -253,10 +245,10 @@ def to(self, device: str): def get_log_probs_ref_model( self, - obs: TensorDict, - action: torch.tensor, - model_kwarpast_model_kwargsgs: Dict[str, Any] = None, - ) -> RefPolicyOutput: + obs, + action, + model_kwarpast_model_kwargsgs = None, + ): # Temp workaround for Seq2seq policy past_model_kwargs = None @@ -326,28 +318,25 @@ def get_policy_first_device(self): else self.device ) - def get_inputs_for_generation(self, obs: TensorDict) -> GenerationInputs: + def get_inputs_for_generation(self, obs): generation_inputs = GenerationInputs( obs["prompt_or_input_encoded_pt"], obs["prompt_or_input_attention_mask_pt"] ) return generation_inputs - def get_policy_type(self): - return PolicyType.SEQ2SEQ - def get_language_model(self): return unwrap_model(self._policy_model) def generate( self, - tokenizer: AutoTokenizer, - texts: List[str] = None, - max_prompt_length: int = None, - input_ids: torch.tensor = None, - attention_mask: torch.tensor = None, - gen_kwargs: Dict[str, Any] = None, - ) -> GenerationOutputs: + tokenizer, + texts = None, + max_prompt_length = None, + input_ids = None, + attention_mask = None, + gen_kwargs = None, + ): # if it different from rollout gen kwargs if gen_kwargs is None: @@ -427,20 +416,20 @@ def generate( return gen_output - def is_encoder_decoder(self, model: PreTrainedModel): + def is_encoder_decoder(self, model): return unwrap_model(model).config.is_encoder_decoder - def set_training_mode(self, mode: bool) -> None: + def set_training_mode(self, mode): self.train(mode) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self): return dict( observation_space=self.observation_space, action_space=self.action_space, ) - def save(self, path: str) -> None: + def save(self, path): """ Save model to a given location. @@ -451,9 +440,9 @@ def save(self, path: str) -> None: def _setup_optimizer( self, - optimizer_kwargs: Dict[str, Any], - weight_decay: float, - optimizer_class: torch.optim, + optimizer_kwargs, + weight_decay, + optimizer_class, ): params = list(self.named_parameters()) diff --git a/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml b/benchmark/torch/RL4LMs/t5_ppo.yml similarity index 97% rename from benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml rename to benchmark/torch/RL4LMs/t5_ppo.yml index 50fe402ad..75e18162d 100644 --- a/benchmark/torch/RL4LMs/configs/summarization/t5_ppo.yml +++ b/benchmark/torch/RL4LMs/t5_ppo.yml @@ -7,7 +7,6 @@ tokenizer: pad_token_as_eos_token: False reward_fn: - id: rouge args: rouge_type: "rouge1" @@ -30,8 +29,6 @@ env: context_start_token: 0 alg: - agent_id: rl4lm_agent - id: ppo args: # n_steps: 512 #####CHNAGE FOR DEBUG######## diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index df81a8dc3..e2811fc02 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -4,9 +4,27 @@ import datetime import yaml import collections -from trainers import OnPolicyTrainer from parl.utils import logger +import torch +import time + +# env and reward function +from utils import build_reward_fn +from env import TextGenEnv, make_vec_env + +# evaluation, metrics, tokenizer & dataset +from utils import build_metrics, build_tokenizer, build_datapool +from utils import evaluate_on_samples + +# rollout +from utils import MaskableDictRolloutBuffer, RolloutUtil + +# agent, algorithm and model +from rl4lm_ppo import RL4LMPPO +from rl4lms_agent import RL4LMsAgent +from seq2seq_model import Seq2SeqLMModel + def recursive_dict_update(d, u): for k, v in u.items(): @@ -18,25 +36,110 @@ def recursive_dict_update(d, u): def main(config): + device = torch.device("cuda" if torch.cuda. + is_available() else "cpu") + + rollout_util = RolloutUtil(config["alg"]["kl_div"]) + + tokenizer = build_tokenizer(config["tokenizer"]) + + # reward function & metrics + reward_fn = build_reward_fn(config["reward_fn"]) + metrics = build_metrics(config["train_evaluation"]["metrics"]) + + # datapool + samples_by_split = build_datapool(config["datapool"]) + + env = make_vec_env(env_id=TextGenEnv, + env_config=config["env"], + reward_fn=reward_fn, + tokenizer=tokenizer, + train_samples= samples_by_split["train"]) - # instantiate the trainer here - # TODO: currently only complete ppo - if "ppo" == config["alg"]["id"]: - trainer = OnPolicyTrainer( - tokenizer_config=config["tokenizer"], - datapool_config=config["datapool"], - reward_config=config["reward_fn"], - env_config=config["env"], - on_policy_alg_config=config["alg"], - train_eval_config=config["train_evaluation"], - ) - else: - raise NotImplementedError - trainer.train_and_eval() + rl4lms_model = Seq2SeqLMModel( + observation_space = env.observation_space, + action_space= env.action_space, + device=device, + **config["alg"]["model"]["args"] + ) + rl4lm_alg = RL4LMPPO(model=rl4lms_model, device=device, **config["alg"]["args"]) + agent = RL4LMsAgent(rl4lm_alg, config["alg"]) + + rollout_buffer = MaskableDictRolloutBuffer( + buffer_size=agent.alg.n_steps * env.num_envs, + observation_space=env.observation_space, + action_space=env.action_space, + device=device, + gamma=agent.alg.gamma, + gae_lambda=agent.alg.gae_lambda, + n_envs=1, + ) + + n_iters = int(config["train_evaluation"]["n_iters"]) + n_steps_per_iter = env.num_envs * agent.alg.n_steps + + max_prompt_length = config["env"]["args"]["max_prompt_length"] + # gen kwargs for evaluation + eval_gen_kwargs = config["train_evaluation"]["generation_kwargs"] + eval_batch_size = config["train_evaluation"]["eval_batch_size"] + eval_splits = ["val", "test"] + iter_start = 0 + for sp in eval_splits: + evaluate_on_samples(policy=agent.alg.model, + tokenizer=tokenizer, + samples=samples_by_split[sp], + batch_size=eval_batch_size, + max_prompt_length=max_prompt_length, + metrics=metrics, + epoch=iter_start, + split_name=sp, + gen_kwargs=eval_gen_kwargs) + epoch = 0 + for epoch in range(iter_start, n_iters): + print("========== BEGIN ==========") + print(f"outer epoch: {epoch} / {n_iters - 1}") + print("========== BEGIN ==========") + outer_start_time = time.time() + num_timesteps = 0 + while num_timesteps < n_steps_per_iter: + run_timesteps = rollout_util.collect_rollouts(agent, env, rollout_buffer, device) + num_timesteps += run_timesteps + agent.learn(rollout_buffer) + + outer_end_time = time.time() + print("========== END ==========") + print(f"outer epoch: {epoch} / {n_iters - 1}") + print(f"time used: {outer_end_time - outer_start_time} second(s), left time:" + f" {1.0 * (outer_end_time - outer_start_time) * (n_iters - epoch - 1) / 60 / 60} hour(s)") + print("========== END ==========") + + # evaluate on val set in the given intervals + if (epoch + 1) % config["train_evaluation"]["eval_every"] == 0: + evaluate_on_samples(policy=agent.alg.model, + tokenizer=tokenizer, + samples=samples_by_split["val"], + batch_size=eval_batch_size, + max_prompt_length=max_prompt_length, + metrics=metrics, + epoch=epoch, + split_name="val", + gen_kwargs=eval_gen_kwargs) + + + for sp in eval_splits: + evaluate_on_samples(policy=agent.alg.model, + tokenizer=tokenizer, + samples=samples_by_split[sp], + batch_size=eval_batch_size, + max_prompt_length=max_prompt_length, + metrics=metrics, + epoch=epoch, + split_name=sp, + gen_kwargs=eval_gen_kwargs) if __name__ == '__main__': @@ -73,5 +176,6 @@ def main(config): config["sys_arg"] = sys.argv logger.info(config) logger.set_level("DEBUG") + main(config) diff --git a/benchmark/torch/RL4LMs/trainers.py b/benchmark/torch/RL4LMs/trainers.py deleted file mode 100644 index 8157ee4d0..000000000 --- a/benchmark/torch/RL4LMs/trainers.py +++ /dev/null @@ -1,566 +0,0 @@ -import time -from typing import Any, Dict, List -import numpy as np - - -from benchmark.torch.RL4LMs.utils import Sample, RewardFunction,\ - evaluate_on_samples,\ - KLController, MaskableDictRolloutBuffer,\ - TransitionInfo, TensorDict, RefPolicyOutput, ValueOutput, PolicyOutput -from benchmark.torch.RL4LMs.registry import DataPoolRegistry, MetricRegistry, RewardFunctionRegistry, \ - ModelRegistry, AlgorithmRegistry, AgentRegistry -from benchmark.torch.RL4LMs.env import TextGenEnv -from transformers import AutoTokenizer -from benchmark.torch.RL4LMs.env import LocalParallelVecEnv, make_vec_env -from transformers import PreTrainedTokenizer -import torch -from parl.utils import logger - -def build_tokenizer(tokenizer_config: Dict[str, Any]): - logger.info(f"loading tokenizer of [{tokenizer_config['model_name']}] model") - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_config["model_name"]) - if tokenizer.pad_token is None and tokenizer_config.get("pad_token_as_eos_token", True): - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = tokenizer_config.get( - "padding_side", "left") - tokenizer.truncation_side = tokenizer_config.get( - "truncation_side", "left") - return tokenizer - - -def build_reward_fn(reward_config: Dict[str, Any]): - reward_fn = RewardFunctionRegistry.get(reward_config["id"], - reward_config.get("args", {})) - return reward_fn - - -def build_metrics(metric_configs: List[Dict[str, Any]]): - metrics = [MetricRegistry.get(metric_config["id"], metric_config.get("args", {})) - for metric_config in metric_configs] - return metrics - - -def build_datapool(datapool_config: Dict[str, Any]): - - def _get_datapool_by_split(split: str): - kwargs = datapool_config.get("args", {}) - kwargs["split"] = split - dp_split = DataPoolRegistry.get(datapool_config["id"], kwargs) - return dp_split - - train_datapool = _get_datapool_by_split("train") - val_datapool = _get_datapool_by_split("val") - test_datapool = _get_datapool_by_split("test") - - samples_by_split = { - "train": [(sample, weight) - for sample, weight in train_datapool], - "val": [sample for sample, _ in val_datapool], - "test": [sample for sample, _ in test_datapool] - } - return samples_by_split - - -def build_env(env_config: Dict[str, Any], - reward_fn: RewardFunction, - tokenizer: AutoTokenizer, - train_samples: List[Sample]): - # vectoried env - env_kwargs = { - "reward_function": reward_fn, - "tokenizer": tokenizer, - "samples": train_samples, - } - env_kwargs = {**env_kwargs, **env_config.get("args", {})} - envs = make_vec_env(TextGenEnv, - n_envs=env_config.get( - "n_envs", 1), - vec_env_cls=LocalParallelVecEnv, - env_kwargs=env_kwargs) - return envs - -def build_agent(alg_config: Dict[str, Any], - env: LocalParallelVecEnv, - model_state: Dict[str, Any] = None, # TODO: save model checkpoint - device = None, - alg_state: Dict[str, Any] = None # TODO: save alg checkpoint - ): - model_config = alg_config["model"] - model_cls = ModelRegistry.get(model_config["id"]) - alg_cls = AlgorithmRegistry.get(alg_config["id"]) - agent_cls = AgentRegistry.get(alg_config["agent_id"]) - - model_args = model_config["args"] - model_args["state_dict"] = model_state - - rl4lms_model = model_cls( - observation_space = env.observation_space, - action_space= env.action_space, - device=device, - **model_args - ) - - rl4lm_alg_cls = alg_cls( - model=rl4lms_model, - device=device, - **alg_config.get("args") - ) - - rl4lm_agent = agent_cls(rl4lm_alg_cls, alg_config) - return rl4lm_agent - - -def dict_to_tensor(obs, device): - return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} - - -def unpack_observations(obs_tensor, n_envs: int): - """ - Unpacks vectorized dict observations into separate dict observations - """ - unpacked_obs = [] - keys = obs_tensor.keys() - for env_ix in range(n_envs): - obs_dict = {} - for key in keys: - obs_dict[key] = obs_tensor[key][env_ix].reshape(1, -1).cpu() - unpacked_obs.append(obs_dict) - return unpacked_obs - - -class OnPolicyTrainer: - """ - A generic trainer for training LMs with onpolicy algorithms from SB3 - """ - - def __init__(self, - tokenizer_config: Dict[str, Any], - datapool_config: Dict[str, Any], - reward_config: Dict[str, Any], - env_config: Dict[str, Any], - on_policy_alg_config: Dict[str, Any], - train_eval_config: Dict[str, Any], - experiment_name: str = '' - ): - # - self._tokenizer = None - self._tokenizer_config = tokenizer_config - - # datapool - self._datapool_config = datapool_config - self._samples_by_split = None - - # reward function & metrics - self._reward_config = reward_config - self._reward_fn = None - self._metrics = None - self._norm_reward = False - - # env - self._env_config = env_config - self._env = None - - # algorithm config & model config - self._on_policy_alg_config = on_policy_alg_config - - # agent - self._agent = None - - # rollout buffer - self._rollout_buffer = None - - self._train_eval_config = train_eval_config - self._experiment_name = experiment_name - self._num_timesteps = None - self._kl_controller = None - self.device = torch.device("cuda" if torch.cuda. - is_available() else "cpu") - - self._setup() - - def train_and_eval(self): - # evaluate on val and test set before fine-tuning once - # iter_start = self._trainer_state["current_iter"] - iter_start = 0 - self._evaluate_on_datapools(epoch=iter_start) - - # train for given number of iters - for epoch in range(iter_start, self._n_iters): - print("========== BEGIN ==========") - print(f"outer epoch: {epoch} / {self._n_iters - 1}") - print("========== BEGIN ==========") - outer_start_time = time.time() - # current state - # self._trainer_state["current_iter"] = epoch - - self._num_timesteps = 0 - - while self._num_timesteps < self._n_steps_per_iter: - self._collect_rollouts(self._env, self._rollout_buffer) - # inner rollout and learn loop for on-policy algorithm - # self._agent.learn(self._n_steps_per_iter) - self._agent.learn(self._rollout_buffer) - - # save the policy checkpoint - # if (epoch + 1) % self._train_eval_config.get("save_every", 20) == 0: - # self.save_trainer_state( - # self._tracker, self._alg.policy, self._trainer_state) - - # evaluate on val set in the given intervals - if (epoch + 1) % self._train_eval_config["eval_every"] == 0: - self._evaluate_on_datapools(epoch=epoch, splits=["val"]) - - outer_end_time = time.time() - print("========== END ==========") - print(f"outer epoch: {epoch} / {self._n_iters - 1}") - print(f"time used: {outer_end_time - outer_start_time} second(s), left time:" - f" {1.0 * (outer_end_time - outer_start_time) * (self._n_iters - epoch - 1) / 60 / 60} hour(s)") - print("========== END ==========") - - - # finally evaluate on val and test samples - self._evaluate_on_datapools(epoch=epoch) - - # # save model here - we save only the language model - # if self._tracker is not None: - # self._tracker.save_auto_model( - # self._alg.policy.get_language_model()) - - def _setup(self): - - # load trainer state from available previous checkpoint if available - # self.load_trainer_state(self._tracker) - - # build components - self._tokenizer = build_tokenizer(self._tokenizer_config) - self._reward_fn = build_reward_fn(self._reward_config) - self._metrics = build_metrics( - self._train_eval_config.get("metrics", [])) - self._samples_by_split = build_datapool(self._datapool_config) - self._env = build_env(self._env_config, self._reward_fn, - self._tokenizer, self._samples_by_split["train"]) - - - self._agent = build_agent(self._on_policy_alg_config, - self._env, device=self.device) - - self._rollout_buffer = MaskableDictRolloutBuffer( - buffer_size=self._agent.alg.n_steps * self._env.num_envs, - observation_space=self._env.observation_space, - action_space=self._env.action_space, - device=self.device, - gamma=self._agent.alg.gamma, - gae_lambda=self._agent.alg.gae_lambda, - n_envs=1, - ) - - self._kl_controller = KLController( - self._on_policy_alg_config["kl_div"]["coeff"], - self._on_policy_alg_config["kl_div"].get("target_kl", None)) - - # extract train params - self._max_episode_length = self._env_config["args"]["max_episode_length"] - self._max_prompt_length = self._env_config["args"]["max_prompt_length"] - self._eval_batch_size = self._train_eval_config["eval_batch_size"] - self._n_iters = int(self._train_eval_config["n_iters"]) - self._n_steps_per_iter = self._env.num_envs * self._agent.alg.n_steps - self._num_timesteps = 0 - - # gen kwargs for evaluation (if it is different from rollout gen kwargs) - self._eval_gen_kwargs = self._train_eval_config.get( - "generation_kwargs", None) - - def _get_policy_kwargs( - self, - obs: TensorDict, - action: torch.tensor, - past_state: Dict[str, torch.tensor], - action_mask: torch.tensor, - ): - - policy_kwargs = { - "obs": obs, - "actions": action, - "past_model_kwargs": past_state, - } - if action_mask is not None: - policy_kwargs["action_masks"] = action_mask - return policy_kwargs - - def _generate_batch( - self, - rollout_buffer, - tokenizer: PreTrainedTokenizer, - max_steps: int, - rollout_info: Dict[str, Any], - ): - # if rollout buffer is already full, do not continue - if rollout_buffer.full: - return - - # start parallel episodes - current_obs = self._env.reset() - episode_starts = np.ones((self._env.num_envs,), dtype=bool) - - # generate text using the model - obs_tensor = dict_to_tensor(current_obs, self.device) - generation_inputs = self._agent.get_inputs_for_generation(obs_tensor) - gen_output = self._agent.generate( - input_ids=generation_inputs.inputs, - attention_mask=generation_inputs.attention_masks, - tokenizer=tokenizer, - ) - - # process them one step at a time to collect rollout info - episode_wise_transitions = [[] for _ in range(self._env.num_envs)] - ep_terminated = np.zeros((self._env.num_envs,), dtype=bool) - value_past_state = None - ref_past_state = None - policy_past_state = None - masks = ( - gen_output.action_masks - if gen_output.action_masks is not None - else [None] * len(gen_output.step_wise_logprobs) - ) - - for actions_tensor, _, action_mask in zip( - gen_output.step_wise_actions, gen_output.step_wise_logprobs, masks - ): - # if all episodes are done, just break and do not continue - if np.all(ep_terminated): - break - - # evaluate actions with actions from rollout - with torch.no_grad(): - obs_tensor = dict_to_tensor(current_obs, self.device) - - # get log probs (TBD: generalize this a bit) - policy_kwargs = self._get_policy_kwargs( - obs_tensor, actions_tensor, policy_past_state, action_mask - ) - - policy_outputs: PolicyOutput = self._agent.forward_policy( - **policy_kwargs - ) - raw_log_probs, log_probs, policy_past_state = ( - policy_outputs.raw_log_probs, - policy_outputs.log_probs, - policy_outputs.past_model_kwargs, - ) - - # sanity check - assert torch.all( - torch.isfinite(log_probs) - ), "Infinite values in log probs" - - # sanity check - assert torch.all( - torch.isfinite(raw_log_probs) - ), "Infinite values in log probs" - - # get values - value_outputs: ValueOutput = self._agent.forward_value( - obs_tensor, value_past_state - ) - values, value_past_state = ( - value_outputs.values, - value_outputs.past_model_kwargs, - ) - - # get reference log probs - ref_policy_outputs: RefPolicyOutput = ( - self._agent.get_log_probs_ref_model( - obs_tensor, actions_tensor, ref_past_state - ) - ) - ref_log_probs, ref_past_state = ( - ref_policy_outputs.log_probs, - ref_policy_outputs.past_model_kwargs, - ) - - # sanity check - assert torch.all( - torch.isfinite(ref_log_probs) - ), "Infinite values in log probs" - - # compute KL rewards - kl_div = raw_log_probs - ref_log_probs - kl_rewards = -1 * self._kl_controller.kl_coeff * kl_div - - # step into env to get rewards - actions = actions_tensor.cpu().numpy() - new_obs, rewards, dones, infos = self._env.step(actions) - - self._num_timesteps += self._env.num_envs - - # compute total rewards - total_rewards = rewards + kl_rewards.cpu().numpy() - - # unpack individual observations - unpacked_obs = unpack_observations(obs_tensor, self._env.num_envs) - - # store episode wise transitions separately - for env_ix in range(self._env.num_envs): - # only if not terminated already - if not ep_terminated[env_ix]: - transtion = TransitionInfo( - observation=unpacked_obs[env_ix], - action=actions[env_ix], - task_reward=rewards[env_ix], - total_reward=total_rewards[env_ix], - kl_div=kl_div.cpu().numpy()[env_ix], - episode_start=episode_starts[env_ix], - value=values[env_ix].cpu(), - log_prob=log_probs[env_ix].cpu(), - done=dones[env_ix], - ref_log_prob=ref_log_probs[env_ix].cpu(), - kl_reward=kl_rewards.cpu().numpy()[env_ix], - action_mask=action_mask[env_ix].cpu().numpy() - if action_mask is not None - else None, - info=infos[env_ix], - ) - - episode_wise_transitions[env_ix].append(transtion) - - # mark this episode to terminated if done occurs once - if dones[env_ix]: - ep_terminated[env_ix] = True - - episode_starts = np.zeros((self._env.num_envs,), dtype=bool) - current_obs = new_obs - - # now we flush all episode wise info to the 1-D buffer - rollout_info = self._add_to_buffer( - rollout_buffer, episode_wise_transitions, rollout_info - ) - return rollout_info - - def _evaluate_on_datapools(self, epoch: int, - splits: List[str] = ["val", "test"]): - for split in splits: - evaluate_on_samples(policy=self._agent.alg.model, - tokenizer=self._tokenizer, - samples=self._samples_by_split[split], - batch_size=self._eval_batch_size, - max_prompt_length=self._max_prompt_length, - metrics=self._metrics, - epoch=epoch, - split_name=split, - gen_kwargs=self._eval_gen_kwargs) - - def _add_to_buffer( - self, rollout_buffer, episode_wise_transitions, rollout_info - ): - # if the reward function is batchable, we override the rewards here - # if isinstance(self.reward_fn, BatchedRewardFunction): - # compute_batched_rewards(episode_wise_transitions, self.reward_fn) - - advantages_computed = False - for ep_ix, transitions in enumerate(episode_wise_transitions): - ep_length = len(transitions) - total_reward = 0.0 - total_kl_reward = 0.0 - for transition_ix, transition in enumerate(transitions): - total_reward += transition.task_reward - total_kl_reward += transition.kl_reward - rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div) - rollout_info["rollout_info/log_prob"].append(transition.log_prob) - rollout_info["rollout_info/ref_log_prob"].append( - transition.ref_log_prob - ) - rollout_info["rollout_info/values"].append(transition.value.numpy()) - - if not rollout_buffer.full: - rollout_buffer.add( - transition.observation, - transition.action, - transition.total_reward, - transition.episode_start, - transition.value, - transition.log_prob, - action_masks=transition.action_mask, - ) - - # if the buffer is full, compute advantages - if rollout_buffer.full and not advantages_computed: - - # normalize the rewards - if self._norm_reward: - mean = rollout_buffer.rewards.mean() - std = rollout_buffer.rewards.std() - rollout_buffer.rewards = (rollout_buffer.rewards - mean) / ( - std + 1e-8 - ) - - # we fetch the last value for the last time step - # values come from the next transitions's values - next_values = ( - transitions[transition_ix + 1].value - if (transition_ix + 1) < ep_length - else torch.tensor([0.0]) - ) - - rollout_buffer.compute_returns_and_advantage( - last_values=next_values, dones=transition.done - ) - advantages_computed = True - - rollout_info["rollout_info/ep_rew"].append(total_reward) - rollout_info["rollout_info/ep_lens"].append(ep_length) - rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward) - return rollout_info - - def _collect_rollouts( - self, - env, - rollout_buffer: MaskableDictRolloutBuffer, - ) -> bool: - # max episode steps - max_steps = env.get_attr("max_steps", [0])[0] - - # get tokenizer - tokenizer = env.get_attr("tokenizer", [0]) - tokenizer = tokenizer[0] - - # Switch to eval mode - # self._agent.alg.model.set_training_mode(False) - self._agent.eval_mode() - - # reset rollout buffer and stats - rollout_buffer.reset() - - # start the rollout process - rollout_info = { - "rollout_info/ep_rew": [], - "rollout_info/kl_div_mean": [], - "rollout_info/ep_lens": [], - "rollout_info/ep_kl_rew": [], - "rollout_info/log_prob": [], - "rollout_info/ref_log_prob": [], - "rollout_info/values": [], - } - while not rollout_buffer.full: - # generate batch of rollouts - rollout_info = self._generate_batch( - rollout_buffer, tokenizer, max_steps, rollout_info - ) - - # aggregate rollout info - aggregated_rollout_info = {} - for key, values in rollout_info.items(): - aggregated_rollout_info[key] = np.mean(values).item() - aggregated_rollout_info[f"{key}_std"] = np.std(values).item() - aggregated_rollout_info[ - "rollout_info/kl_coeff" - ] = self._kl_controller.kl_coeff - - # if self.tracker is not None: - # self.tracker.log_rollout_infos(aggregated_rollout_info) - - # adapt the KL coeff - self._kl_controller.step( - torch.tensor(aggregated_rollout_info["rollout_info/kl_div_mean"]) - ) - return True \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/__init__.py b/benchmark/torch/RL4LMs/utils/__init__.py index b093afb85..363e8e266 100644 --- a/benchmark/torch/RL4LMs/utils/__init__.py +++ b/benchmark/torch/RL4LMs/utils/__init__.py @@ -5,8 +5,6 @@ from .huggingface_generation_util import override_generation_routines -from .type_wrapper import TensorDict, Schedule - from .distribution_wrapper import CategoricalDistribution from .sample_util import PrioritySampler @@ -17,8 +15,11 @@ from .evaluation_util import evaluate_on_samples -from .data_pool import TextGenPool, CNNDailyMail +from .data_pool import CNNDailyMail + +from .reward_util import RougeRewardFunction + +from .component_build_util import build_tokenizer, build_metrics, build_reward_fn,\ + build_datapool -from .reward_util import RewardFunction, RougeRewardFunction, RougeLMaxRewardFunction, \ - BatchedRewardFunction, BERTScoreRewardFunction, BLEURewardFunction, BLEURTRewardFunction, MeteorRewardFunction,\ - LearnedRewardFunction, SacreBleu, CommonGenPenaltyShapingFunction, RougeCombined +from .rollout_util import RolloutUtil diff --git a/benchmark/torch/RL4LMs/utils/buffer.py b/benchmark/torch/RL4LMs/utils/buffer.py index b0a99a7f5..7e9ec9123 100644 --- a/benchmark/torch/RL4LMs/utils/buffer.py +++ b/benchmark/torch/RL4LMs/utils/buffer.py @@ -1,11 +1,6 @@ -import warnings -from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Union, Tuple - import numpy as np import torch from gym import spaces - from .data_wrapper import MaskableDictRolloutBufferSamples try: @@ -14,32 +9,9 @@ except ImportError: psutil = None - -def get_action_dim(action_space: spaces.Space) -> int: - """ - Get the dimension of the action space. - - :param action_space: - :return: - """ - if isinstance(action_space, spaces.Box): - return int(np.prod(action_space.shape)) - elif isinstance(action_space, spaces.Discrete): - # Action is an int - return 1 - elif isinstance(action_space, spaces.MultiDiscrete): - # Number of discrete actions - return int(len(action_space.nvec)) - elif isinstance(action_space, spaces.MultiBinary): - # Number of binary actions - return int(action_space.n) - else: - raise NotImplementedError(f"{action_space} action space is not supported") - - def get_obs_shape( - observation_space: spaces.Space, -) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]: + observation_space, +): """ Get the shape of the observation (useful for the buffers). @@ -91,21 +63,20 @@ class MaskableDictRolloutBuffer: def __init__( self, - buffer_size: int, - observation_space: spaces.Space, - action_space: spaces.Space, - device: Union[torch.device, str] = "cpu", - gae_lambda: float = 1, - gamma: float = 0.99, - n_envs: int = 1, + buffer_size, + observation_space, + action_space, + device = "cpu", + gae_lambda = 1, + gamma = 0.99, + n_envs = 1, ): - self.action_masks = None self.buffer_size = buffer_size self.observation_space = observation_space self.action_space = action_space self.obs_shape = get_obs_shape(observation_space) - self.action_dim = get_action_dim(action_space) + self.action_dim = 1 self.pos = 0 self.full = False self.device = device @@ -120,20 +91,8 @@ def __init__( self.generator_ready = False self.reset() - def reset(self) -> None: - if isinstance(self.action_space, spaces.Discrete): - mask_dims = self.action_space.n - elif isinstance(self.action_space, spaces.MultiDiscrete): - mask_dims = sum(self.action_space.nvec) - elif isinstance(self.action_space, spaces.MultiBinary): - mask_dims = 2 * self.action_space.n # One mask per binary outcome - else: - raise ValueError( - f"Unsupported action space {type(self.action_space)}") - - self.mask_dims = mask_dims - self.action_masks = np.ones( - (self.buffer_size, self.n_envs, self.mask_dims)) # .to(self.device) + def reset(self): + self.mask_dims = self.action_space.n assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" self.observations = {} @@ -152,13 +111,12 @@ def reset(self) -> None: self.full = False def add(self, - obs: Dict[str, np.ndarray], - action: np.ndarray, - reward: np.ndarray, - episode_start: np.ndarray, - value: torch.Tensor, - log_prob: torch.Tensor, - action_masks: Optional[torch.Tensor] = None) -> None: + obs, + action, + reward, + episode_start, + value, + log_prob,): """ :param obs: Observation :param action: Action @@ -168,11 +126,7 @@ def add(self, following the current policy. :param log_prob: log probability of the action following the current policy. - :param action_masks: Masks applied to constrain the choice of possible actions. """ - if action_masks is not None: - self.action_masks[self.pos] = action_masks.reshape( - (self.n_envs, self.mask_dims)) if len(log_prob.shape) == 0: # Reshape 0-d tensor to avoid error @@ -195,7 +149,7 @@ def add(self, if self.pos == self.buffer_size: self.full = True - def compute_returns_and_advantage(self, last_values: torch.Tensor, dones: np.ndarray) -> None: + def compute_returns_and_advantage(self, last_values, dones): """ Post-processing step: compute the lambda-return (TD(lambda) estimate) and GAE(lambda) advantage. @@ -232,7 +186,7 @@ def compute_returns_and_advantage(self, last_values: torch.Tensor, dones: np.nda # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA self.returns = self.advantages + self.values - def swap_and_flatten(self, arr: np.ndarray) -> np.ndarray: + def swap_and_flatten(self, arr): """ Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) @@ -246,7 +200,7 @@ def swap_and_flatten(self, arr: np.ndarray) -> np.ndarray: shape = shape + (1,) return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) - def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: + def get(self, batch_size): assert self.full, "" indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data @@ -256,7 +210,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRollout self.observations[key] = self.swap_and_flatten(obs) _tensor_names = ["actions", "values", "log_probs", - "advantages", "returns", "action_masks"] + "advantages", "returns"] for tensor in _tensor_names: self.__dict__[tensor] = self.swap_and_flatten( @@ -272,7 +226,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRollout yield self._get_samples(indices[start_idx: start_idx + batch_size]) start_idx += batch_size - def to_torch(self, array: np.ndarray, copy: bool = True) -> torch.Tensor: + def to_torch(self, array, copy = True): """ Convert a numpy array to a PyTorch tensor. Note: it copies the data by default @@ -286,7 +240,7 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> torch.Tensor: return torch.tensor(array).to(self.device) return torch.as_tensor(array).to(self.device) - def _get_samples(self, batch_inds: np.ndarray, env = None) -> MaskableDictRolloutBufferSamples: + def _get_samples(self, batch_inds): return MaskableDictRolloutBufferSamples( observations={key: self.to_torch(obs[batch_inds]) for ( @@ -296,6 +250,4 @@ def _get_samples(self, batch_inds: np.ndarray, env = None) -> MaskableDictRollou old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), advantages=self.to_torch(self.advantages[batch_inds].flatten()), returns=self.to_torch(self.returns[batch_inds].flatten()), - action_masks=self.to_torch( - self.action_masks[batch_inds].reshape(-1, self.mask_dims)), ) \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/component_build_util.py b/benchmark/torch/RL4LMs/utils/component_build_util.py new file mode 100644 index 000000000..79dd20809 --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/component_build_util.py @@ -0,0 +1,52 @@ +from transformers import AutoTokenizer +from parl.utils import logger +from .reward_util import RougeRewardFunction +from .metric_util import MetricRegistry +from .data_pool import CNNDailyMail + +def build_tokenizer(tokenizer_config): + logger.info(f"loading tokenizer of [{tokenizer_config['model_name']}] model") + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_config["model_name"]) + if tokenizer.pad_token is None and tokenizer_config.get("pad_token_as_eos_token", True): + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = tokenizer_config.get( + "padding_side", "left") + tokenizer.truncation_side = tokenizer_config.get( + "truncation_side", "left") + return tokenizer + + +def build_reward_fn(reward_config): + logger.info(f"loading reward function: rouge") + reward_fn = RougeRewardFunction(**reward_config.get("args", {})) + return reward_fn + + +def build_metrics(metric_configs): + metrics = [MetricRegistry.get(metric_config["id"], metric_config.get("args", {})) + for metric_config in metric_configs] + return metrics + + +def build_datapool(datapool_config): + def _get_datapool_by_split(split): + kwargs = datapool_config.get("args", {}) + kwargs["split"] = split + logger.info(f"loading split of dataset: {datapool_config['id']} -- {kwargs['split']}") + dp_split = CNNDailyMail.prepare(**kwargs) + return dp_split + + train_datapool = _get_datapool_by_split("train") + val_datapool = _get_datapool_by_split("val") + test_datapool = _get_datapool_by_split("test") + + samples_by_split = { + "train": [(sample, weight) + for sample, weight in train_datapool], + "val": [sample for sample, _ in val_datapool], + "test": [sample for sample, _ in test_datapool] + } + return samples_by_split + + diff --git a/benchmark/torch/RL4LMs/utils/data_pool.py b/benchmark/torch/RL4LMs/utils/data_pool.py index ad7de7769..d149d0e42 100644 --- a/benchmark/torch/RL4LMs/utils/data_pool.py +++ b/benchmark/torch/RL4LMs/utils/data_pool.py @@ -1,97 +1,39 @@ from datasets import load_dataset from .data_wrapper import Sample -from typing import Any, List, Dict import random -from abc import abstractclassmethod from tqdm import tqdm from nltk.tokenize import word_tokenize -class TextGenPool: - def __init__(self, samples: List[Sample]): + + +class CNNDailyMail: + + def __init__(self, samples): self._samples = samples def __len__(self): return len(self._samples) - def __getitem__(self, ix: int) -> Sample: + def __getitem__(self, ix): if ix >= len(self): raise StopIteration sample = self._samples[ix] return sample, 1.0 - def sample(self) -> Sample: - random_sample = random.choice(self._samples) - return random_sample - - @abstractclassmethod - def prepare(cls, **args) -> 'TextGenPool': - """ - A factory method to instantiate data pool - """ - raise NotImplementedError - - def split(self, split_ratios: List[float]) -> List['TextGenPool']: - start_ix = 0 - pools = [] - for ratio in split_ratios: - count = int(len(self) * ratio) - end_ix = start_ix + count - pools.append(type(self)(self._samples[start_ix: end_ix])) - start_ix = end_ix - return pools - -class CommonGen(TextGenPool): - @classmethod - def prepare(cls, split: str, - concept_separator_token: str = " ", - concept_end_token=" ", - prefix: str = "summarize: ") -> 'TextGenPool': - ds = load_dataset("gem", "common_gen") - samples = [] - split_id = CommonGen.gen_split_name(split) - for ix, item in enumerate(ds[split_id]): - concepts = concept_separator_token.join(item["concepts"]) - concepts = prefix + concepts - concepts += concept_end_token - if item["target"] == "": - # just to avoid breaking of metric computation - item["target"] = "empty reference" - targets = [item["target"]] - sample = Sample(id=f"{split}_{ix}", - prompt_or_input_text=concepts, - references=targets, - meta_data={ - "concepts": item["concepts"] - } - ) - samples.append(sample) - pool_instance = cls(samples) - return pool_instance - - @staticmethod - def gen_split_name(split: str): - if split == "train": - split_name = "train" - elif split == "val": - split_name = "validation" - elif split == "test": - split_name = "test" - else: - raise NotImplementedError - return split_name - - - -class CNNDailyMail(TextGenPool): @classmethod def prepare(cls, - split: str, - prompt_suffix: str = "", - prompt_prefix: str = "", - truncate_article: int = None, - max_size: int = None): + split, + prompt_suffix = "", + prompt_prefix = "", + truncate_article = None, + max_size = None): + split2name = { + "train": "train", + "val": "validation", + "test": "test" + } dataset = load_dataset("cnn_dailymail", "3.0.0") - dataset_split = CommonGen.gen_split_name(split) + dataset_split = split2name[split] samples = [] for ix, item in tqdm(enumerate(dataset[dataset_split]), desc="Tokenizing dataset", @@ -113,4 +55,18 @@ def prepare(cls, break pool_instance = cls(samples) - return pool_instance \ No newline at end of file + return pool_instance + + def sample(self): + random_sample = random.choice(self._samples) + return random_sample + + def split(self, split_ratios): + start_ix = 0 + pools = [] + for ratio in split_ratios: + count = int(len(self) * ratio) + end_ix = start_ix + count + pools.append(type(self)(self._samples[start_ix: end_ix])) + start_ix = end_ix + return pools \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/data_wrapper.py b/benchmark/torch/RL4LMs/utils/data_wrapper.py index 234b42c5a..f69ee23c9 100644 --- a/benchmark/torch/RL4LMs/utils/data_wrapper.py +++ b/benchmark/torch/RL4LMs/utils/data_wrapper.py @@ -1,13 +1,16 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Dict, List from transformers import AutoTokenizer from copy import deepcopy -from .type_wrapper import TensorDict from typing import NamedTuple import torch import numpy as np +from typing import Any, Union + +TensorDict = Dict[Union[str, int], torch.Tensor] + @dataclass class TransitionInfo: @@ -22,7 +25,6 @@ class TransitionInfo: done: np.ndarray ref_log_prob: torch.Tensor kl_reward: np.ndarray - action_mask: np.ndarray info: Dict[str, Any] @@ -33,7 +35,6 @@ class MaskableDictRolloutBufferSamples(NamedTuple): old_log_prob: torch.Tensor advantages: torch.Tensor returns: torch.Tensor - action_masks: torch.Tensor @dataclass(init=True) @@ -156,7 +157,7 @@ class Observation: # other meta info meta_info: Dict[str, Any] - def to_dict(self) -> Dict[str, torch.tensor]: + def to_dict(self): """ For stable baselines (only return tensor items) """ diff --git a/benchmark/torch/RL4LMs/utils/distribution_wrapper.py b/benchmark/torch/RL4LMs/utils/distribution_wrapper.py index bcb5bca5a..e5824e239 100644 --- a/benchmark/torch/RL4LMs/utils/distribution_wrapper.py +++ b/benchmark/torch/RL4LMs/utils/distribution_wrapper.py @@ -15,7 +15,7 @@ def __init__(self, action_dim: int): super().__init__() self.action_dim = action_dim - def proba_distribution_net(self, latent_dim: int) -> nn.Module: + def proba_distribution_net(self, latent_dim): """ Create the layer that represents the distribution: it will be the logits of the Categorical distribution. @@ -28,35 +28,35 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module: action_logits = nn.Linear(latent_dim, self.action_dim) return action_logits - def proba_distribution(self, action_logits: torch.Tensor) -> "CategoricalDistribution": + def proba_distribution(self, action_logits: torch.Tensor): self.distribution = Categorical(logits=action_logits) return self - def log_prob(self, actions: torch.Tensor) -> torch.Tensor: + def log_prob(self, actions): return self.distribution.log_prob(actions) - def entropy(self) -> torch.Tensor: + def entropy(self): return self.distribution.entropy() - def sample(self) -> torch.Tensor: + def sample(self): return self.distribution.sample() - def mode(self) -> torch.Tensor: + def mode(self): return torch.argmax(self.distribution.probs, dim=1) - def actions_from_params(self, action_logits: torch.Tensor, deterministic: bool = False) -> torch.Tensor: + def actions_from_params(self, action_logits, deterministic = False): # Update the proba distribution self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, action_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def log_prob_from_params(self, action_logits): actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob - def get_actions(self, deterministic: bool = False) -> torch.Tensor: + def get_actions(self, deterministic = False): """ Return actions according to the probability distribution. diff --git a/benchmark/torch/RL4LMs/utils/evaluation_util.py b/benchmark/torch/RL4LMs/utils/evaluation_util.py index 0a86b0f09..f69e16e4a 100644 --- a/benchmark/torch/RL4LMs/utils/evaluation_util.py +++ b/benchmark/torch/RL4LMs/utils/evaluation_util.py @@ -4,11 +4,10 @@ from transformers import AutoTokenizer from . import Sample -from benchmark.torch.RL4LMs.metrics import BaseMetric from parl.utils import logger -def get_batch(samples: List[Sample], batch_size: int): +def get_batch(samples, batch_size): current_ix = 0 n_samples = len(samples) while current_ix < n_samples: @@ -17,17 +16,18 @@ def get_batch(samples: List[Sample], batch_size: int): current_ix += batch_size + def evaluate_on_samples( policy, - tokenizer: AutoTokenizer, - samples: List[Sample], - batch_size: int, - max_prompt_length: int, - metrics: List[BaseMetric], - epoch: int, - split_name: str, - dt_control_token: str = "", - gen_kwargs: Dict[str, Any] = None, + tokenizer, + samples, + batch_size, + max_prompt_length, + metrics, + epoch, + split_name, + dt_control_token = "", + gen_kwargs = None, ): # generate text by batch all_generated_texts = [] @@ -110,11 +110,11 @@ def evaluate_on_samples( def generate_text( policy, - tokenizer: AutoTokenizer, - samples: List[Sample], - max_prompt_length: int, - dt_control_token: str, - gen_kwargs: Dict[str, Any], + tokenizer, + samples, + max_prompt_length, + dt_control_token, + gen_kwargs, ): prompt_texts = [ dt_control_token + sample.prompt_or_input_text for sample in samples diff --git a/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py b/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py index 421510446..20426fb0a 100644 --- a/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py +++ b/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py @@ -50,83 +50,11 @@ StoppingCriteriaList, validate_stopping_criteria, ) -from transformers.pytorch_utils import torch_int_div from transformers.utils import ModelOutput, logging logger = logging.get_logger(__name__) - -@dataclass -class GreedySearchDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using greedy search. - - - Args: - sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each - tensor of shape `(batch_size, config.vocab_size)`). - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - scores: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -@dataclass -class GreedySearchEncoderDecoderOutput(ModelOutput): - """ - Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention - weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the - encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) - - - Args: - sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape - `(batch_size, config.vocab_size)`). - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, - sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - scores: Optional[Tuple[torch.FloatTensor]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - @dataclass class SampleDecoderOnlyOutput(ModelOutput): """ @@ -199,183 +127,7 @@ class SampleEncoderDecoderOutput(ModelOutput): decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None -@dataclass -class BeamSearchDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using beam search. - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting - of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape - `(batch_size*num_beams*num_return_sequences, config.vocab_size)`). - beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped - tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - sequences_scores: Optional[torch.FloatTensor] = None - scores: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None - attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -@dataclass -class BeamSearchEncoderDecoderOutput(ModelOutput): - """ - Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights - of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states - attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting - of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, - config.vocab_size)`). - beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped - tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, - sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, - sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - sequences_scores: Optional[torch.FloatTensor] = None - scores: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -@dataclass -class BeamSampleDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using beam sample. - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting - of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape - `(batch_size*num_beams*num_return_sequences, config.vocab_size)`). - beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped - tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - sequences_scores: Optional[torch.FloatTensor] = None - scores: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None - attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -@dataclass -class BeamSampleEncoderDecoderOutput(ModelOutput): - """ - Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention - weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the - encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting - of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, - config.vocab_size)`). - beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped - tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, - sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size*num_beams, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - sequences_scores: Optional[torch.FloatTensor] = None - scores: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, - GreedySearchDecoderOnlyOutput] SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] -BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, - BeamSearchDecoderOnlyOutput] -BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, - BeamSampleDecoderOnlyOutput] class GenerationMixinWithRawScores: @@ -905,7 +657,7 @@ def generate( synced_gpus: Optional[bool] = False, exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, **model_kwargs, - ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: + ): r""" Generates sequences of token ids for models with a language modeling head. The method supports the following @@ -1066,18 +818,12 @@ def generate( If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`~utils.ModelOutput`] types are: - - [`~generation_utils.GreedySearchDecoderOnlyOutput`], - [`~generation_utils.SampleDecoderOnlyOutput`], - - [`~generation_utils.BeamSearchDecoderOnlyOutput`], - - [`~generation_utils.BeamSampleDecoderOnlyOutput`] If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible [`~utils.ModelOutput`] types are: - - [`~generation_utils.GreedySearchEncoderDecoderOutput`], - [`~generation_utils.SampleEncoderDecoderOutput`], - - [`~generation_utils.BeamSearchEncoderDecoderOutput`], - - [`~generation_utils.BeamSampleEncoderDecoderOutput`] Examples: @@ -1590,41 +1336,6 @@ def sample( -def top_k_top_p_filtering( - logits: torch.FloatTensor, - top_k: int = 0, - top_p: float = 1.0, - filter_value: float = -float("Inf"), - min_tokens_to_keep: int = 1, -) -> torch.FloatTensor: - """ - Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - - Args: - logits: logits distribution shape (batch size, vocabulary size) - top_k (`int`, *optional*, defaults to 0): - If > 0, only keep the top k tokens with highest probability (top-k filtering) - top_p (`float`, *optional*, defaults to 1.0): - If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus - filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimumber of tokens we keep per batch example in the output. - - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 - """ - if top_k > 0: - logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( - None, logits - ) - - if 0 <= top_p <= 1.0: - logits = TopPLogitsWarper( - top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) - - return logits - - - def override_generation_routines(cls): bases = list(cls.__bases__) for base_ix in range(len(bases)): diff --git a/benchmark/torch/RL4LMs/utils/kl_controller.py b/benchmark/torch/RL4LMs/utils/kl_controller.py index 377d196aa..20b1f7034 100644 --- a/benchmark/torch/RL4LMs/utils/kl_controller.py +++ b/benchmark/torch/RL4LMs/utils/kl_controller.py @@ -3,7 +3,7 @@ class KLController: - def __init__(self, kl_coeff: float, target_kl: Optional[float] = None) -> None: + def __init__(self, kl_coeff, target_kl = None): self._kl_coeff = kl_coeff self._target_kl = target_kl diff --git a/benchmark/torch/RL4LMs/utils/metric_util.py b/benchmark/torch/RL4LMs/utils/metric_util.py new file mode 100644 index 000000000..374ca3a3c --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/metric_util.py @@ -0,0 +1,177 @@ +import torch +import numpy as np +from datasets import load_metric +from gem_metrics.msttr import MSTTR +from gem_metrics.ngrams import NGramStats +from gem_metrics.texts import Predictions +from parl.utils import logger + + + +class MeteorMetric: + def __init__(self): + super().__init__() + self._metric = load_metric("meteor") + + def compute( + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos = None, + model = None, + split_name = None, + ): + + score = self._metric.compute( + predictions=generated_texts, references=reference_texts + )["meteor"] + + metric_dict = {"lexical/meteor": (None, score)} + return metric_dict + + +class RougeMetric: + def __init__(self, use_single_ref = True): + super().__init__() + self._metric = load_metric("rouge") + self._use_single_ref = use_single_ref + + def compute( + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos = None, + model = None, + split_name = None, + ): + if self._use_single_ref: + # TBD: this is required for CNN/DM dataset, without this we get low scores + # TBD: needs investigation + ref_texts = [ref[0] for ref in reference_texts] + else: + ref_texts = reference_texts + + metric_results = self._metric.compute( + predictions=generated_texts, references=ref_texts, use_stemmer=True + ) + score_keys = ["rouge1", "rouge2", "rougeL", "rougeLsum"] + metric_dict = {} + for rouge_type in score_keys: + rouge_score = metric_results[rouge_type].mid.fmeasure + metric_dict[f"lexical/rouge_{rouge_type}"] = (None, rouge_score) + return metric_dict + + +class BERTScoreMetric: + def __init__(self, language): + super().__init__() + self._metric = load_metric("bertscore") + self._language = language + # since models are loaded heavily on cuda:0, use the last one to avoid memory + self._last_gpu = f"cuda:{torch.cuda.device_count() - 1}" + + def compute( + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos = None, + model = None, + split_name = None, + ): + with torch.no_grad(): + metric_results = self._metric.compute( + predictions=generated_texts, + references=reference_texts, + lang=self._language, + device=self._last_gpu, + ) + bert_scores = metric_results["f1"] + corpus_level_score = np.mean(bert_scores) + metric_dict = {"semantic/bert_score": (bert_scores, corpus_level_score)} + return metric_dict + + +class BLEUMetric: + def __init__(self): + super().__init__() + self._metric = load_metric("bleu") + + def compute( + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos = None, + model = None, + split_name = None, + ): + + tokenized_predictions = [] + tokenized_reference_texts = [] + for prediction, refs in zip(generated_texts, reference_texts): + tokenized_prediction = prediction.split() + tokenized_refs = [ref.split() for ref in refs] + tokenized_predictions.append(tokenized_prediction) + tokenized_reference_texts.append(tokenized_refs) + + try: + metric_results = self._metric.compute( + predictions=tokenized_predictions, references=tokenized_reference_texts + ) + bleu_score = metric_results["bleu"] + metric_dict = {"lexical/bleu": (None, bleu_score)} + return metric_dict + except Exception as e: + return {"lexical/bleu": (None, "n/a")} + + +class DiversityMetrics: + def __init__(self, window_size = 100): + self._msttr_metric = MSTTR(window_size=window_size) + self._n_gram_metric = NGramStats() + + def compute( + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos = None, + model = None, + split_name = None, + ): + + predictions = Predictions(data={"filename": "", "values": generated_texts}) + diversity_metrics = {} + msttr_metrics = self._msttr_metric.compute(None, predictions) + n_gram_metrics = self._n_gram_metric.compute(None, predictions) + + for key, value in msttr_metrics.items(): + diversity_metrics[f"diversity_metrics/{key}"] = (None, value) + for key, value in n_gram_metrics.items(): + diversity_metrics[f"diversity_metrics/{key}"] = (None, value) + + return diversity_metrics + + +class MetricRegistry: + _registry = { + "meteor": MeteorMetric, + "rouge": RougeMetric, + "bert_score": BERTScoreMetric, + "bleu": BLEUMetric, + "diversity": DiversityMetrics, + } + + @classmethod + def get(cls, metric_id, kwargs): + logger.info(f"loading metric: {metric_id}") + metric_cls = cls._registry[metric_id] + metric = metric_cls(**kwargs) + return metric + + @classmethod + def add(cls, id, metric_cls): + MetricRegistry._registry[id] = metric_cls diff --git a/benchmark/torch/RL4LMs/utils/reward_util.py b/benchmark/torch/RL4LMs/utils/reward_util.py index 7ab5da1bc..a8d4ca77f 100644 --- a/benchmark/torch/RL4LMs/utils/reward_util.py +++ b/benchmark/torch/RL4LMs/utils/reward_util.py @@ -1,158 +1,24 @@ -from abc import ABC, abstractclassmethod - -import torch from datasets import load_metric -from .data_wrapper import Observation -from transformers import AutoModelForSequenceClassification, AutoTokenizer -from benchmark.torch.RL4LMs.metrics import ( - MeteorMetric, - BERTScoreMetric, - BLEUMetric, - RougeLMax, -) -import numpy as np -from typing import List, Dict, Any - - -class RewardFunction(ABC): - @abstractclassmethod - def __call__( - self, - current_observation: Observation, - action: int, - next_observation: Observation, - done: bool, - meta_info: Dict[str, Any] = None, - ) -> float: - """ - Callable for reward functions for text generation - - Args: - current_observation (Observation): previous observation (s) - action (int): action performed (a) at s - next_observation (Observation): observation after the action was performed (s') - done (bool): whether the episode is finished or not - meta_info (dict) - other information regarding textual sample - Returns: - float: scalar reward - """ - raise NotImplementedError - - -class BatchedRewardFunction(ABC): - """ - Computes rewards for several instances at once - """ - - @abstractclassmethod - def __call__( - self, - prompt_texts: List[str], - gen_texts: List[str], - ref_texts: List[List[str]], - dones: List[bool], - meta_infos: List[Dict[str, Any]] = None, - ) -> List[float]: - """ - An abstract class for batched reward functions for text generation - """ - raise NotImplementedError - - -### Automated reward functions ########################### - - -class CommonGenPenaltyShapingFunction(RewardFunction): - def __call__( - self, - current_observation: Observation, - action: int, - next_observation: Observation, - done: bool, - meta_info: Dict[str, Any] = None, - ) -> float: - if done: - prompt_text = next_observation.prompt_or_input_text - prefix = "generate a sentence with: " - concept_n_grams = prompt_text.split(prefix)[1][:-1] - if ( - concept_n_grams.lower() in next_observation.context_text.lower() - or prefix in next_observation.context_text.lower() - or "generate" in next_observation.context_text.lower() - or "sentence" in next_observation.context_text.lower() - ): - penalty_score = -1 - else: - penalty_score = 0 - return penalty_score - return 0 - - - - - -class MeteorRewardFunction(RewardFunction): - def __init__(self, shaping_fn: str = None) -> None: - super().__init__() - self._metric = MeteorMetric() - from benchmark.torch.RL4LMs.registry import RewardFunctionRegistry - - self._shaping_fn = ( - RewardFunctionRegistry.get(shaping_fn, {}) - if shaping_fn is not None - else shaping_fn - ) - - def __call__( - self, - current_observation: Observation, - action: int, - next_observation: Observation, - done: bool, - meta_info: Dict[str, Any] = None, - ) -> float: - - # compute meteor at the end of episode - if done: - references = [next_observation.target_or_reference_texts] - predicted = [next_observation.context_text] - metric_dict = self._metric.compute(None, predicted, references) - score = metric_dict["lexical/meteor"][1] - - if self._shaping_fn is not None: - aux_score = self._shaping_fn( - current_observation, action, next_observation, done, meta_info - ) - score = score + aux_score - return score - return 0 - - -class RougeRewardFunction(RewardFunction): +class RougeRewardFunction: def __init__( - self, rouge_type: str, shaping_fn: str = None, use_single_ref: bool = True - ) -> None: + self, rouge_type, use_single_ref = True + ): super().__init__() self._metric = load_metric("rouge") self._rouge_type = rouge_type - from benchmark.torch.RL4LMs.registry import RewardFunctionRegistry - self._shaping_fn = ( - RewardFunctionRegistry.get(shaping_fn, {}) - if shaping_fn is not None - else shaping_fn - ) + self._shaping_fn = None self._use_single_ref = use_single_ref def __call__( self, - current_observation: Observation, - action: int, - next_observation: Observation, - done: bool, - meta_info: Dict[str, Any] = None, - ) -> float: + current_observation, + action, + next_observation, + done, + meta_info = None, + ): if done: # TBD: considers only one reference for now if self._use_single_ref: @@ -171,276 +37,4 @@ def __call__( ) reward = reward + aux_score return reward - return 0 - - -class RougeCombined(RewardFunction): - def __init__(self, shaping_fn: str = None) -> None: - super().__init__() - self._metric = load_metric("rouge") - from benchmark.torch.RL4LMs.registry import RewardFunctionRegistry - - self._shaping_fn = ( - RewardFunctionRegistry.get(shaping_fn, {}) - if shaping_fn is not None - else shaping_fn - ) - - def __call__( - self, - current_observation: Observation, - action: int, - next_observation: Observation, - done: bool, - meta_info: Dict[str, Any] = None, - ) -> float: - if done: - # TBD: considers only one reference for now - references = [next_observation.target_or_reference_texts[0]] - predicted = [next_observation.context_text] - - metric_results = self._metric.compute( - predictions=predicted, references=references, use_stemmer=True - ) - - rouge_keys = ["rouge1", "rouge2", "rougeL"] - scores = [ - metric_results[rouge_type].mid.fmeasure for rouge_type in rouge_keys - ] - reward = np.mean(scores) - if self._shaping_fn is not None: - aux_score = self._shaping_fn( - current_observation, action, next_observation, done, meta_info - ) - reward = reward + aux_score - return reward - return 0 - - -class BERTScoreRewardFunction(RewardFunction): - def __init__(self, language: str = "en") -> None: - super().__init__() - self._metric = BERTScoreMetric(language) - - def __call__( - self, - current_observation: Observation, - action: int, - next_observation: Observation, - done: bool, - meta_info: Dict[str, Any] = None, - ) -> float: - if done: - references = [next_observation.target_or_reference_texts] - predicted = [next_observation.context_text] - metric_results = self._metric.compute(None, predicted, references) - bert_score = metric_results["semantic/bert_score"][1] - return bert_score - return 0 - - -class BLEURewardFunction(RewardFunction): - def __init__(self) -> None: - super().__init__() - self._metric = BLEUMetric() - - def __call__( - self, - current_observation: Observation, - action: int, - next_observation: Observation, - done: bool, - meta_info: Dict[str, Any] = None, - ) -> float: - if done: - references = [next_observation.target_or_reference_texts] - predicted = [next_observation.context_text] - metric_results = self._metric.compute(None, predicted, references) - bleu_score = metric_results["lexical/bleu"][1] - return bleu_score - return 0 - - -class SacreBleu(RewardFunction): - def __init__(self, **args) -> None: - super().__init__() - self._metric = load_metric("sacrebleu") - self._args = args - - def __call__( - self, - current_observation: Observation, - action: int, - next_observation: Observation, - done: bool, - meta_info: Dict[str, Any] = None, - ) -> float: - if done: - references = [next_observation.target_or_reference_texts] - predicted = [next_observation.context_text] - metric_results = self._metric.compute( - predictions=predicted, references=references, **self._args - ) - return metric_results["score"] / 100 - return 0 - - - - -############################################################################# - -########## Learned Reward Functions########################################## - - -class LearnedRewardFunction(RewardFunction): - def __init__( - self, model_name: str, label_ix: int, include_prompt_for_eval: bool = True - ) -> None: - super().__init__() - self._device = "cuda" if torch.cuda.is_available() else "cpu" - self._metric_tokenizer = AutoTokenizer.from_pretrained(model_name) - self._metric_tokenizer.truncation_side = "left" - self._metric_model = AutoModelForSequenceClassification.from_pretrained( - model_name - ).to(self._device) - self._label_ix = label_ix - self._include_prompt_for_eval = include_prompt_for_eval - - def __call__( - self, - current_observation: Observation, - action: int, - next_observation: Observation, - done: bool, - meta_info: Dict[str, Any] = None, - ) -> float: - if done: - generated_text = ( - current_observation.prompt_or_input_text - if self._include_prompt_for_eval - else "" - ) - generated_text += next_observation.context_text - - with torch.no_grad(): - encoded = self._metric_tokenizer( - generated_text, return_tensors="pt", truncation=True, padding=True - ) - outputs = self._metric_model( - input_ids=encoded.input_ids.to(self._device), - attention_mask=encoded.attention_mask.to(self._device), - ) - scores = torch.softmax(outputs.logits.flatten(), dim=0) - score = scores[self._label_ix].item() - return score - return 0 - - -class BLEURTRewardFunction(RewardFunction): - def __init__(self, checkpoint: str = None): - super().__init__() - self._metric = load_metric("bleurt", checkpoint=checkpoint) - - def __call__( - self, - current_observation: Observation, - action: int, - next_observation: Observation, - done: bool, - meta_info: Dict[str, Any] = None, - ) -> float: - if done: - references = [next_observation.target_or_reference_texts] - predicted = [next_observation.context_text] - metric_results = self._metric.compute( - predictions=predicted, references=references - ) - score = metric_results["scores"][0] - return score - return 0 - - -# class PARENTRewardFunction(RewardFunction): -# """ -# PARENT F1 score as the reward -# """ -# -# def __init__(self) -> None: -# super().__init__() -# self._metric = ParentToTTo() -# -# def __call__( -# self, -# current_observation: Observation, -# action: int, -# next_observation: Observation, -# done: bool, -# meta_info: Dict[str, Any] = None, -# ) -> float: -# if done: -# generated_texts = [next_observation.context_text] -# meta_infos = [meta_info] -# scores = self._metric.compute(None, generated_texts, None, meta_infos) -# reward = scores["table_to_text/parent_overall_f_score"][0][0] -# return reward -# return 0 - - -class RougeLMaxRewardFunction(RewardFunction): - def __init__(self, **args) -> None: - super().__init__() - self._metric = RougeLMax(**args) - - def __call__( - self, - current_observation: Observation, - action: int, - next_observation: Observation, - done: bool, - meta_info: Dict[str, Any] = None, - ) -> float: - if done: - references = [next_observation.target_or_reference_texts] - predicted = [next_observation.context_text] - meta_infos = [meta_info] - scores = self._metric.compute(None, predicted, references, meta_infos) - reward = scores["lexical/rouge_l_max"][0][0] - return reward - return 0 - - - - -if __name__ == "__main__": - predictions = "hello there general kenobi" - references = ["hello there general kenobi", "hello there!!"] - observation = Observation( - None, None, None, None, None, predictions, references, None, None, None, None - ) - - reward_fn = MeteorRewardFunction() - print(reward_fn(None, None, observation, True)) - - # reward_fn = chrF() - # print(reward_fn(None, None, observation, True)) - - reward_fn = RougeCombined() - print(reward_fn(None, None, observation, True)) - - reward_fn = RougeRewardFunction(rouge_type="rouge1") - print(reward_fn(None, None, observation, True)) - - reward_fn = RougeRewardFunction(rouge_type="rouge2") - print(reward_fn(None, None, observation, True)) - - reward_fn = RougeRewardFunction(rouge_type="rougeL") - print(reward_fn(None, None, observation, True)) - - reward_fn = BERTScoreRewardFunction(language="en") - print(reward_fn(None, None, observation, True)) - - reward_fn = BLEURewardFunction() - print(reward_fn(None, None, observation, True)) - - reward_fn = BLEURTRewardFunction() - print(reward_fn(None, None, observation, True)) + return 0 \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/rollout_util.py b/benchmark/torch/RL4LMs/utils/rollout_util.py new file mode 100644 index 000000000..b1a1a228c --- /dev/null +++ b/benchmark/torch/RL4LMs/utils/rollout_util.py @@ -0,0 +1,281 @@ +import torch +import numpy as np +from .data_wrapper import TransitionInfo +from .kl_controller import KLController + + +def dict_to_tensor(obs, device): + return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} + +def unpack_observations(obs_tensor, n_envs): + """ + Unpacks vectorized dict observations into separate dict observations + """ + unpacked_obs = [] + keys = obs_tensor.keys() + for env_ix in range(n_envs): + obs_dict = {} + for key in keys: + obs_dict[key] = obs_tensor[key][env_ix].reshape(1, -1).cpu() + unpacked_obs.append(obs_dict) + return unpacked_obs + + +def add_to_buffer( + rollout_buffer, episode_wise_transitions, rollout_info +): + advantages_computed = False + for ep_ix, transitions in enumerate(episode_wise_transitions): + ep_length = len(transitions) + total_reward = 0.0 + total_kl_reward = 0.0 + for transition_ix, transition in enumerate(transitions): + total_reward += transition.task_reward + total_kl_reward += transition.kl_reward + rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div) + rollout_info["rollout_info/log_prob"].append(transition.log_prob) + rollout_info["rollout_info/ref_log_prob"].append( + transition.ref_log_prob + ) + rollout_info["rollout_info/values"].append(transition.value.numpy()) + + if not rollout_buffer.full: + rollout_buffer.add( + transition.observation, + transition.action, + transition.total_reward, + transition.episode_start, + transition.value, + transition.log_prob, + ) + + # if the buffer is full, compute advantages + if rollout_buffer.full and not advantages_computed: + # we fetch the last value for the last time step + # values come from the next transitions's values + next_values = ( + transitions[transition_ix + 1].value + if (transition_ix + 1) < ep_length + else torch.tensor([0.0]) + ) + + rollout_buffer.compute_returns_and_advantage( + last_values=next_values, dones=transition.done + ) + advantages_computed = True + + rollout_info["rollout_info/ep_rew"].append(total_reward) + rollout_info["rollout_info/ep_lens"].append(ep_length) + rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward) + return rollout_info + + +class RolloutUtil: + def __init__(self, kl_args): + self._kl_controller = KLController(kl_args["coeff"], + kl_args["target_kl"]) + + + def _generate_batch( + self, + agent=None, + env=None, + rollout_buffer=None, + tokenizer=None, + rollout_info=None, + device=None + ): + num_timesteps = 0 + # if rollout buffer is already full, do not continue + if rollout_buffer.full: + return + + # start parallel episodes + current_obs = env.reset() + episode_starts = np.ones((env.num_envs,), dtype=bool) + + # generate text using the model + obs_tensor = dict_to_tensor(current_obs, device) + generation_inputs = agent.get_inputs_for_generation(obs_tensor) + gen_output = agent.generate( + input_ids=generation_inputs.inputs, + attention_mask=generation_inputs.attention_masks, + tokenizer=tokenizer, + ) + + # process them one step at a time to collect rollout info + episode_wise_transitions = [[] for _ in range(env.num_envs)] + ep_terminated = np.zeros((env.num_envs,), dtype=bool) + value_past_state = None + ref_past_state = None + policy_past_state = None + + for actions_tensor, _ in zip( + gen_output.step_wise_actions, gen_output.step_wise_logprobs + ): + # if all episodes are done, just break and do not continue + if np.all(ep_terminated): + break + + # evaluate actions with actions from rollout + with torch.no_grad(): + obs_tensor = dict_to_tensor(current_obs, device) + + # get log probs (TBD: generalize this a bit) + policy_kwargs = { + "obs": obs_tensor, + "actions": actions_tensor, + "past_model_kwargs": policy_past_state, + } + + policy_outputs = agent.forward_policy( + **policy_kwargs + ) + raw_log_probs, log_probs, policy_past_state = ( + policy_outputs.raw_log_probs, + policy_outputs.log_probs, + policy_outputs.past_model_kwargs, + ) + + # sanity check + assert torch.all( + torch.isfinite(log_probs) + ), "Infinite values in log probs" + + # sanity check + assert torch.all( + torch.isfinite(raw_log_probs) + ), "Infinite values in log probs" + + # get values + value_outputs = agent.forward_value( + obs_tensor, value_past_state + ) + values, value_past_state = ( + value_outputs.values, + value_outputs.past_model_kwargs, + ) + + # get reference log probs + ref_policy_outputs = ( + agent.get_log_probs_ref_model( + obs_tensor, actions_tensor, ref_past_state + ) + ) + ref_log_probs, ref_past_state = ( + ref_policy_outputs.log_probs, + ref_policy_outputs.past_model_kwargs, + ) + + # sanity check + assert torch.all( + torch.isfinite(ref_log_probs) + ), "Infinite values in log probs" + + # compute KL rewards + kl_div = raw_log_probs - ref_log_probs + kl_rewards = -1 * self._kl_controller.kl_coeff * kl_div + + # step into env to get rewards + actions = actions_tensor.cpu().numpy() + new_obs, rewards, dones, infos = env.step(actions) + + num_timesteps += env.num_envs + + # compute total rewards + total_rewards = rewards + kl_rewards.cpu().numpy() + + # unpack individual observations + unpacked_obs = unpack_observations(obs_tensor, env.num_envs) + + # store episode wise transitions separately + for env_ix in range(env.num_envs): + # only if not terminated already + if not ep_terminated[env_ix]: + transtion = TransitionInfo( + observation=unpacked_obs[env_ix], + action=actions[env_ix], + task_reward=rewards[env_ix], + total_reward=total_rewards[env_ix], + kl_div=kl_div.cpu().numpy()[env_ix], + episode_start=episode_starts[env_ix], + value=values[env_ix].cpu(), + log_prob=log_probs[env_ix].cpu(), + done=dones[env_ix], + ref_log_prob=ref_log_probs[env_ix].cpu(), + kl_reward=kl_rewards.cpu().numpy()[env_ix], + info=infos[env_ix], + ) + + episode_wise_transitions[env_ix].append(transtion) + + # mark this episode to terminated if done occurs once + if dones[env_ix]: + ep_terminated[env_ix] = True + + episode_starts = np.zeros((env.num_envs,), dtype=bool) + current_obs = new_obs + + # now we flush all episode wise info to the 1-D buffer + rollout_info = add_to_buffer( + rollout_buffer, episode_wise_transitions, rollout_info + ) + return rollout_info, num_timesteps + + + def collect_rollouts( + self, + agent, + env, + rollout_buffer, + device + ): + used_timesteps = 0 + # get tokenizer + tokenizer = env.get_attr("tokenizer", [0]) + tokenizer = tokenizer[0] + + # Switch to eval mode + # self._agent.alg.model.set_training_mode(False) + agent.eval_mode() + + # reset rollout buffer and stats + rollout_buffer.reset() + + # start the rollout process + rollout_info = { + "rollout_info/ep_rew": [], + "rollout_info/kl_div_mean": [], + "rollout_info/ep_lens": [], + "rollout_info/ep_kl_rew": [], + "rollout_info/log_prob": [], + "rollout_info/ref_log_prob": [], + "rollout_info/values": [], + } + num_timesteps = 0 + while not rollout_buffer.full: + # generate batch of rollouts + rollout_info, run_timestamps = self._generate_batch( + agent=agent, + env=env, + rollout_buffer=rollout_buffer, + tokenizer=tokenizer, + rollout_info=rollout_info, + device=device + ) + num_timesteps += run_timestamps + + # aggregate rollout info + aggregated_rollout_info = {} + for key, values in rollout_info.items(): + aggregated_rollout_info[key] = np.mean(values).item() + aggregated_rollout_info[f"{key}_std"] = np.std(values).item() + aggregated_rollout_info[ + "rollout_info/kl_coeff" + ] = self._kl_controller.kl_coeff + + # adapt the KL coeff + self._kl_controller.step( + torch.tensor(aggregated_rollout_info["rollout_info/kl_div_mean"]) + ) + return num_timesteps \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/sample_util.py b/benchmark/torch/RL4LMs/utils/sample_util.py index d403fd741..097539b32 100644 --- a/benchmark/torch/RL4LMs/utils/sample_util.py +++ b/benchmark/torch/RL4LMs/utils/sample_util.py @@ -1,8 +1,6 @@ from collections import deque -from typing import Any, List import numpy as np - class PrioritySampler: def __init__(self, max_size: int = None, priority_scale: float = 0.0): """ @@ -17,11 +15,11 @@ def __init__(self, max_size: int = None, priority_scale: float = 0.0): self.item_priorities = deque(maxlen=self.max_size) self.priority_scale = priority_scale - def add(self, item: Any, priority: float): + def add(self, item, priority: float): self.items.append(item) self.item_priorities.append(priority) - def sample(self, size: int) -> List[Any]: + def sample(self, size: int): min_sample_size = min(len(self.items), size) scaled_item_priorities = np.array( self.item_priorities) ** self.priority_scale @@ -30,11 +28,11 @@ def sample(self, size: int) -> List[Any]: a=self.items, p=sample_probs, size=min_sample_size) return samples - def update(self, item: Any, priority: float): + def update(self, item, priority): index = self.items.index(item) del self.items[index] del self.item_priorities[index] self.add(item, priority) - def get_all_samples(self) -> List[Any]: + def get_all_samples(self): return self.items diff --git a/benchmark/torch/RL4LMs/utils/type_wrapper.py b/benchmark/torch/RL4LMs/utils/type_wrapper.py deleted file mode 100644 index 17f81ddd8..000000000 --- a/benchmark/torch/RL4LMs/utils/type_wrapper.py +++ /dev/null @@ -1,7 +0,0 @@ -from typing import Any, Dict, Optional, List, Union, Callable -import torch - - -# refer to stable_baselines3.common.type_aliases -TensorDict = Dict[Union[str, int], torch.Tensor] -Schedule = Callable[[float], float] From 0b69359f0ff665b0f4bbc4a0273fd0ee5f577a3c Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Tue, 7 Mar 2023 16:05:36 +0800 Subject: [PATCH 07/34] remove distribution_wrapper.py and sample_util.py --- benchmark/torch/RL4LMs/env/text_gen_env.py | 24 +-- benchmark/torch/RL4LMs/env/vec_env.py | 67 ++------ benchmark/torch/RL4LMs/rl4lm_ppo.py | 46 +++--- benchmark/torch/RL4LMs/rl4lms_agent.py | 15 +- benchmark/torch/RL4LMs/seq2seq_model.py | 143 +++++++----------- benchmark/torch/RL4LMs/train.py | 8 +- benchmark/torch/RL4LMs/utils/__init__.py | 6 +- benchmark/torch/RL4LMs/utils/buffer.py | 30 ++-- benchmark/torch/RL4LMs/utils/data_wrapper.py | 2 +- .../RL4LMs/utils/distribution_wrapper.py | 68 --------- .../torch/RL4LMs/utils/evaluation_util.py | 9 +- .../utils/huggingface_generation_util.py | 46 +----- benchmark/torch/RL4LMs/utils/rollout_util.py | 19 +-- benchmark/torch/RL4LMs/utils/sample_util.py | 38 ----- 14 files changed, 128 insertions(+), 393 deletions(-) delete mode 100644 benchmark/torch/RL4LMs/utils/distribution_wrapper.py delete mode 100644 benchmark/torch/RL4LMs/utils/sample_util.py diff --git a/benchmark/torch/RL4LMs/env/text_gen_env.py b/benchmark/torch/RL4LMs/env/text_gen_env.py index 7f9d2c9b7..79287e340 100644 --- a/benchmark/torch/RL4LMs/env/text_gen_env.py +++ b/benchmark/torch/RL4LMs/env/text_gen_env.py @@ -4,8 +4,9 @@ from gym import Env, spaces from gym.spaces.dict import Dict as DictSpace from gym.spaces.discrete import Discrete -from benchmark.torch.RL4LMs.utils import Sample, Observation, PrioritySampler -from transformers import AutoTokenizer +from benchmark.torch.RL4LMs.utils import Sample, Observation +from collections import deque +import numpy as np class TextGenEnv(Env): @@ -15,12 +16,12 @@ def __init__( reward_function, samples, max_episode_length = 512, - priority_scale = 0.0, max_prompt_length = None, terminate_on_eos = False, context_start_token = None, prompt_truncation_side = "left", ): + """ A generic RL environment to generate textual sequences. For eg: text generation, summarization, machine translation, text simplification @@ -29,7 +30,6 @@ def __init__( reward_function (RewardFunction): reward functiom samples (Tuple[List[Sample], float]): list of samples max_episode_length (int, optional): Max steps to the model Defaults to 512. - priority_scale (float, optional): weight for the priority sampler Defaults to 0.0. max_prompt_length (Optional[int], optional): maximum prompt length. Defaults to None. terminate_on_eos (bool, optional): whether to terminate on EOS. Defaults to False. context_start_token (bool, optional): start token for the context (For Encoder-Decoder models! ) @@ -82,9 +82,9 @@ def __init__( elif 't5' in self.tokenizer.name_or_path: n = 32128 self.action_space = Discrete(n=n) - self.sampler_for_replaying = PrioritySampler(priority_scale=priority_scale) + self.samples_for_replaying = deque() for sample, weight in samples: - self.sampler_for_replaying.add(sample, weight) + self.samples_for_replaying.append(sample) # check the tokenizer and add padding tokens if self.tokenizer.pad_token is None: @@ -112,18 +112,13 @@ def step(self, action): ) # compute reward - reward = ( - None - if self.reward_function is None - else self.reward_function( + reward = self.reward_function( previous_obs, action, self.__current_obs, done, self.__current_obs.meta_info, ) - ) - # populate additional info info = { @@ -143,7 +138,7 @@ def reset(self, sample = None): """ # gets a new sample if not provided if sample is None: - sample = self.sampler_for_replaying.sample(size=1)[0] + sample = np.random.choice(a=self.samples_for_replaying, size=min(len(self.samples_for_replaying), 1))[0] self.__current_sample = sample # init the observation @@ -168,6 +163,3 @@ def render(self): def close(self): pass - - def add_sample(self, sample, weight = 1.0): - self.sampler_for_replaying.add(sample, weight) diff --git a/benchmark/torch/RL4LMs/env/vec_env.py b/benchmark/torch/RL4LMs/env/vec_env.py index e4f5e4136..d6a21ee9a 100644 --- a/benchmark/torch/RL4LMs/env/vec_env.py +++ b/benchmark/torch/RL4LMs/env/vec_env.py @@ -2,7 +2,6 @@ import cloudpickle import gym from collections import OrderedDict -from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union, Dict import multiprocessing as mp @@ -16,7 +15,7 @@ def __getstate__(self): def __setstate__(self, var): self.var = cloudpickle.loads(var) -def _flatten_obs(obs, space: gym.spaces.Space): +def _flatten_obs(obs, space): assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" assert len(obs) > 0, "need observations from at least one environment" @@ -32,7 +31,7 @@ def _flatten_obs(obs, space: gym.spaces.Space): return np.stack(obs) def _worker( - remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper + remote, parent_remote, env_fn_wrapper ): # Import here to avoid a circular import @@ -53,21 +52,12 @@ def _worker( elif cmd == "reset": observation = env.reset() remote.send(observation) - elif cmd == "render": - remote.send(env.render(data)) elif cmd == "close": env.close() remote.close() break elif cmd == "get_spaces": remote.send((env.observation_space, env.action_space)) - elif cmd == "env_method": - method = getattr(env, data[0]) - remote.send(method(*data[1], **data[2])) - elif cmd == "get_attr": - remote.send(getattr(env, data)) - elif cmd == "set_attr": - remote.send(setattr(env, data[0], data[1])) else: raise NotImplementedError(f"`{cmd}` is not implemented in the worker") except EOFError: @@ -75,10 +65,11 @@ def _worker( class LocalParallelVecEnv: - def __init__(self, env_fns, start_method = None): + def __init__(self, env_fns, tokenizer=None, start_method = None): self.waiting = False self.closed = False n_envs = len(env_fns) + self.tokenizer = tokenizer if start_method is None: # Fork is not a thread safe method (see issue #217) @@ -115,7 +106,7 @@ def step_wait(self): obs, rews, dones, infos = zip(*results) return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos - def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + def seed(self, seed = None): if seed is None: seed = np.random.randint(0, 2**32 - 1) for idx, remote in enumerate(self.remotes): @@ -128,7 +119,7 @@ def reset(self): obs = [remote.recv() for remote in self.remotes] return _flatten_obs(obs, self.observation_space) - def close(self) -> None: + def close(self): if self.closed: return if self.waiting: @@ -140,43 +131,7 @@ def close(self) -> None: process.join() self.closed = True - def get_attr(self, attr_name: str, indices) -> List[Any]: - """Return attribute from vectorized environment (see base class).""" - target_remotes = self._get_target_remotes(indices) - for remote in target_remotes: - remote.send(("get_attr", attr_name)) - return [remote.recv() for remote in target_remotes] - - def set_attr(self, attr_name: str, value: Any, indices = None) -> None: - """Set attribute inside vectorized environments (see base class).""" - target_remotes = self._get_target_remotes(indices) - for remote in target_remotes: - remote.send(("set_attr", (attr_name, value))) - for remote in target_remotes: - remote.recv() - - def env_method(self, method_name: str, *method_args, indices = None, **method_kwargs) -> List[Any]: - """Call instance methods of vectorized environments.""" - target_remotes = self._get_target_remotes(indices) - for remote in target_remotes: - remote.send(("env_method", (method_name, method_args, method_kwargs))) - return [remote.recv() for remote in target_remotes] - - def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices = None) -> List[bool]: - """Check if worker environments are wrapped with a given wrapper""" - target_remotes = self._get_target_remotes(indices) - for remote in target_remotes: - remote.send(("is_wrapped", wrapper_class)) - return [remote.recv() for remote in target_remotes] - - def _get_target_remotes(self, indices) -> List[Any]: - if indices is None: - indices = range(self.num_envs) - elif isinstance(indices, int): - indices = [indices] - return [self.remotes[i] for i in indices] - - def step(self, actions: np.ndarray): + def step(self, actions): """ Step the environments with the given action @@ -187,9 +142,9 @@ def step(self, actions: np.ndarray): return self.step_wait() def make_vec_env( - env_id: Union[str, Type[gym.Env]], - seed: Optional[int] = None, - start_index: int = 0, + env_id, + seed = None, + start_index = 0, env_config = None, reward_fn = None, tokenizer = None, @@ -211,4 +166,4 @@ def _init(): return env return _init - return LocalParallelVecEnv([make_env(i + start_index) for i in range(n_envs)]) + return LocalParallelVecEnv([make_env(i + start_index) for i in range(n_envs)], tokenizer=tokenizer) diff --git a/benchmark/torch/RL4LMs/rl4lm_ppo.py b/benchmark/torch/RL4LMs/rl4lm_ppo.py index 4a1c753e7..47c2bda90 100644 --- a/benchmark/torch/RL4LMs/rl4lm_ppo.py +++ b/benchmark/torch/RL4LMs/rl4lm_ppo.py @@ -10,22 +10,22 @@ class RL4LMPPO(parl.Algorithm): def __init__(self, - model: parl.Model, + model, learning_rate = 3e-4, - n_steps: int = 2048, - batch_size: int = 64, - n_epochs: int = 10, - gamma: float = 0.99, - gae_lambda: float = 0.95, + n_steps = 2048, + batch_size = 64, + n_epochs = 10, + gamma = 0.99, + gae_lambda = 0.95, clip_range = 0.2, - normalize_advantage: bool = True, - ent_coef: float = 0.0, - vf_coef: float = 0.5, - max_grad_norm: float = 0.5, - target_kl: Optional[float] = None, - seed: Optional[int] = None, - device: Union[torch.device, str] = "auto", - _init_setup_model: bool = True, + normalize_advantage = True, + ent_coef = 0.0, + vf_coef = 0.5, + max_grad_norm = 0.5, + target_kl = None, + seed = None, + device = "auto", + _init_setup_model = True, ): super(RL4LMPPO, self).__init__(model=model) self.learning_rate = learning_rate @@ -127,10 +127,6 @@ def learn(self, rollout_buffer, log_info): return continue_training, loss - - def sample(self, obs): - pass - def predict(self, obs): pass @@ -140,20 +136,17 @@ def value(self, obs): def forward_value( self, obs, - past_model_kwargs = None, ): - return self.model.forward_value(obs, past_model_kwargs) + return self.model.forward_value(obs) def forward_policy( self, obs, - actions: torch.tensor, - past_model_kwargs = None, + actions, ): return self.model.forward_policy( obs = obs, actions = actions, - past_model_kwargs = past_model_kwargs, ) @@ -161,11 +154,10 @@ def get_log_probs_ref_model( self, obs, action, - model_kwarpast_model_kwargsgs = None, ): - return self.model.get_log_probs_ref_model(obs, action, model_kwarpast_model_kwargsgs) + return self.model.get_log_probs_ref_model(obs, action) - def generate( + def sample( self, tokenizer, texts = None, @@ -174,7 +166,7 @@ def generate( attention_mask = None, gen_kwargs = None, ): - return self.model.generate( + return self.model.sample( input_ids=input_ids, attention_mask=attention_mask, tokenizer=tokenizer, diff --git a/benchmark/torch/RL4LMs/rl4lms_agent.py b/benchmark/torch/RL4LMs/rl4lms_agent.py index 468a9a3d2..4e9357f14 100644 --- a/benchmark/torch/RL4LMs/rl4lms_agent.py +++ b/benchmark/torch/RL4LMs/rl4lms_agent.py @@ -101,26 +101,20 @@ def get_inputs_for_generation(self, obs_tensor): def predict(self, *args, **kwargs): pass - def sample(self, *args, **kwargs): - pass - def forward_value( self, obs, - past_model_kwargs = None, ): - return self.alg.forward_value(obs, past_model_kwargs) + return self.alg.forward_value(obs) def forward_policy( self, obs, actions, - past_model_kwargs = None, ): return self.alg.forward_policy( obs = obs, actions = actions, - past_model_kwargs = past_model_kwargs, ) @@ -128,11 +122,10 @@ def get_log_probs_ref_model( self, obs, action, - model_kwarpast_model_kwargsgs = None, ): - return self.alg.get_log_probs_ref_model(obs, action, model_kwarpast_model_kwargsgs) + return self.alg.get_log_probs_ref_model(obs, action) - def generate( + def sample( self, tokenizer, texts = None, @@ -141,7 +134,7 @@ def generate( attention_mask = None, gen_kwargs = None, ): - return self.alg.generate( + return self.alg.sample( input_ids=input_ids, attention_mask=attention_mask, tokenizer=tokenizer, diff --git a/benchmark/torch/RL4LMs/seq2seq_model.py b/benchmark/torch/RL4LMs/seq2seq_model.py index 7704a7d6f..fde2b85b1 100644 --- a/benchmark/torch/RL4LMs/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/seq2seq_model.py @@ -8,7 +8,7 @@ import parl from benchmark.torch.RL4LMs.utils import ( - override_generation_routines, CategoricalDistribution, + override_generation_routines, GenerationInputs, PolicyOutput, RefPolicyOutput, ValueOutput, EvaluateActionsOutput, GenerationOutputs, @@ -28,7 +28,6 @@ def __init__( optimizer_class = torch.optim.AdamW, generation_kwargs = {}, prompt_truncation_side = "left", - state_dict = None, device = None, ): super(Seq2SeqLMModel, self).__init__() @@ -47,7 +46,6 @@ def __init__( self._apply_model_parallel = apply_model_parallel self._build_model_heads(model_name) self._setup_optimizer(optimizer_kwargs, weight_decay, optimizer_class) - self._action_dist = CategoricalDistribution(self._action_space.n) self._generation_kwargs = generation_kwargs self._prompt_truncation_side = prompt_truncation_side @@ -82,36 +80,27 @@ def forward_policy( self, obs, actions, - past_model_kwargs = None, ): + # 1. prepare model inputs + past_model_kwargs = { + "attention_mask": obs["prompt_or_input_attention_mask_pt"], + } + inputs_tensor, model_input_name, past_model_kwargs = unwrap_model( + self._policy_model + )._prepare_model_inputs( + obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs + ) - # Temp workaround for Seq2seq policy - past_model_kwargs = None - - if past_model_kwargs is None: - # 1. prepare model inputs - past_model_kwargs = { - "attention_mask": obs["prompt_or_input_attention_mask_pt"], - } - inputs_tensor, model_input_name, past_model_kwargs = unwrap_model( - self._policy_model - )._prepare_model_inputs( - obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs - ) - - # 2. prepare encoder outputs - past_model_kwargs = unwrap_model( - self._policy_model - )._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, past_model_kwargs, model_input_name - ) + # 2. prepare encoder outputs + past_model_kwargs = unwrap_model( + self._policy_model + )._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, past_model_kwargs, model_input_name + ) - # 3. Prepare input_ids for auto-regressive generation - input_ids = obs["context_encoded_pt"].int() - decoder_attn_mask = obs["context_attention_mask_pt"] - else: - input_ids = obs["context_encoded_pt"].int() - decoder_attn_mask = past_model_kwargs.pop("decoder_attention_mask") + # 3. Prepare input_ids for auto-regressive generation + input_ids = obs["context_encoded_pt"].int() + decoder_attn_mask = obs["context_attention_mask_pt"] # all set to get into auto-regressive mode # prepare all of the model inputs for the decoder @@ -127,7 +116,7 @@ def forward_policy( next_token_logits = outputs.logits[:, -1, :] # get log probs - dist = self._action_dist.proba_distribution(action_logits=next_token_logits) + dist = Categorical(logits=next_token_logits) log_prob = dist.log_prob(actions) entropy = dist.entropy() @@ -155,35 +144,27 @@ def forward_policy( def forward_value( self, obs, - past_model_kwargs = None, ): - # Temp workaround for Seq2seq policy - past_model_kwargs = None - - if past_model_kwargs is None: - # 1. prepare model inputs - past_model_kwargs = { - "attention_mask": obs["prompt_or_input_attention_mask_pt"], - } - inputs_tensor, model_input_name, past_model_kwargs = unwrap_model( - self._value_model - )._prepare_model_inputs( - obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs - ) + # 1. prepare model inputs + past_model_kwargs = { + "attention_mask": obs["prompt_or_input_attention_mask_pt"], + } + inputs_tensor, model_input_name, past_model_kwargs = unwrap_model( + self._value_model + )._prepare_model_inputs( + obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs + ) - # 2. prepare encoder outputs - past_model_kwargs = unwrap_model( - self._value_model - )._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, past_model_kwargs, model_input_name - ) + # 2. prepare encoder outputs + past_model_kwargs = unwrap_model( + self._value_model + )._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, past_model_kwargs, model_input_name + ) - # 3. Prepare input_ids for auto-regressive generation - input_ids = obs["context_encoded_pt"].int() - decoder_attn_mask = obs["context_attention_mask_pt"] - else: - input_ids = obs["context_encoded_pt"].int() - decoder_attn_mask = past_model_kwargs.pop("decoder_attention_mask") + # 3. Prepare input_ids for auto-regressive generation + input_ids = obs["context_encoded_pt"].int() + decoder_attn_mask = obs["context_attention_mask_pt"] # all set to get into auto-regressive mode # prepare all of the model inputs for the decoder @@ -247,35 +228,27 @@ def get_log_probs_ref_model( self, obs, action, - model_kwarpast_model_kwargsgs = None, ): - # Temp workaround for Seq2seq policy - past_model_kwargs = None - - if past_model_kwargs is None: - # 1. prepare model inputs - past_model_kwargs = { - "attention_mask": obs["prompt_or_input_attention_mask_pt"], - } - inputs_tensor, model_input_name, past_model_kwargs = unwrap_model( - self._ref_model - )._prepare_model_inputs( - obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs - ) + # 1. prepare model inputs + past_model_kwargs = { + "attention_mask": obs["prompt_or_input_attention_mask_pt"], + } + inputs_tensor, model_input_name, past_model_kwargs = unwrap_model( + self._ref_model + )._prepare_model_inputs( + obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs + ) - # 2. prepare encoder outputs - past_model_kwargs = unwrap_model( - self._ref_model - )._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, past_model_kwargs, model_input_name - ) + # 2. prepare encoder outputs + past_model_kwargs = unwrap_model( + self._ref_model + )._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, past_model_kwargs, model_input_name + ) - # 3. Prepare input_ids for auto-regressive generation - input_ids = obs["context_encoded_pt"].int() - decoder_attn_mask = obs["context_attention_mask_pt"] - else: - input_ids = obs["context_encoded_pt"].int() - decoder_attn_mask = past_model_kwargs.pop("decoder_attention_mask") + # 3. Prepare input_ids for auto-regressive generation + input_ids = obs["context_encoded_pt"].int() + decoder_attn_mask = obs["context_attention_mask_pt"] # all set to get into auto-regressive mode # prepare all of the model inputs for the decoder @@ -291,7 +264,7 @@ def get_log_probs_ref_model( next_token_logits = outputs.logits[:, -1, :] # get log probs - dist = self._action_dist.proba_distribution(action_logits=next_token_logits) + dist = Categorical(logits=next_token_logits) log_prob = dist.log_prob(action) # update the model kwargs for further generation @@ -328,7 +301,7 @@ def get_inputs_for_generation(self, obs): def get_language_model(self): return unwrap_model(self._policy_model) - def generate( + def sample( self, tokenizer, texts = None, diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index e2811fc02..81f9372ac 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -18,7 +18,7 @@ from utils import evaluate_on_samples # rollout -from utils import MaskableDictRolloutBuffer, RolloutUtil +from utils import DictRolloutBuffer, RolloutUtil # agent, algorithm and model from rl4lm_ppo import RL4LMPPO @@ -39,8 +39,6 @@ def main(config): device = torch.device("cuda" if torch.cuda. is_available() else "cpu") - rollout_util = RolloutUtil(config["alg"]["kl_div"]) - tokenizer = build_tokenizer(config["tokenizer"]) # reward function & metrics @@ -65,15 +63,15 @@ def main(config): rl4lm_alg = RL4LMPPO(model=rl4lms_model, device=device, **config["alg"]["args"]) agent = RL4LMsAgent(rl4lm_alg, config["alg"]) - rollout_buffer = MaskableDictRolloutBuffer( + rollout_buffer = DictRolloutBuffer( buffer_size=agent.alg.n_steps * env.num_envs, observation_space=env.observation_space, action_space=env.action_space, device=device, gamma=agent.alg.gamma, gae_lambda=agent.alg.gae_lambda, - n_envs=1, ) + rollout_util = RolloutUtil(config["alg"]["kl_div"]) n_iters = int(config["train_evaluation"]["n_iters"]) n_steps_per_iter = env.num_envs * agent.alg.n_steps diff --git a/benchmark/torch/RL4LMs/utils/__init__.py b/benchmark/torch/RL4LMs/utils/__init__.py index 363e8e266..9a0576015 100644 --- a/benchmark/torch/RL4LMs/utils/__init__.py +++ b/benchmark/torch/RL4LMs/utils/__init__.py @@ -5,11 +5,7 @@ from .huggingface_generation_util import override_generation_routines -from .distribution_wrapper import CategoricalDistribution - -from .sample_util import PrioritySampler - -from .buffer import MaskableDictRolloutBuffer +from .buffer import DictRolloutBuffer from .kl_controller import KLController diff --git a/benchmark/torch/RL4LMs/utils/buffer.py b/benchmark/torch/RL4LMs/utils/buffer.py index 7e9ec9123..2d27e2591 100644 --- a/benchmark/torch/RL4LMs/utils/buffer.py +++ b/benchmark/torch/RL4LMs/utils/buffer.py @@ -36,7 +36,7 @@ def get_obs_shape( raise NotImplementedError(f"{observation_space} observation space is not supported") -class MaskableDictRolloutBuffer: +class DictRolloutBuffer: """ Dict Rollout buffer used in on-policy algorithms like A2C/PPO. Extends the RolloutBuffer to use dictionary observations @@ -69,7 +69,6 @@ def __init__( device = "cpu", gae_lambda = 1, gamma = 0.99, - n_envs = 1, ): self.buffer_size = buffer_size self.observation_space = observation_space @@ -80,7 +79,6 @@ def __init__( self.pos = 0 self.full = False self.device = device - self.n_envs = n_envs assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" @@ -92,19 +90,17 @@ def __init__( self.reset() def reset(self): - self.mask_dims = self.action_space.n - assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" self.observations = {} for key, obs_input_shape in self.obs_shape.items(): - self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32) - self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) - self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.observations[key] = np.zeros((self.buffer_size, 1) + obs_input_shape, dtype=np.float32) + self.actions = np.zeros((self.buffer_size, 1, self.action_dim), dtype=np.float32) + self.rewards = np.zeros((self.buffer_size, 1), dtype=np.float32) + self.returns = np.zeros((self.buffer_size, 1), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, 1), dtype=np.float32) + self.values = np.zeros((self.buffer_size, 1), dtype=np.float32) + self.log_probs = np.zeros((self.buffer_size, 1), dtype=np.float32) + self.advantages = np.zeros((self.buffer_size, 1), dtype=np.float32) self.generator_ready = False self.pos = 0 @@ -137,7 +133,7 @@ def add(self, # Reshape needed when using multiple envs with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) if isinstance(self.observation_space.spaces[key], spaces.Discrete): - obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key]) + obs_ = obs_.reshape((1,) + self.obs_shape[key]) self.observations[key][self.pos] = obs_ self.actions[self.pos] = np.array(action).copy() @@ -202,7 +198,7 @@ def swap_and_flatten(self, arr): def get(self, batch_size): assert self.full, "" - indices = np.random.permutation(self.buffer_size * self.n_envs) + indices = np.random.permutation(self.buffer_size * 1) # Prepare the data if not self.generator_ready: @@ -219,10 +215,10 @@ def get(self, batch_size): # Return everything, don't create minibatches if batch_size is None: - batch_size = self.buffer_size * self.n_envs + batch_size = self.buffer_size * 1 start_idx = 0 - while start_idx < self.buffer_size * self.n_envs: + while start_idx < self.buffer_size * 1: yield self._get_samples(indices[start_idx: start_idx + batch_size]) start_idx += batch_size diff --git a/benchmark/torch/RL4LMs/utils/data_wrapper.py b/benchmark/torch/RL4LMs/utils/data_wrapper.py index f69ee23c9..e591e3091 100644 --- a/benchmark/torch/RL4LMs/utils/data_wrapper.py +++ b/benchmark/torch/RL4LMs/utils/data_wrapper.py @@ -192,7 +192,7 @@ def _concat(prompt: torch.tensor, prompt_mask: torch.tensor, actual_size:] = 1 return concatenated, concatenated_mask - def update(self, action: int, tokenizer: AutoTokenizer) -> "Observation": + def update(self, action: int, tokenizer: AutoTokenizer): """ Updates the observation using the given action """ diff --git a/benchmark/torch/RL4LMs/utils/distribution_wrapper.py b/benchmark/torch/RL4LMs/utils/distribution_wrapper.py deleted file mode 100644 index e5824e239..000000000 --- a/benchmark/torch/RL4LMs/utils/distribution_wrapper.py +++ /dev/null @@ -1,68 +0,0 @@ -# refer to stable_baselines3.common.distributions -from torch import nn -from torch.distributions import Categorical -from typing import Tuple -import torch - -class CategoricalDistribution: - """ - Categorical distribution for discrete actions. - - :param action_dim: Number of discrete actions - """ - - def __init__(self, action_dim: int): - super().__init__() - self.action_dim = action_dim - - def proba_distribution_net(self, latent_dim): - """ - Create the layer that represents the distribution: - it will be the logits of the Categorical distribution. - You can then get probabilities using a softmax. - - :param latent_dim: Dimension of the last layer - of the policy network (before the action layer) - :return: - """ - action_logits = nn.Linear(latent_dim, self.action_dim) - return action_logits - - def proba_distribution(self, action_logits: torch.Tensor): - self.distribution = Categorical(logits=action_logits) - return self - - def log_prob(self, actions): - return self.distribution.log_prob(actions) - - def entropy(self): - return self.distribution.entropy() - - def sample(self): - return self.distribution.sample() - - def mode(self): - return torch.argmax(self.distribution.probs, dim=1) - - - def actions_from_params(self, action_logits, deterministic = False): - # Update the proba distribution - self.proba_distribution(action_logits) - return self.get_actions(deterministic=deterministic) - - def log_prob_from_params(self, action_logits): - actions = self.actions_from_params(action_logits) - log_prob = self.log_prob(actions) - return actions, log_prob - - - def get_actions(self, deterministic = False): - """ - Return actions according to the probability distribution. - - :param deterministic: - :return: - """ - if deterministic: - return self.mode() - return self.sample() \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/utils/evaluation_util.py b/benchmark/torch/RL4LMs/utils/evaluation_util.py index f69e16e4a..28602724e 100644 --- a/benchmark/torch/RL4LMs/utils/evaluation_util.py +++ b/benchmark/torch/RL4LMs/utils/evaluation_util.py @@ -16,7 +16,6 @@ def get_batch(samples, batch_size): current_ix += batch_size - def evaluate_on_samples( policy, tokenizer, @@ -36,13 +35,11 @@ def evaluate_on_samples( all_meta_infos = [] ###########CHANGE FOR DEBUG############ tem = [] - for i in range(200): + for i in range(100): tem.append(samples[i]) samples = tem ###########CHANGE FOR DEBUG############ - - n_samples = len(samples) for batch in tqdm(list(get_batch(samples, batch_size)), desc="Evaluating"): batch_generated_texts = generate_text( @@ -97,8 +94,6 @@ def evaluate_on_samples( sample_prediction[metric_key] = sample_scores[ix] sample_predictions_dict.append(sample_prediction) - - metrics_dict_ = { "epoch": epoch, "metrics": corpus_level_metrics @@ -119,7 +114,7 @@ def generate_text( prompt_texts = [ dt_control_token + sample.prompt_or_input_text for sample in samples ] - generated_texts = policy.generate( + generated_texts = policy.sample( tokenizer, prompt_texts, max_prompt_length, gen_kwargs=gen_kwargs ).gen_texts return generated_texts diff --git a/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py b/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py index 20426fb0a..d082f8cfb 100644 --- a/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py +++ b/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py @@ -55,35 +55,6 @@ logger = logging.get_logger(__name__) -@dataclass -class SampleDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using sampling. - - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each - tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`). - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length, - sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - scores: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - @dataclass class SampleEncoderDecoderOutput(ModelOutput): """ @@ -127,8 +98,6 @@ class SampleEncoderDecoderOutput(ModelOutput): decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None -SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] - class GenerationMixinWithRawScores: """ @@ -226,12 +195,6 @@ def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) - """ return {"input_ids": input_ids} - def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: - """ - Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. - """ - return logits - def _prepare_input_ids_for_generation( self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput] ) -> torch.LongTensor: @@ -1080,7 +1043,7 @@ def sample( return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, **model_kwargs, - ) -> Union[SampleOutput, torch.LongTensor]: + ): r""" Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. @@ -1325,12 +1288,7 @@ def sample( decoder_hidden_states=decoder_hidden_states, ) else: - return SampleDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) + raise NotImplementedError else: return input_ids diff --git a/benchmark/torch/RL4LMs/utils/rollout_util.py b/benchmark/torch/RL4LMs/utils/rollout_util.py index b1a1a228c..0c392cfe2 100644 --- a/benchmark/torch/RL4LMs/utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/utils/rollout_util.py @@ -75,7 +75,6 @@ def __init__(self, kl_args): self._kl_controller = KLController(kl_args["coeff"], kl_args["target_kl"]) - def _generate_batch( self, agent=None, @@ -97,7 +96,7 @@ def _generate_batch( # generate text using the model obs_tensor = dict_to_tensor(current_obs, device) generation_inputs = agent.get_inputs_for_generation(obs_tensor) - gen_output = agent.generate( + gen_output = agent.sample( input_ids=generation_inputs.inputs, attention_mask=generation_inputs.attention_masks, tokenizer=tokenizer, @@ -106,9 +105,6 @@ def _generate_batch( # process them one step at a time to collect rollout info episode_wise_transitions = [[] for _ in range(env.num_envs)] ep_terminated = np.zeros((env.num_envs,), dtype=bool) - value_past_state = None - ref_past_state = None - policy_past_state = None for actions_tensor, _ in zip( gen_output.step_wise_actions, gen_output.step_wise_logprobs @@ -125,7 +121,6 @@ def _generate_batch( policy_kwargs = { "obs": obs_tensor, "actions": actions_tensor, - "past_model_kwargs": policy_past_state, } policy_outputs = agent.forward_policy( @@ -149,7 +144,7 @@ def _generate_batch( # get values value_outputs = agent.forward_value( - obs_tensor, value_past_state + obs_tensor ) values, value_past_state = ( value_outputs.values, @@ -159,7 +154,7 @@ def _generate_batch( # get reference log probs ref_policy_outputs = ( agent.get_log_probs_ref_model( - obs_tensor, actions_tensor, ref_past_state + obs_tensor, actions_tensor ) ) ref_log_probs, ref_past_state = ( @@ -230,10 +225,8 @@ def collect_rollouts( rollout_buffer, device ): - used_timesteps = 0 # get tokenizer - tokenizer = env.get_attr("tokenizer", [0]) - tokenizer = tokenizer[0] + tokenizer = env.tokenizer # Switch to eval mode # self._agent.alg.model.set_training_mode(False) @@ -255,7 +248,7 @@ def collect_rollouts( num_timesteps = 0 while not rollout_buffer.full: # generate batch of rollouts - rollout_info, run_timestamps = self._generate_batch( + rollout_info, run_timesteps = self._generate_batch( agent=agent, env=env, rollout_buffer=rollout_buffer, @@ -263,7 +256,7 @@ def collect_rollouts( rollout_info=rollout_info, device=device ) - num_timesteps += run_timestamps + num_timesteps += run_timesteps # aggregate rollout info aggregated_rollout_info = {} diff --git a/benchmark/torch/RL4LMs/utils/sample_util.py b/benchmark/torch/RL4LMs/utils/sample_util.py deleted file mode 100644 index 097539b32..000000000 --- a/benchmark/torch/RL4LMs/utils/sample_util.py +++ /dev/null @@ -1,38 +0,0 @@ -from collections import deque -import numpy as np - -class PrioritySampler: - def __init__(self, max_size: int = None, priority_scale: float = 0.0): - """ - Creates a priority sampler - - Args: - max_size (int): maximum size of the queue - priority_scale (float): 0.0 is a pure uniform sampling, 1.0 is completely priority sampling - """ - self.max_size = max_size - self.items = deque(maxlen=self.max_size) - self.item_priorities = deque(maxlen=self.max_size) - self.priority_scale = priority_scale - - def add(self, item, priority: float): - self.items.append(item) - self.item_priorities.append(priority) - - def sample(self, size: int): - min_sample_size = min(len(self.items), size) - scaled_item_priorities = np.array( - self.item_priorities) ** self.priority_scale - sample_probs = scaled_item_priorities / np.sum(scaled_item_priorities) - samples = np.random.choice( - a=self.items, p=sample_probs, size=min_sample_size) - return samples - - def update(self, item, priority): - index = self.items.index(item) - del self.items[index] - del self.item_priorities[index] - self.add(item, priority) - - def get_all_samples(self): - return self.items From 23735cbe44733ac96491b68b9d81358091153215 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Tue, 7 Mar 2023 20:42:39 +0800 Subject: [PATCH 08/34] remove EvaluateActionsOutput, ValueOutput and PolicyOutput --- benchmark/torch/RL4LMs/rl4lm_ppo.py | 7 +-- benchmark/torch/RL4LMs/seq2seq_model.py | 30 +++--------- benchmark/torch/RL4LMs/utils/__init__.py | 3 +- benchmark/torch/RL4LMs/utils/buffer.py | 4 +- benchmark/torch/RL4LMs/utils/data_wrapper.py | 47 +------------------ .../utils/huggingface_generation_util.py | 1 - benchmark/torch/RL4LMs/utils/rollout_util.py | 42 +++-------------- 7 files changed, 18 insertions(+), 116 deletions(-) diff --git a/benchmark/torch/RL4LMs/rl4lm_ppo.py b/benchmark/torch/RL4LMs/rl4lm_ppo.py index 47c2bda90..8022c230c 100644 --- a/benchmark/torch/RL4LMs/rl4lm_ppo.py +++ b/benchmark/torch/RL4LMs/rl4lm_ppo.py @@ -1,11 +1,8 @@ import parl -from typing import Union, Optional, Dict, Any import torch from gym import spaces -from benchmark.torch.RL4LMs.utils import EvaluateActionsOutput from torch.nn import functional as F - from parl.algorithms.torch import PPO class RL4LMPPO(parl.Algorithm): @@ -60,9 +57,7 @@ def learn(self, rollout_buffer, log_info): actions = rollout_data.actions.long().flatten() - evaluation_output: EvaluateActionsOutput = self.model.evaluate_actions( - rollout_data.observations, actions) - values, log_prob, entropy = evaluation_output.values, evaluation_output.log_prob, evaluation_output.entropy + values, log_prob, entropy = self.model.evaluate_actions(rollout_data.observations, actions) values = values.flatten() # Normalize advantage advantages = rollout_data.advantages diff --git a/benchmark/torch/RL4LMs/seq2seq_model.py b/benchmark/torch/RL4LMs/seq2seq_model.py index fde2b85b1..5df88a4ef 100644 --- a/benchmark/torch/RL4LMs/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/seq2seq_model.py @@ -10,8 +10,7 @@ from benchmark.torch.RL4LMs.utils import ( override_generation_routines, - GenerationInputs, PolicyOutput, RefPolicyOutput, ValueOutput, - EvaluateActionsOutput, GenerationOutputs, + GenerationInputs, GenerationOutputs, ) @@ -135,11 +134,7 @@ def forward_policy( dim=-1, ) - policy_output = PolicyOutput( - actions, log_prob, log_prob, entropy, past_model_kwargs - ) - - return policy_output + return actions, log_prob, entropy, past_model_kwargs def forward_value( self, @@ -199,23 +194,15 @@ def forward_value( (decoder_attn_mask, torch.ones(batch_size, 1).to(decoder_attn_mask.device)), dim=-1, ) - - value_output = ValueOutput(values, past_model_kwargs) - return value_output + return values, past_model_kwargs def evaluate_actions( self, obs, actions ): - policy_outputs = self.forward_policy(obs=obs, actions=actions) - value_outputs = self.forward_value(obs) - - eval_outputs = EvaluateActionsOutput( - values=value_outputs.values, - log_prob=policy_outputs.log_probs, - entropy=policy_outputs.entropy, - ) - return eval_outputs + _, log_prob, entropy, _ = self.forward_policy(obs=obs, actions=actions) + values, _ = self.forward_value(obs) + return values, log_prob, entropy def to(self, device): if self._apply_model_parallel: @@ -279,10 +266,7 @@ def get_log_probs_ref_model( (decoder_attn_mask, torch.ones(batch_size, 1).to(decoder_attn_mask.device)), dim=-1, ) - - ref_policy_output = RefPolicyOutput(log_prob, past_model_kwargs) - - return ref_policy_output + return log_prob, past_model_kwargs def get_policy_first_device(self): return ( diff --git a/benchmark/torch/RL4LMs/utils/__init__.py b/benchmark/torch/RL4LMs/utils/__init__.py index 9a0576015..dc4965794 100644 --- a/benchmark/torch/RL4LMs/utils/__init__.py +++ b/benchmark/torch/RL4LMs/utils/__init__.py @@ -1,5 +1,4 @@ -from .data_wrapper import EvaluateActionsOutput, PolicyOutput, \ - RefPolicyOutput, ValueOutput, GenerationInputs, GenerationOutputs,\ +from .data_wrapper import RefPolicyOutput, GenerationInputs, GenerationOutputs,\ PolicyType, Sample, Observation, TransitionInfo diff --git a/benchmark/torch/RL4LMs/utils/buffer.py b/benchmark/torch/RL4LMs/utils/buffer.py index 2d27e2591..1ac097f33 100644 --- a/benchmark/torch/RL4LMs/utils/buffer.py +++ b/benchmark/torch/RL4LMs/utils/buffer.py @@ -1,7 +1,7 @@ import numpy as np import torch from gym import spaces -from .data_wrapper import MaskableDictRolloutBufferSamples +from .data_wrapper import DictRolloutBufferSamples try: # Check memory used by replay buffer when possible @@ -238,7 +238,7 @@ def to_torch(self, array, copy = True): def _get_samples(self, batch_inds): - return MaskableDictRolloutBufferSamples( + return DictRolloutBufferSamples( observations={key: self.to_torch(obs[batch_inds]) for ( key, obs) in self.observations.items()}, actions=self.to_torch(self.actions[batch_inds]), diff --git a/benchmark/torch/RL4LMs/utils/data_wrapper.py b/benchmark/torch/RL4LMs/utils/data_wrapper.py index e591e3091..37235b247 100644 --- a/benchmark/torch/RL4LMs/utils/data_wrapper.py +++ b/benchmark/torch/RL4LMs/utils/data_wrapper.py @@ -28,7 +28,7 @@ class TransitionInfo: info: Dict[str, Any] -class MaskableDictRolloutBufferSamples(NamedTuple): +class DictRolloutBufferSamples(NamedTuple): observations: TensorDict actions: torch.Tensor old_values: torch.Tensor @@ -51,39 +51,6 @@ class PolicyType(Enum): SEQ2SEQ = 1 -@dataclass -class EvaluateActionsOutput: - """ - Dataclass for the output of the method policy.evaluate_actions(). - This is invoked during training phase for each mini-batch in the rollout buffer - """ - - # values of the given state - values: torch.tensor - # log prob of chosen actions - log_prob: torch.tensor - # entropy of action dist - entropy: torch.tensor - - -@dataclass -class PolicyOutput: - """ - Dataclass for the output of the method policy.foward_policy() - """ - - # chosen actions by policy - actions: torch.tensor - # raw log probs corresponding to chosen actions - raw_log_probs: torch.tensor - # processed log probs (eg: after action masking) for chosen actions - log_probs: torch.tensor - # entropy of action dist - entropy: torch.tensor - # cached policy activations for sequential forward passes - past_model_kwargs: torch.tensor - - @dataclass class RefPolicyOutput: """ @@ -96,18 +63,6 @@ class RefPolicyOutput: past_model_kwargs: torch.tensor -@dataclass -class ValueOutput: - """ - Dataclass for the output of the method policy.forward_value() - """ - - # values corresponding to given state - values: torch.tensor - # cached value activations for sequential forward passes - past_model_kwargs: Dict[str, torch.tensor] - - @dataclass class GenerationInputs: # prompt inputs diff --git a/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py b/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py index d082f8cfb..2666846e9 100644 --- a/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py +++ b/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py @@ -1293,7 +1293,6 @@ def sample( return input_ids - def override_generation_routines(cls): bases = list(cls.__bases__) for base_ix in range(len(bases)): diff --git a/benchmark/torch/RL4LMs/utils/rollout_util.py b/benchmark/torch/RL4LMs/utils/rollout_util.py index 0c392cfe2..70510dd81 100644 --- a/benchmark/torch/RL4LMs/utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/utils/rollout_util.py @@ -123,52 +123,22 @@ def _generate_batch( "actions": actions_tensor, } - policy_outputs = agent.forward_policy( - **policy_kwargs - ) - raw_log_probs, log_probs, policy_past_state = ( - policy_outputs.raw_log_probs, - policy_outputs.log_probs, - policy_outputs.past_model_kwargs, - ) + _, log_probs, _, _ = agent.forward_policy(**policy_kwargs) # sanity check - assert torch.all( - torch.isfinite(log_probs) - ), "Infinite values in log probs" - - # sanity check - assert torch.all( - torch.isfinite(raw_log_probs) - ), "Infinite values in log probs" + assert torch.all(torch.isfinite(log_probs)), "Infinite values in log probs" # get values - value_outputs = agent.forward_value( - obs_tensor - ) - values, value_past_state = ( - value_outputs.values, - value_outputs.past_model_kwargs, - ) + values, _ = agent.forward_value(obs_tensor) # get reference log probs - ref_policy_outputs = ( - agent.get_log_probs_ref_model( - obs_tensor, actions_tensor - ) - ) - ref_log_probs, ref_past_state = ( - ref_policy_outputs.log_probs, - ref_policy_outputs.past_model_kwargs, - ) + ref_log_probs, _ = agent.get_log_probs_ref_model(obs_tensor, actions_tensor) # sanity check - assert torch.all( - torch.isfinite(ref_log_probs) - ), "Infinite values in log probs" + assert torch.all(torch.isfinite(ref_log_probs)), "Infinite values in log probs" # compute KL rewards - kl_div = raw_log_probs - ref_log_probs + kl_div = log_probs - ref_log_probs kl_rewards = -1 * self._kl_controller.kl_coeff * kl_div # step into env to get rewards From bbdd102684971d13d40e91189a0db1f1037af10d Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Wed, 8 Mar 2023 18:03:15 +0800 Subject: [PATCH 09/34] use Reviewer and ReviewerGroup instead of Env --- benchmark/torch/RL4LMs/env/__init__.py | 2 - benchmark/torch/RL4LMs/env/text_gen_env.py | 165 --------- benchmark/torch/RL4LMs/env/vec_env.py | 169 ---------- benchmark/torch/RL4LMs/reviewer.py | 338 +++++++++++++++++++ benchmark/torch/RL4LMs/rl4lm_ppo.py | 2 + benchmark/torch/RL4LMs/t5_ppo.yml | 3 +- benchmark/torch/RL4LMs/train.py | 28 +- benchmark/torch/RL4LMs/utils/rollout_util.py | 117 +------ 8 files changed, 373 insertions(+), 451 deletions(-) delete mode 100644 benchmark/torch/RL4LMs/env/__init__.py delete mode 100644 benchmark/torch/RL4LMs/env/text_gen_env.py delete mode 100644 benchmark/torch/RL4LMs/env/vec_env.py create mode 100644 benchmark/torch/RL4LMs/reviewer.py diff --git a/benchmark/torch/RL4LMs/env/__init__.py b/benchmark/torch/RL4LMs/env/__init__.py deleted file mode 100644 index 09abf2026..000000000 --- a/benchmark/torch/RL4LMs/env/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .text_gen_env import TextGenEnv -from .vec_env import make_vec_env \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/env/text_gen_env.py b/benchmark/torch/RL4LMs/env/text_gen_env.py deleted file mode 100644 index 79287e340..000000000 --- a/benchmark/torch/RL4LMs/env/text_gen_env.py +++ /dev/null @@ -1,165 +0,0 @@ -from cmath import inf -from typing import Dict, Tuple, Optional, List - -from gym import Env, spaces -from gym.spaces.dict import Dict as DictSpace -from gym.spaces.discrete import Discrete -from benchmark.torch.RL4LMs.utils import Sample, Observation -from collections import deque -import numpy as np - - -class TextGenEnv(Env): - def __init__( - self, - tokenizer, - reward_function, - samples, - max_episode_length = 512, - max_prompt_length = None, - terminate_on_eos = False, - context_start_token = None, - prompt_truncation_side = "left", - ): - - """ - A generic RL environment to generate textual sequences. - For eg: text generation, summarization, machine translation, text simplification - Args: - tokenizer (AutoTokenizer): pre-trained tokenizer - reward_function (RewardFunction): reward functiom - samples (Tuple[List[Sample], float]): list of samples - max_episode_length (int, optional): Max steps to the model Defaults to 512. - max_prompt_length (Optional[int], optional): maximum prompt length. Defaults to None. - terminate_on_eos (bool, optional): whether to terminate on EOS. Defaults to False. - context_start_token (bool, optional): start token for the context (For Encoder-Decoder models! ) - prompt_truncation_side (str): truncation side for prompt text (Defaults to "left") - """ - self.tokenizer = tokenizer - self.reward_function = reward_function - self.max_steps = max_episode_length - self._max_text_length = ( - max_prompt_length if max_prompt_length else tokenizer.model_max_length - ) - self._terminate_on_eos = terminate_on_eos - self._context_start_token = context_start_token - self._prompt_truncation_side = prompt_truncation_side - super().__init__() - - # set the observation and action space here - self._vocab_size = tokenizer.vocab_size - self.observation_space = DictSpace( - { - # we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited - # while creating rollout buffers, observations are concatenated for each key - "prompt_or_input_encoded_pt": spaces.Box( - low=0, high=self._vocab_size, shape=(self._max_text_length,) - ), - "prompt_or_input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self._max_text_length,) - ), - "context_encoded_pt": spaces.Box( - low=0, high=self._vocab_size, shape=(self.max_steps,) - ), - "context_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self.max_steps,) - ), - "input_encoded_pt": spaces.Box( - low=0, - high=self._vocab_size, - shape=(self._max_text_length + self.max_steps,), - ), - "input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self._max_text_length + self.max_steps,) - ), - } - ) - self.action_space = Discrete(n=self._vocab_size) - # see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency - if 'mt5' in self.tokenizer.name_or_path: - n = 250112 - self.action_space = Discrete(n=n) - elif 't5' in self.tokenizer.name_or_path: - n = 32128 - self.action_space = Discrete(n=n) - self.samples_for_replaying = deque() - for sample, weight in samples: - self.samples_for_replaying.append(sample) - - # check the tokenizer and add padding tokens - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token - self.tokenizer.padding_side = "left" # TBD: configure this - self.tokenizer.truncation_side = "left" # TBD: configure this - - # init tracking variables - self.__current_sample = None - self.__current_obs = None - self.__time_step = None - - def step(self, action): - self.__time_step += 1 - - # previous obs - previous_obs = self.__current_obs - - # just update the context tensor and gets the new observation - self.__current_obs = self.__current_obs.update(action, self.tokenizer) - - # decide if the episode is finished or not - done = (action == self.tokenizer.eos_token_id and self._terminate_on_eos) or ( - self.__time_step == self.max_steps - ) - - # compute reward - reward = self.reward_function( - previous_obs, - action, - self.__current_obs, - done, - self.__current_obs.meta_info, - ) - - # populate additional info - info = { - "output": self.__current_obs.context_text, - "action_history": self.__current_obs.action_history, - "reference_text": self.__current_obs.target_or_reference_texts, - "prompt_text": self.__current_obs.prompt_or_input_text, - "prev_output": previous_obs.context_text, - "meta_info": previous_obs.meta_info, - } - - return self.__current_obs.to_dict(), reward, done, info - - def reset(self, sample = None): - """ - Resets the environment and starts a new episode - """ - # gets a new sample if not provided - if sample is None: - sample = np.random.choice(a=self.samples_for_replaying, size=min(len(self.samples_for_replaying), 1))[0] - self.__current_sample = sample - - # init the observation - self.__current_obs = Observation.init_from_sample( - sample, - self.tokenizer, - self._max_text_length, - self.max_steps, - self._prompt_truncation_side, - self._context_start_token, - sample.meta_data, - ) - - # start the time step counter - self.__time_step = 0 - - dict_observation = self.__current_obs.to_dict() - return dict_observation - - def render(self): - pass - - def close(self): - pass diff --git a/benchmark/torch/RL4LMs/env/vec_env.py b/benchmark/torch/RL4LMs/env/vec_env.py deleted file mode 100644 index d6a21ee9a..000000000 --- a/benchmark/torch/RL4LMs/env/vec_env.py +++ /dev/null @@ -1,169 +0,0 @@ -import numpy as np -import cloudpickle -import gym -from collections import OrderedDict -import multiprocessing as mp - - -class CloudpickleWrapper: - def __init__(self, var): - self.var = var - - def __getstate__(self): - return cloudpickle.dumps(self.var) - - def __setstate__(self, var): - self.var = cloudpickle.loads(var) - -def _flatten_obs(obs, space): - assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" - assert len(obs) > 0, "need observations from at least one environment" - - if isinstance(space, gym.spaces.Dict): - assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" - assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" - return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) - elif isinstance(space, gym.spaces.Tuple): - assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" - obs_len = len(space.spaces) - return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) - else: - return np.stack(obs) - -def _worker( - remote, parent_remote, env_fn_wrapper -): - # Import here to avoid a circular import - - parent_remote.close() - env = env_fn_wrapper.var() - while True: - try: - cmd, data = remote.recv() - if cmd == "step": - observation, reward, done, info = env.step(data) - if done: - # save final observation where user can get it, then reset - info["terminal_observation"] = observation - observation = env.reset() - remote.send((observation, reward, done, info)) - elif cmd == "seed": - remote.send(env.seed(data)) - elif cmd == "reset": - observation = env.reset() - remote.send(observation) - elif cmd == "close": - env.close() - remote.close() - break - elif cmd == "get_spaces": - remote.send((env.observation_space, env.action_space)) - else: - raise NotImplementedError(f"`{cmd}` is not implemented in the worker") - except EOFError: - break - -class LocalParallelVecEnv: - - def __init__(self, env_fns, tokenizer=None, start_method = None): - self.waiting = False - self.closed = False - n_envs = len(env_fns) - self.tokenizer = tokenizer - - if start_method is None: - # Fork is not a thread safe method (see issue #217) - # but is more user friendly (does not require to wrap the code in - # a `if __name__ == "__main__":`) - forkserver_available = "forkserver" in mp.get_all_start_methods() - start_method = "forkserver" if forkserver_available else "spawn" - ctx = mp.get_context(start_method) - - self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)]) - self.processes = [] - for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns): - args = (work_remote, remote, CloudpickleWrapper(env_fn)) - # daemon=True: if the main process crashes, we should not cause things to hang - process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error - process.start() - self.processes.append(process) - work_remote.close() - - self.remotes[0].send(("get_spaces", None)) - observation_space, action_space = self.remotes[0].recv() - self.num_envs = len(env_fns) - self.observation_space = observation_space - self.action_space = action_space - - def step_async(self, actions): - for remote, action in zip(self.remotes, actions): - remote.send(("step", action)) - self.waiting = True - - def step_wait(self): - results = [remote.recv() for remote in self.remotes] - self.waiting = False - obs, rews, dones, infos = zip(*results) - return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos - - def seed(self, seed = None): - if seed is None: - seed = np.random.randint(0, 2**32 - 1) - for idx, remote in enumerate(self.remotes): - remote.send(("seed", seed + idx)) - return [remote.recv() for remote in self.remotes] - - def reset(self): - for remote in self.remotes: - remote.send(("reset", None)) - obs = [remote.recv() for remote in self.remotes] - return _flatten_obs(obs, self.observation_space) - - def close(self): - if self.closed: - return - if self.waiting: - for remote in self.remotes: - remote.recv() - for remote in self.remotes: - remote.send(("close", None)) - for process in self.processes: - process.join() - self.closed = True - - def step(self, actions): - """ - Step the environments with the given action - - :param actions: the action - :return: observation, reward, done, information - """ - self.step_async(actions) - return self.step_wait() - -def make_vec_env( - env_id, - seed = None, - start_index = 0, - env_config = None, - reward_fn = None, - tokenizer = None, - train_samples = None -): - n_envs = env_config["n_envs"] - env_kwargs = { - "reward_function": reward_fn, - "tokenizer": tokenizer, - "samples": train_samples, - } - env_kwargs = {**env_kwargs, **env_config.get("args", {})} - def make_env(rank): - def _init(): - env = env_id(**env_kwargs) - if seed is not None: - env.seed(seed + rank) - env.action_space.seed(seed + rank) - return env - return _init - - return LocalParallelVecEnv([make_env(i + start_index) for i in range(n_envs)], tokenizer=tokenizer) diff --git a/benchmark/torch/RL4LMs/reviewer.py b/benchmark/torch/RL4LMs/reviewer.py new file mode 100644 index 000000000..c970a722b --- /dev/null +++ b/benchmark/torch/RL4LMs/reviewer.py @@ -0,0 +1,338 @@ +import gym +from collections import OrderedDict +import torch +from utils import TransitionInfo, Sample, Observation +from gym import Env, spaces +from gym.spaces.dict import Dict as DictSpace +from gym.spaces.discrete import Discrete +import parl +from collections import deque +import numpy as np + +def _flatten_obs(obs, space): + assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" + assert len(obs) > 0, "need observations from at least one environment" + + if isinstance(space, gym.spaces.Dict): + assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" + assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" + return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) + elif isinstance(space, gym.spaces.Tuple): + assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" + obs_len = len(space.spaces) + return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) + else: + return np.stack(obs) + +def dict_to_tensor(obs, device): + return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} + +def unpack_observations(obs_tensor, n_envs): + """ + Unpacks vectorized dict observations into separate dict observations + """ + unpacked_obs = [] + keys = obs_tensor.keys() + for env_ix in range(n_envs): + obs_dict = {} + for key in keys: + obs_dict[key] = obs_tensor[key][env_ix].reshape(1, -1).cpu() + unpacked_obs.append(obs_dict) + return unpacked_obs + + +# @parl.remote_class(wait=False) +class Reviewer: + def __init__( + self, + tokenizer, + reward_function, + samples, + max_episode_length = 512, + max_prompt_length = None, + terminate_on_eos = False, + context_start_token = None, + prompt_truncation_side = "left", + ): + + """ + A generic RL environment to generate textual sequences. + For eg: text generation, summarization, machine translation, text simplification + Args: + tokenizer (AutoTokenizer): pre-trained tokenizer + reward_function (RewardFunction): reward functiom + samples (Tuple[List[Sample], float]): list of samples + max_episode_length (int, optional): Max steps to the model Defaults to 512. + max_prompt_length (Optional[int], optional): maximum prompt length. Defaults to None. + terminate_on_eos (bool, optional): whether to terminate on EOS. Defaults to False. + context_start_token (bool, optional): start token for the context (For Encoder-Decoder models! ) + prompt_truncation_side (str): truncation side for prompt text (Defaults to "left") + """ + self.tokenizer = tokenizer + self.reward_function = reward_function + self.max_steps = max_episode_length + self._max_text_length = ( + max_prompt_length if max_prompt_length else tokenizer.model_max_length + ) + self._terminate_on_eos = terminate_on_eos + self._context_start_token = context_start_token + self._prompt_truncation_side = prompt_truncation_side + super().__init__() + + # set the observation and action space here + self._vocab_size = tokenizer.vocab_size + self.observation_space = DictSpace( + { + # we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited + # while creating rollout buffers, observations are concatenated for each key + "prompt_or_input_encoded_pt": spaces.Box( + low=0, high=self._vocab_size, shape=(self._max_text_length,) + ), + "prompt_or_input_attention_mask_pt": spaces.Box( + low=0, high=1, shape=(self._max_text_length,) + ), + "context_encoded_pt": spaces.Box( + low=0, high=self._vocab_size, shape=(self.max_steps,) + ), + "context_attention_mask_pt": spaces.Box( + low=0, high=1, shape=(self.max_steps,) + ), + "input_encoded_pt": spaces.Box( + low=0, + high=self._vocab_size, + shape=(self._max_text_length + self.max_steps,), + ), + "input_attention_mask_pt": spaces.Box( + low=0, high=1, shape=(self._max_text_length + self.max_steps,) + ), + } + ) + self.action_space = Discrete(n=self._vocab_size) + # see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency + if 'mt5' in self.tokenizer.name_or_path: + n = 250112 + self.action_space = Discrete(n=n) + elif 't5' in self.tokenizer.name_or_path: + n = 32128 + self.action_space = Discrete(n=n) + self.samples_for_replaying = deque() + for sample, weight in samples: + self.samples_for_replaying.append(sample) + + # check the tokenizer and add padding tokens + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.padding_side = "left" # TBD: configure this + self.tokenizer.truncation_side = "left" # TBD: configure this + + # init tracking variables + self.__current_sample = None + self.__current_obs = None + self.__time_step = None + + def get_new_obs_and_feedback_one_step(self, action): + self.__time_step += 1 + + # previous obs + previous_obs = self.__current_obs + + # just update the context tensor and gets the new observation + self.__current_obs = self.__current_obs.update(action, self.tokenizer) + + # decide if the episode is finished or not + done = (action == self.tokenizer.eos_token_id and self._terminate_on_eos) or ( + self.__time_step == self.max_steps + ) + + # compute reward + reward = self.reward_function( + previous_obs, + action, + self.__current_obs, + done, + self.__current_obs.meta_info, + ) + + # populate additional info + info = { + "output": self.__current_obs.context_text, + "action_history": self.__current_obs.action_history, + "reference_text": self.__current_obs.target_or_reference_texts, + "prompt_text": self.__current_obs.prompt_or_input_text, + "prev_output": previous_obs.context_text, + "meta_info": previous_obs.meta_info, + } + + if done: + # save final observation where user can get it, then reset + info["terminal_observation"] = self.__current_obs.to_dict() + observation = self.ask() + return (observation, reward, done, info) + else: + return (self.__current_obs.to_dict(), reward, done, info) + + def ask(self, sample = None): + """ + Resets the environment and starts a new episode + """ + # gets a new sample if not provided + if sample is None: + sample = np.random.choice(a=self.samples_for_replaying, size=min(len(self.samples_for_replaying), 1))[0] + self.__current_sample = sample + + # init the observation + self.__current_obs = Observation.init_from_sample( + sample, + self.tokenizer, + self._max_text_length, + self.max_steps, + self._prompt_truncation_side, + self._context_start_token, + sample.meta_data, + ) + + # start the time step counter + self.__time_step = 0 + + dict_observation = self.__current_obs.to_dict() + return dict_observation + + def get_obs_and_action_space(self): + return self.observation_space, self.action_space + + +class ReviewerGroup: + def __init__(self, + reviewer_config=None, + reward_fn=None, + tokenizer=None, + question_samples=None, + seed = None, + start_index = 0, + ): + self.n_reviewers = reviewer_config["n_reviewers"] + reviewer_kwargs = { + "reward_function": reward_fn, + "tokenizer": tokenizer, + "samples": question_samples, + } + reviewer_kwargs = {**reviewer_kwargs, **reviewer_config.get("args", {})} + self.tokenizer = tokenizer + self._remote_reviewers = self._create_reviewers(reviewer_kwargs) + tem_future_object_ids = self._remote_reviewers[0].get_obs_and_action_space() + # self.observation_space, self.action_space = tem_future_object_ids.get() + self.observation_space, self.action_space = tem_future_object_ids + + def ask(self): + future_object_ids = [ + remote_reviewer.ask() for remote_reviewer in self._remote_reviewers + ] + # sample_questions = [ + # future_object.get() for future_object in future_object_ids + # ] + sample_questions = future_object_ids + return _flatten_obs(sample_questions, self.observation_space) + + def feedback(self, current_obs, gen_output, kl_criterion, agent, device): + review_times = 0 + episode_starts = np.ones((self.n_reviewers,), dtype=bool) + # process them one step at a time to collect rollout info + episode_wise_transitions = [[] for _ in range(self.n_reviewers)] + ep_terminated = np.zeros((self.n_reviewers,), dtype=bool) + + for actions_tensor, _ in zip( + gen_output.step_wise_actions, gen_output.step_wise_logprobs + ): + # if all episodes are done, just break and do not continue + if np.all(ep_terminated): + break + + # evaluate actions with actions from rollout + with torch.no_grad(): + obs_tensor = dict_to_tensor(current_obs, device) + + # get log probs (TBD: generalize this a bit) + policy_kwargs = { + "obs": obs_tensor, + "actions": actions_tensor, + } + + _, log_probs, _, _ = agent.forward_policy(**policy_kwargs) + + # sanity check + assert torch.all(torch.isfinite(log_probs)), "Infinite values in log probs" + + # get values + values, _ = agent.forward_value(obs_tensor) + + # get reference log probs + ref_log_probs, _ = agent.get_log_probs_ref_model(obs_tensor, actions_tensor) + + # sanity check + assert torch.all(torch.isfinite(ref_log_probs)), "Infinite values in log probs" + + # compute KL rewards + kl_div = log_probs - ref_log_probs + kl_rewards = -1 * kl_criterion.kl_coeff * kl_div + + # step into env to get rewards + actions = actions_tensor.cpu().numpy() + new_obs, rewards, dones, infos = self._feedback_one_step(actions) + + review_times += self.n_reviewers + + # compute total rewards + total_rewards = rewards + kl_rewards.cpu().numpy() + + # unpack individual observations + unpacked_obs = unpack_observations(obs_tensor, self.n_reviewers) + + # store episode wise transitions separately + for env_ix in range(self.n_reviewers): + # only if not terminated already + if not ep_terminated[env_ix]: + transtion = TransitionInfo( + observation=unpacked_obs[env_ix], + action=actions[env_ix], + task_reward=rewards[env_ix], + total_reward=total_rewards[env_ix], + kl_div=kl_div.cpu().numpy()[env_ix], + episode_start=episode_starts[env_ix], + value=values[env_ix].cpu(), + log_prob=log_probs[env_ix].cpu(), + done=dones[env_ix], + ref_log_prob=ref_log_probs[env_ix].cpu(), + kl_reward=kl_rewards.cpu().numpy()[env_ix], + info=infos[env_ix], + ) + + episode_wise_transitions[env_ix].append(transtion) + + # mark this episode to terminated if done occurs once + if dones[env_ix]: + ep_terminated[env_ix] = True + + episode_starts = np.zeros((self.n_reviewers,), dtype=bool) + current_obs = new_obs + return episode_wise_transitions, review_times + + def _feedback_one_step(self, actions): + future_object_ids = [ + self._remote_reviewers[i].get_new_obs_and_feedback_one_step( + actions[i]) for i in range(self.n_reviewers) + ] + # feedback_res = [ + # future_object.get() for future_object in future_object_ids + # ] + feedback_res = future_object_ids + obs, rews, dones, infos = zip(*feedback_res) + return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos + + + def _create_reviewers(self, reviewer_kwargs): + # parl.connect(reviewer_kwargs['parl_master_address']) + return [Reviewer(**reviewer_kwargs) for _ in range(self.n_reviewers)] + + + + diff --git a/benchmark/torch/RL4LMs/rl4lm_ppo.py b/benchmark/torch/RL4LMs/rl4lm_ppo.py index 8022c230c..9a365d222 100644 --- a/benchmark/torch/RL4LMs/rl4lm_ppo.py +++ b/benchmark/torch/RL4LMs/rl4lm_ppo.py @@ -39,6 +39,8 @@ def __init__(self, self.target_kl = target_kl self.seed = seed self.device = device + for param_group in self.model.optimizer.param_groups: + param_group["lr"] = self.learning_rate def learn(self, rollout_buffer, log_info): entropy_losses = log_info["entropy_losses"] diff --git a/benchmark/torch/RL4LMs/t5_ppo.yml b/benchmark/torch/RL4LMs/t5_ppo.yml index 75e18162d..cabc697ed 100644 --- a/benchmark/torch/RL4LMs/t5_ppo.yml +++ b/benchmark/torch/RL4LMs/t5_ppo.yml @@ -19,7 +19,8 @@ datapool: env: ## CHANGE FOR DEBUG ## # n_envs: 10 - n_envs: 2 + parl_master_address: "localhost:8811" + n_reviewers: 2 ## CHANGE FOR DEBUG ## args: max_prompt_length: 512 diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 81f9372ac..14d7ad7f5 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -11,7 +11,7 @@ # env and reward function from utils import build_reward_fn -from env import TextGenEnv, make_vec_env +from reviewer import ReviewerGroup # evaluation, metrics, tokenizer & dataset from utils import build_metrics, build_tokenizer, build_datapool @@ -48,15 +48,15 @@ def main(config): # datapool samples_by_split = build_datapool(config["datapool"]) - env = make_vec_env(env_id=TextGenEnv, - env_config=config["env"], - reward_fn=reward_fn, - tokenizer=tokenizer, - train_samples= samples_by_split["train"]) + + reviewer_group = ReviewerGroup(reviewer_config=config["env"], + reward_fn=reward_fn, + tokenizer=tokenizer, + question_samples=samples_by_split["train"]) rl4lms_model = Seq2SeqLMModel( - observation_space = env.observation_space, - action_space= env.action_space, + observation_space = reviewer_group.observation_space, + action_space= reviewer_group.action_space, device=device, **config["alg"]["model"]["args"] ) @@ -64,17 +64,17 @@ def main(config): agent = RL4LMsAgent(rl4lm_alg, config["alg"]) rollout_buffer = DictRolloutBuffer( - buffer_size=agent.alg.n_steps * env.num_envs, - observation_space=env.observation_space, - action_space=env.action_space, + buffer_size=agent.alg.n_steps * reviewer_group.n_reviewers, + observation_space=reviewer_group.observation_space, + action_space=reviewer_group.action_space, device=device, gamma=agent.alg.gamma, gae_lambda=agent.alg.gae_lambda, ) - rollout_util = RolloutUtil(config["alg"]["kl_div"]) + rollout_util = RolloutUtil(config["alg"]["kl_div"], reviewer_group) n_iters = int(config["train_evaluation"]["n_iters"]) - n_steps_per_iter = env.num_envs * agent.alg.n_steps + n_steps_per_iter = reviewer_group.n_reviewers * agent.alg.n_steps max_prompt_length = config["env"]["args"]["max_prompt_length"] @@ -104,7 +104,7 @@ def main(config): num_timesteps = 0 while num_timesteps < n_steps_per_iter: - run_timesteps = rollout_util.collect_rollouts(agent, env, rollout_buffer, device) + run_timesteps = rollout_util.collect_rollouts(agent, reviewer_group, rollout_buffer, device) num_timesteps += run_timesteps agent.learn(rollout_buffer) diff --git a/benchmark/torch/RL4LMs/utils/rollout_util.py b/benchmark/torch/RL4LMs/utils/rollout_util.py index 70510dd81..58e1efad6 100644 --- a/benchmark/torch/RL4LMs/utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/utils/rollout_util.py @@ -1,25 +1,13 @@ import torch import numpy as np -from .data_wrapper import TransitionInfo + from .kl_controller import KLController +from parl.utils import logger def dict_to_tensor(obs, device): return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} -def unpack_observations(obs_tensor, n_envs): - """ - Unpacks vectorized dict observations into separate dict observations - """ - unpacked_obs = [] - keys = obs_tensor.keys() - for env_ix in range(n_envs): - obs_dict = {} - for key in keys: - obs_dict[key] = obs_tensor[key][env_ix].reshape(1, -1).cpu() - unpacked_obs.append(obs_dict) - return unpacked_obs - def add_to_buffer( rollout_buffer, episode_wise_transitions, rollout_info @@ -71,27 +59,26 @@ def add_to_buffer( class RolloutUtil: - def __init__(self, kl_args): + def __init__(self, kl_args, reviewer_group): self._kl_controller = KLController(kl_args["coeff"], kl_args["target_kl"]) def _generate_batch( self, agent=None, - env=None, + reviewer_group=None, rollout_buffer=None, tokenizer=None, rollout_info=None, device=None ): - num_timesteps = 0 # if rollout buffer is already full, do not continue if rollout_buffer.full: return # start parallel episodes - current_obs = env.reset() - episode_starts = np.ones((env.num_envs,), dtype=bool) + current_obs = reviewer_group.ask() + # generate text using the model obs_tensor = dict_to_tensor(current_obs, device) @@ -102,84 +89,11 @@ def _generate_batch( tokenizer=tokenizer, ) - # process them one step at a time to collect rollout info - episode_wise_transitions = [[] for _ in range(env.num_envs)] - ep_terminated = np.zeros((env.num_envs,), dtype=bool) - - for actions_tensor, _ in zip( - gen_output.step_wise_actions, gen_output.step_wise_logprobs - ): - # if all episodes are done, just break and do not continue - if np.all(ep_terminated): - break - - # evaluate actions with actions from rollout - with torch.no_grad(): - obs_tensor = dict_to_tensor(current_obs, device) - - # get log probs (TBD: generalize this a bit) - policy_kwargs = { - "obs": obs_tensor, - "actions": actions_tensor, - } - - _, log_probs, _, _ = agent.forward_policy(**policy_kwargs) - - # sanity check - assert torch.all(torch.isfinite(log_probs)), "Infinite values in log probs" - - # get values - values, _ = agent.forward_value(obs_tensor) - - # get reference log probs - ref_log_probs, _ = agent.get_log_probs_ref_model(obs_tensor, actions_tensor) - - # sanity check - assert torch.all(torch.isfinite(ref_log_probs)), "Infinite values in log probs" - - # compute KL rewards - kl_div = log_probs - ref_log_probs - kl_rewards = -1 * self._kl_controller.kl_coeff * kl_div - - # step into env to get rewards - actions = actions_tensor.cpu().numpy() - new_obs, rewards, dones, infos = env.step(actions) - - num_timesteps += env.num_envs - - # compute total rewards - total_rewards = rewards + kl_rewards.cpu().numpy() - - # unpack individual observations - unpacked_obs = unpack_observations(obs_tensor, env.num_envs) - - # store episode wise transitions separately - for env_ix in range(env.num_envs): - # only if not terminated already - if not ep_terminated[env_ix]: - transtion = TransitionInfo( - observation=unpacked_obs[env_ix], - action=actions[env_ix], - task_reward=rewards[env_ix], - total_reward=total_rewards[env_ix], - kl_div=kl_div.cpu().numpy()[env_ix], - episode_start=episode_starts[env_ix], - value=values[env_ix].cpu(), - log_prob=log_probs[env_ix].cpu(), - done=dones[env_ix], - ref_log_prob=ref_log_probs[env_ix].cpu(), - kl_reward=kl_rewards.cpu().numpy()[env_ix], - info=infos[env_ix], - ) - - episode_wise_transitions[env_ix].append(transtion) - - # mark this episode to terminated if done occurs once - if dones[env_ix]: - ep_terminated[env_ix] = True - - episode_starts = np.zeros((env.num_envs,), dtype=bool) - current_obs = new_obs + episode_wise_transitions, num_timesteps = reviewer_group.feedback(current_obs=current_obs, + gen_output=gen_output, + kl_criterion=self._kl_controller, + agent=agent, + device=device) # now we flush all episode wise info to the 1-D buffer rollout_info = add_to_buffer( @@ -191,12 +105,12 @@ def _generate_batch( def collect_rollouts( self, agent, - env, + reviewer_group, rollout_buffer, device ): # get tokenizer - tokenizer = env.tokenizer + tokenizer = reviewer_group.tokenizer # Switch to eval mode # self._agent.alg.model.set_training_mode(False) @@ -220,7 +134,7 @@ def collect_rollouts( # generate batch of rollouts rollout_info, run_timesteps = self._generate_batch( agent=agent, - env=env, + reviewer_group=reviewer_group, rollout_buffer=rollout_buffer, tokenizer=tokenizer, rollout_info=rollout_info, @@ -237,6 +151,9 @@ def collect_rollouts( "rollout_info/kl_coeff" ] = self._kl_controller.kl_coeff + logger.info(f"Rollout Info: {aggregated_rollout_info}") + + # adapt the KL coeff self._kl_controller.step( torch.tensor(aggregated_rollout_info["rollout_info/kl_div_mean"]) From a9aef6b507f8bc0adc29a1e14c94751a893533a7 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Thu, 9 Mar 2023 10:14:23 +0800 Subject: [PATCH 10/34] use Reviewer and ReviewerGroup instead of Env (parl parallel) --- benchmark/torch/RL4LMs/reviewer.py | 32 +++++++++---------- .../{utils => rl4lms_utils}/__init__.py | 0 .../RL4LMs/{utils => rl4lms_utils}/buffer.py | 0 .../component_build_util.py | 0 .../{utils => rl4lms_utils}/data_pool.py | 0 .../{utils => rl4lms_utils}/data_wrapper.py | 0 .../evaluation_util.py | 0 .../huggingface_generation_util.py | 0 .../{utils => rl4lms_utils}/kl_controller.py | 0 .../{utils => rl4lms_utils}/metric_util.py | 0 .../{utils => rl4lms_utils}/reward_util.py | 0 .../{utils => rl4lms_utils}/rollout_util.py | 0 benchmark/torch/RL4LMs/seq2seq_model.py | 2 +- benchmark/torch/RL4LMs/t5_ppo.yml | 5 +-- benchmark/torch/RL4LMs/train.py | 12 +++---- 15 files changed, 26 insertions(+), 25 deletions(-) rename benchmark/torch/RL4LMs/{utils => rl4lms_utils}/__init__.py (100%) rename benchmark/torch/RL4LMs/{utils => rl4lms_utils}/buffer.py (100%) rename benchmark/torch/RL4LMs/{utils => rl4lms_utils}/component_build_util.py (100%) rename benchmark/torch/RL4LMs/{utils => rl4lms_utils}/data_pool.py (100%) rename benchmark/torch/RL4LMs/{utils => rl4lms_utils}/data_wrapper.py (100%) rename benchmark/torch/RL4LMs/{utils => rl4lms_utils}/evaluation_util.py (100%) rename benchmark/torch/RL4LMs/{utils => rl4lms_utils}/huggingface_generation_util.py (100%) rename benchmark/torch/RL4LMs/{utils => rl4lms_utils}/kl_controller.py (100%) rename benchmark/torch/RL4LMs/{utils => rl4lms_utils}/metric_util.py (100%) rename benchmark/torch/RL4LMs/{utils => rl4lms_utils}/reward_util.py (100%) rename benchmark/torch/RL4LMs/{utils => rl4lms_utils}/rollout_util.py (100%) diff --git a/benchmark/torch/RL4LMs/reviewer.py b/benchmark/torch/RL4LMs/reviewer.py index c970a722b..64b32b6bc 100644 --- a/benchmark/torch/RL4LMs/reviewer.py +++ b/benchmark/torch/RL4LMs/reviewer.py @@ -1,7 +1,7 @@ import gym from collections import OrderedDict import torch -from utils import TransitionInfo, Sample, Observation +from rl4lms_utils import TransitionInfo, Sample, Observation from gym import Env, spaces from gym.spaces.dict import Dict as DictSpace from gym.spaces.discrete import Discrete @@ -41,7 +41,7 @@ def unpack_observations(obs_tensor, n_envs): return unpacked_obs -# @parl.remote_class(wait=False) +@parl.remote_class(wait=False) class Reviewer: def __init__( self, @@ -198,7 +198,7 @@ def ask(self, sample = None): return dict_observation def get_obs_and_action_space(self): - return self.observation_space, self.action_space + return (self.observation_space, self.action_space) class ReviewerGroup: @@ -218,19 +218,19 @@ def __init__(self, } reviewer_kwargs = {**reviewer_kwargs, **reviewer_config.get("args", {})} self.tokenizer = tokenizer - self._remote_reviewers = self._create_reviewers(reviewer_kwargs) + self._remote_reviewers = self._create_reviewers(reviewer_kwargs, reviewer_config["parl_master_address"]) tem_future_object_ids = self._remote_reviewers[0].get_obs_and_action_space() - # self.observation_space, self.action_space = tem_future_object_ids.get() - self.observation_space, self.action_space = tem_future_object_ids + self.observation_space, self.action_space = tem_future_object_ids.get() + # self.observation_space, self.action_space = tem_future_object_ids def ask(self): future_object_ids = [ remote_reviewer.ask() for remote_reviewer in self._remote_reviewers ] - # sample_questions = [ - # future_object.get() for future_object in future_object_ids - # ] - sample_questions = future_object_ids + sample_questions = [ + future_object.get() for future_object in future_object_ids + ] + # sample_questions = future_object_ids return _flatten_obs(sample_questions, self.observation_space) def feedback(self, current_obs, gen_output, kl_criterion, agent, device): @@ -321,16 +321,16 @@ def _feedback_one_step(self, actions): self._remote_reviewers[i].get_new_obs_and_feedback_one_step( actions[i]) for i in range(self.n_reviewers) ] - # feedback_res = [ - # future_object.get() for future_object in future_object_ids - # ] - feedback_res = future_object_ids + feedback_res = [ + future_object.get() for future_object in future_object_ids + ] + # feedback_res = future_object_ids obs, rews, dones, infos = zip(*feedback_res) return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos - def _create_reviewers(self, reviewer_kwargs): - # parl.connect(reviewer_kwargs['parl_master_address']) + def _create_reviewers(self, reviewer_kwargs, parl_port=None): + parl.connect(parl_port, distributed_files=["./rl4lms_utils/*.py", "./*.py"]) return [Reviewer(**reviewer_kwargs) for _ in range(self.n_reviewers)] diff --git a/benchmark/torch/RL4LMs/utils/__init__.py b/benchmark/torch/RL4LMs/rl4lms_utils/__init__.py similarity index 100% rename from benchmark/torch/RL4LMs/utils/__init__.py rename to benchmark/torch/RL4LMs/rl4lms_utils/__init__.py diff --git a/benchmark/torch/RL4LMs/utils/buffer.py b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py similarity index 100% rename from benchmark/torch/RL4LMs/utils/buffer.py rename to benchmark/torch/RL4LMs/rl4lms_utils/buffer.py diff --git a/benchmark/torch/RL4LMs/utils/component_build_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py similarity index 100% rename from benchmark/torch/RL4LMs/utils/component_build_util.py rename to benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py diff --git a/benchmark/torch/RL4LMs/utils/data_pool.py b/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py similarity index 100% rename from benchmark/torch/RL4LMs/utils/data_pool.py rename to benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py diff --git a/benchmark/torch/RL4LMs/utils/data_wrapper.py b/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py similarity index 100% rename from benchmark/torch/RL4LMs/utils/data_wrapper.py rename to benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py diff --git a/benchmark/torch/RL4LMs/utils/evaluation_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/evaluation_util.py similarity index 100% rename from benchmark/torch/RL4LMs/utils/evaluation_util.py rename to benchmark/torch/RL4LMs/rl4lms_utils/evaluation_util.py diff --git a/benchmark/torch/RL4LMs/utils/huggingface_generation_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py similarity index 100% rename from benchmark/torch/RL4LMs/utils/huggingface_generation_util.py rename to benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py diff --git a/benchmark/torch/RL4LMs/utils/kl_controller.py b/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py similarity index 100% rename from benchmark/torch/RL4LMs/utils/kl_controller.py rename to benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py diff --git a/benchmark/torch/RL4LMs/utils/metric_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py similarity index 100% rename from benchmark/torch/RL4LMs/utils/metric_util.py rename to benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py diff --git a/benchmark/torch/RL4LMs/utils/reward_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py similarity index 100% rename from benchmark/torch/RL4LMs/utils/reward_util.py rename to benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py diff --git a/benchmark/torch/RL4LMs/utils/rollout_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py similarity index 100% rename from benchmark/torch/RL4LMs/utils/rollout_util.py rename to benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py diff --git a/benchmark/torch/RL4LMs/seq2seq_model.py b/benchmark/torch/RL4LMs/seq2seq_model.py index 5df88a4ef..5be6782ad 100644 --- a/benchmark/torch/RL4LMs/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/seq2seq_model.py @@ -7,7 +7,7 @@ from transformers.modeling_utils import unwrap_model import parl -from benchmark.torch.RL4LMs.utils import ( +from rl4lms_utils import ( override_generation_routines, GenerationInputs, GenerationOutputs, diff --git a/benchmark/torch/RL4LMs/t5_ppo.yml b/benchmark/torch/RL4LMs/t5_ppo.yml index cabc697ed..72fff2360 100644 --- a/benchmark/torch/RL4LMs/t5_ppo.yml +++ b/benchmark/torch/RL4LMs/t5_ppo.yml @@ -16,13 +16,14 @@ datapool: prompt_prefix: "Summarize: " -env: +reviewer: + parl_master_address: "localhost:8811" ## CHANGE FOR DEBUG ## # n_envs: 10 - parl_master_address: "localhost:8811" n_reviewers: 2 ## CHANGE FOR DEBUG ## args: + max_prompt_length: 512 max_episode_length: 100 terminate_on_eos: True diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 14d7ad7f5..4487900c4 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -10,15 +10,15 @@ import time # env and reward function -from utils import build_reward_fn +from rl4lms_utils import build_reward_fn from reviewer import ReviewerGroup # evaluation, metrics, tokenizer & dataset -from utils import build_metrics, build_tokenizer, build_datapool -from utils import evaluate_on_samples +from rl4lms_utils import build_metrics, build_tokenizer, build_datapool +from rl4lms_utils import evaluate_on_samples # rollout -from utils import DictRolloutBuffer, RolloutUtil +from rl4lms_utils import DictRolloutBuffer, RolloutUtil # agent, algorithm and model from rl4lm_ppo import RL4LMPPO @@ -49,7 +49,7 @@ def main(config): samples_by_split = build_datapool(config["datapool"]) - reviewer_group = ReviewerGroup(reviewer_config=config["env"], + reviewer_group = ReviewerGroup(reviewer_config=config["reviewer"], reward_fn=reward_fn, tokenizer=tokenizer, question_samples=samples_by_split["train"]) @@ -76,7 +76,7 @@ def main(config): n_iters = int(config["train_evaluation"]["n_iters"]) n_steps_per_iter = reviewer_group.n_reviewers * agent.alg.n_steps - max_prompt_length = config["env"]["args"]["max_prompt_length"] + max_prompt_length = config["reviewer"]["args"]["max_prompt_length"] # gen kwargs for evaluation eval_gen_kwargs = config["train_evaluation"]["generation_kwargs"] From bf3c625f1fe4cbf2ed1121bbd477d5af53bd9250 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Thu, 9 Mar 2023 15:04:20 +0800 Subject: [PATCH 11/34] use Reviewer and ReviewerGroup instead of Env (parl parallel version) --- benchmark/torch/RL4LMs/README.md | 1 + benchmark/torch/RL4LMs/reviewer.py | 65 ++++++++++++++++--- .../rl4lms_utils/component_build_util.py | 20 +++++- benchmark/torch/RL4LMs/train.py | 9 ++- 4 files changed, 79 insertions(+), 16 deletions(-) diff --git a/benchmark/torch/RL4LMs/README.md b/benchmark/torch/RL4LMs/README.md index 172025756..688e97f39 100644 --- a/benchmark/torch/RL4LMs/README.md +++ b/benchmark/torch/RL4LMs/README.md @@ -11,6 +11,7 @@ - Change from **\{ trainer: \{ ppo: \{ env, rollout_buffer, policy/model \} \} \}** to **\{trainer: \{env, rollout_buffer, agent: \{ ppo: \{ model \} \} \} \}** according to PARL architecture. +- Use Parl parallel Training ### Running command diff --git a/benchmark/torch/RL4LMs/reviewer.py b/benchmark/torch/RL4LMs/reviewer.py index 64b32b6bc..511576917 100644 --- a/benchmark/torch/RL4LMs/reviewer.py +++ b/benchmark/torch/RL4LMs/reviewer.py @@ -8,6 +8,7 @@ import parl from collections import deque import numpy as np +from rl4lms_utils import build_datapool, build_tokenizer, build_reward_fn def _flatten_obs(obs, space): assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" @@ -45,16 +46,18 @@ def unpack_observations(obs_tensor, n_envs): class Reviewer: def __init__( self, - tokenizer, - reward_function, - samples, + tokenizer=None, + reward_function=None, + samples=None, + reward_config=None, + tokenizer_config=None, + datapool_config=None, max_episode_length = 512, max_prompt_length = None, terminate_on_eos = False, context_start_token = None, prompt_truncation_side = "left", ): - """ A generic RL environment to generate textual sequences. For eg: text generation, summarization, machine translation, text simplification @@ -68,6 +71,12 @@ def __init__( context_start_token (bool, optional): start token for the context (For Encoder-Decoder models! ) prompt_truncation_side (str): truncation side for prompt text (Defaults to "left") """ + if tokenizer is None: + tokenizer = build_tokenizer(tokenizer_config) + if samples is None: + samples = build_datapool(datapool_config, remote_train=True)["train"] + if reward_function is None: + reward_function = build_reward_fn(reward_config) self.tokenizer = tokenizer self.reward_function = reward_function self.max_steps = max_episode_length @@ -204,25 +213,61 @@ def get_obs_and_action_space(self): class ReviewerGroup: def __init__(self, reviewer_config=None, - reward_fn=None, + reward_config=None, tokenizer=None, + datapool_config=None, + tokenizer_config=None, + reward_fn=None, question_samples=None, seed = None, start_index = 0, ): self.n_reviewers = reviewer_config["n_reviewers"] reviewer_kwargs = { - "reward_function": reward_fn, - "tokenizer": tokenizer, - "samples": question_samples, + # "reward_function": reward_fn, + "reward_config": reward_config, + # "tokenizer": tokenizer, + "tokenizer_config": tokenizer_config, + # "samples": question_samples, + "datapool_config": datapool_config } reviewer_kwargs = {**reviewer_kwargs, **reviewer_config.get("args", {})} self.tokenizer = tokenizer self._remote_reviewers = self._create_reviewers(reviewer_kwargs, reviewer_config["parl_master_address"]) - tem_future_object_ids = self._remote_reviewers[0].get_obs_and_action_space() - self.observation_space, self.action_space = tem_future_object_ids.get() + # tem_future_object_ids = self._remote_reviewers[0].get_obs_and_action_space() + # self.observation_space, self.action_space = tem_future_object_ids.get() # self.observation_space, self.action_space = tem_future_object_ids + # due to serialization, build obs space and action space here + self._vocab_size = tokenizer.vocab_size + self.observation_space = DictSpace( + { + # we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited + # while creating rollout buffers, observations are concatenated for each key + "prompt_or_input_encoded_pt": spaces.Box( + low=0, high=self._vocab_size, shape=(reviewer_kwargs["max_prompt_length"],) + ), + "prompt_or_input_attention_mask_pt": spaces.Box( + low=0, high=1, shape=(reviewer_kwargs["max_prompt_length"],) + ), + "context_encoded_pt": spaces.Box( + low=0, high=self._vocab_size, shape=(reviewer_kwargs["max_episode_length"],) + ), + "context_attention_mask_pt": spaces.Box( + low=0, high=1, shape=(reviewer_kwargs["max_episode_length"],) + ), + "input_encoded_pt": spaces.Box( + low=0, + high=self._vocab_size, + shape=(reviewer_kwargs["max_prompt_length"] + reviewer_kwargs["max_episode_length"],), + ), + "input_attention_mask_pt": spaces.Box( + low=0, high=1, shape=(reviewer_kwargs["max_prompt_length"] + reviewer_kwargs["max_episode_length"],) + ), + } + ) + self.action_space = Discrete(n=self._vocab_size) + def ask(self): future_object_ids = [ remote_reviewer.ask() for remote_reviewer in self._remote_reviewers diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py index 79dd20809..5ba3cf537 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py @@ -6,8 +6,13 @@ def build_tokenizer(tokenizer_config): logger.info(f"loading tokenizer of [{tokenizer_config['model_name']}] model") - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_config["model_name"]) + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_config["model_name"]) + except Exception: + logger.info(f"trying to use local_files to load tokenizer of [{tokenizer_config['model_name']}] model") + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_config["model_name"], local_files_only=True) if tokenizer.pad_token is None and tokenizer_config.get("pad_token_as_eos_token", True): tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = tokenizer_config.get( @@ -29,15 +34,24 @@ def build_metrics(metric_configs): return metrics -def build_datapool(datapool_config): +def build_datapool(datapool_config, remote_train=False): def _get_datapool_by_split(split): kwargs = datapool_config.get("args", {}) kwargs["split"] = split logger.info(f"loading split of dataset: {datapool_config['id']} -- {kwargs['split']}") dp_split = CNNDailyMail.prepare(**kwargs) + logger.info(f"finish loading split of dataset: {datapool_config['id']} -- {kwargs['split']}") return dp_split train_datapool = _get_datapool_by_split("train") + + if remote_train: + samples_by_split = { + "train": [(sample, weight) + for sample, weight in train_datapool], + } + return samples_by_split + val_datapool = _get_datapool_by_split("val") test_datapool = _get_datapool_by_split("test") diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 4487900c4..175b2e810 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -42,7 +42,7 @@ def main(config): tokenizer = build_tokenizer(config["tokenizer"]) # reward function & metrics - reward_fn = build_reward_fn(config["reward_fn"]) + # reward_fn = build_reward_fn(config["reward_fn"]) # build reward_fn in reviewer metrics = build_metrics(config["train_evaluation"]["metrics"]) # datapool @@ -50,9 +50,12 @@ def main(config): reviewer_group = ReviewerGroup(reviewer_config=config["reviewer"], - reward_fn=reward_fn, + reward_config=config["reward_fn"], tokenizer=tokenizer, - question_samples=samples_by_split["train"]) + tokenizer_config=config["tokenizer"], + datapool_config=config["datapool"],) + # reward_fn=reward_fn, + # question_samples=samples_by_split["train"]) rl4lms_model = Seq2SeqLMModel( observation_space = reviewer_group.observation_space, From d452685bc6d67b76c1e275db0cdbe3b8fddc308b Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Fri, 10 Mar 2023 12:38:49 +0800 Subject: [PATCH 12/34] review using sentence (parl parallel version) --- benchmark/torch/RL4LMs/reviewer.py | 133 +++++------------- .../torch/RL4LMs/rl4lms_utils/rollout_util.py | 112 +++++++++++++-- benchmark/torch/RL4LMs/train.py | 2 +- 3 files changed, 137 insertions(+), 110 deletions(-) diff --git a/benchmark/torch/RL4LMs/reviewer.py b/benchmark/torch/RL4LMs/reviewer.py index 511576917..4c2c94dcc 100644 --- a/benchmark/torch/RL4LMs/reviewer.py +++ b/benchmark/torch/RL4LMs/reviewer.py @@ -10,36 +10,23 @@ import numpy as np from rl4lms_utils import build_datapool, build_tokenizer, build_reward_fn -def _flatten_obs(obs, space): +def _flatten_obs(obs, space, n_reviewer=None): assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" assert len(obs) > 0, "need observations from at least one environment" if isinstance(space, gym.spaces.Dict): assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" + if n_reviewer is not None: + return OrderedDict([(k, np.stack([o[k] for o in obs]).reshape((n_reviewer, -1, len(obs[0][k])))) for k in space.spaces.keys()]) return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) - elif isinstance(space, gym.spaces.Tuple): - assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" - obs_len = len(space.spaces) - return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) else: - return np.stack(obs) + raise NotImplementedError def dict_to_tensor(obs, device): return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} -def unpack_observations(obs_tensor, n_envs): - """ - Unpacks vectorized dict observations into separate dict observations - """ - unpacked_obs = [] - keys = obs_tensor.keys() - for env_ix in range(n_envs): - obs_dict = {} - for key in keys: - obs_dict[key] = obs_tensor[key][env_ix].reshape(1, -1).cpu() - unpacked_obs.append(obs_dict) - return unpacked_obs + @parl.remote_class(wait=False) @@ -180,6 +167,13 @@ def get_new_obs_and_feedback_one_step(self, action): else: return (self.__current_obs.to_dict(), reward, done, info) + def get_new_obs_and_feedback_sentence(self, sentence): + res = [] + for token in sentence: + one_step_res = self.get_new_obs_and_feedback_one_step(token) + res.append(one_step_res) + return res + def ask(self, sample = None): """ Resets the environment and starts a new episode @@ -278,88 +272,27 @@ def ask(self): # sample_questions = future_object_ids return _flatten_obs(sample_questions, self.observation_space) - def feedback(self, current_obs, gen_output, kl_criterion, agent, device): - review_times = 0 - episode_starts = np.ones((self.n_reviewers,), dtype=bool) - # process them one step at a time to collect rollout info - episode_wise_transitions = [[] for _ in range(self.n_reviewers)] - ep_terminated = np.zeros((self.n_reviewers,), dtype=bool) - - for actions_tensor, _ in zip( - gen_output.step_wise_actions, gen_output.step_wise_logprobs - ): - # if all episodes are done, just break and do not continue - if np.all(ep_terminated): - break - - # evaluate actions with actions from rollout - with torch.no_grad(): - obs_tensor = dict_to_tensor(current_obs, device) - - # get log probs (TBD: generalize this a bit) - policy_kwargs = { - "obs": obs_tensor, - "actions": actions_tensor, - } - - _, log_probs, _, _ = agent.forward_policy(**policy_kwargs) - - # sanity check - assert torch.all(torch.isfinite(log_probs)), "Infinite values in log probs" - - # get values - values, _ = agent.forward_value(obs_tensor) - - # get reference log probs - ref_log_probs, _ = agent.get_log_probs_ref_model(obs_tensor, actions_tensor) - - # sanity check - assert torch.all(torch.isfinite(ref_log_probs)), "Infinite values in log probs" - - # compute KL rewards - kl_div = log_probs - ref_log_probs - kl_rewards = -1 * kl_criterion.kl_coeff * kl_div - - # step into env to get rewards - actions = actions_tensor.cpu().numpy() - new_obs, rewards, dones, infos = self._feedback_one_step(actions) - - review_times += self.n_reviewers - - # compute total rewards - total_rewards = rewards + kl_rewards.cpu().numpy() - - # unpack individual observations - unpacked_obs = unpack_observations(obs_tensor, self.n_reviewers) - - # store episode wise transitions separately - for env_ix in range(self.n_reviewers): - # only if not terminated already - if not ep_terminated[env_ix]: - transtion = TransitionInfo( - observation=unpacked_obs[env_ix], - action=actions[env_ix], - task_reward=rewards[env_ix], - total_reward=total_rewards[env_ix], - kl_div=kl_div.cpu().numpy()[env_ix], - episode_start=episode_starts[env_ix], - value=values[env_ix].cpu(), - log_prob=log_probs[env_ix].cpu(), - done=dones[env_ix], - ref_log_prob=ref_log_probs[env_ix].cpu(), - kl_reward=kl_rewards.cpu().numpy()[env_ix], - info=infos[env_ix], - ) - - episode_wise_transitions[env_ix].append(transtion) - - # mark this episode to terminated if done occurs once - if dones[env_ix]: - ep_terminated[env_ix] = True - - episode_starts = np.zeros((self.n_reviewers,), dtype=bool) - current_obs = new_obs - return episode_wise_transitions, review_times + def feedback_sentense(self, gen_output): + sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos = \ + self._reviewers_feedback_sentence(gen_output.step_wise_actions) + + return sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos + + + def _reviewers_feedback_sentence(self, all_sentences): + all_sentences = torch.stack(all_sentences).cpu().numpy().transpose(1, 0) + future_object_ids = [ + self._remote_reviewers[i].get_new_obs_and_feedback_sentence( + all_sentences[i]) for i in range(self.n_reviewers) + ] + + feedback_res = np.stack([future_object.get() for future_object in future_object_ids]) + + obs, rews, dones, infos = zip(*feedback_res.reshape(-1, 4)) + return _flatten_obs(obs, self.observation_space, self.n_reviewers), \ + np.stack(rews).reshape(self.n_reviewers, -1), np.stack(dones).reshape(self.n_reviewers, -1),\ + np.stack(infos).reshape(self.n_reviewers, -1) + def _feedback_one_step(self, actions): future_object_ids = [ diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py index 58e1efad6..12b6aa400 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py @@ -3,11 +3,29 @@ from .kl_controller import KLController from parl.utils import logger +from collections import OrderedDict +from .data_wrapper import TransitionInfo def dict_to_tensor(obs, device): return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} +def get_one_token_obs(obs, idx, space): + return OrderedDict([(k, obs[k][:, idx, :]) for k in space.spaces.keys()]) + +def unpack_observations(obs_tensor, n_envs): + """ + Unpacks vectorized dict observations into separate dict observations + """ + unpacked_obs = [] + keys = obs_tensor.keys() + for env_ix in range(n_envs): + obs_dict = {} + for key in keys: + obs_dict[key] = obs_tensor[key][env_ix].reshape(1, -1).cpu() + unpacked_obs.append(obs_dict) + return unpacked_obs + def add_to_buffer( rollout_buffer, episode_wise_transitions, rollout_info @@ -59,9 +77,8 @@ def add_to_buffer( class RolloutUtil: - def __init__(self, kl_args, reviewer_group): - self._kl_controller = KLController(kl_args["coeff"], - kl_args["target_kl"]) + def __init__(self, kl_args): + self._kl_controller = KLController(kl_args["coeff"], kl_args["target_kl"]) def _generate_batch( self, @@ -89,17 +106,94 @@ def _generate_batch( tokenizer=tokenizer, ) - episode_wise_transitions, num_timesteps = reviewer_group.feedback(current_obs=current_obs, - gen_output=gen_output, - kl_criterion=self._kl_controller, - agent=agent, - device=device) + review_times = 0 + episode_starts = np.ones((reviewer_group.n_reviewers,), dtype=bool) + # process them one step at a time to collect rollout info + episode_wise_transitions = [[] for _ in range(reviewer_group.n_reviewers)] + ep_terminated = np.zeros((reviewer_group.n_reviewers,), dtype=bool) + + sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos = reviewer_group.feedback_sentense( + gen_output=gen_output) + + for idx, actions_tensor in enumerate(gen_output.step_wise_actions): + if np.all(ep_terminated): + break + + # evaluate actions with actions from rollout + with torch.no_grad(): + obs_tensor = dict_to_tensor(current_obs, device) + + # get log probs (TBD: generalize this a bit) + policy_kwargs = { + "obs": obs_tensor, + "actions": actions_tensor, + } + + _, log_probs, _, _ = agent.forward_policy(**policy_kwargs) + + # sanity check + assert torch.all(torch.isfinite(log_probs)), "Infinite values in log probs" + + # get values + values, _ = agent.forward_value(obs_tensor) + + # get reference log probs + ref_log_probs, _ = agent.get_log_probs_ref_model(obs_tensor, actions_tensor) + + # sanity check + assert torch.all(torch.isfinite(ref_log_probs)), "Infinite values in log probs" + + # compute KL rewards + kl_div = log_probs - ref_log_probs + kl_rewards = -1 * self._kl_controller.kl_coeff * kl_div + + actions = actions_tensor.cpu().numpy() + rewards = sentence_rewards[:, idx] + dones = sentence_dones[:, idx] + new_obs = get_one_token_obs(sentence_new_obs, idx, reviewer_group.observation_space) + infos = sentence_infos[:, idx] + + review_times += reviewer_group.n_reviewers + + # compute total rewards + total_rewards = rewards + kl_rewards.cpu().numpy() + + # unpack individual observations + unpacked_obs = unpack_observations(obs_tensor, reviewer_group.n_reviewers) + + # store episode wise transitions separately + for env_ix in range(reviewer_group.n_reviewers): + # only if not terminated already + if not ep_terminated[env_ix]: + transtion = TransitionInfo( + observation=unpacked_obs[env_ix], + action=actions[env_ix], + task_reward=rewards[env_ix], + total_reward=total_rewards[env_ix], + kl_div=kl_div.cpu().numpy()[env_ix], + episode_start=episode_starts[env_ix], + value=values[env_ix].cpu(), + log_prob=log_probs[env_ix].cpu(), + done=dones[env_ix], + ref_log_prob=ref_log_probs[env_ix].cpu(), + kl_reward=kl_rewards.cpu().numpy()[env_ix], + info=infos[env_ix], + ) + + episode_wise_transitions[env_ix].append(transtion) + + # mark this episode to terminated if done occurs once + if dones[env_ix]: + ep_terminated[env_ix] = True + + episode_starts = np.zeros((reviewer_group.n_reviewers,), dtype=bool) + current_obs = new_obs # now we flush all episode wise info to the 1-D buffer rollout_info = add_to_buffer( rollout_buffer, episode_wise_transitions, rollout_info ) - return rollout_info, num_timesteps + return rollout_info, review_times def collect_rollouts( diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 175b2e810..37d1dcf7c 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -74,7 +74,7 @@ def main(config): gamma=agent.alg.gamma, gae_lambda=agent.alg.gae_lambda, ) - rollout_util = RolloutUtil(config["alg"]["kl_div"], reviewer_group) + rollout_util = RolloutUtil(config["alg"]["kl_div"]) n_iters = int(config["train_evaluation"]["n_iters"]) n_steps_per_iter = reviewer_group.n_reviewers * agent.alg.n_steps From b943f1c7538eead4c5a4234e3dd0fae4548a4cb7 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Fri, 10 Mar 2023 16:00:20 +0800 Subject: [PATCH 13/34] remove some '**config' and change rollout util --- benchmark/torch/RL4LMs/reviewer.py | 70 +----- benchmark/torch/RL4LMs/rl4lms_agent.py | 1 + benchmark/torch/RL4LMs/rl4lms_utils/buffer.py | 13 +- .../torch/RL4LMs/rl4lms_utils/rollout_util.py | 231 +++++++++--------- benchmark/torch/RL4LMs/seq2seq_model.py | 12 +- benchmark/torch/RL4LMs/t5_ppo.yml | 3 +- benchmark/torch/RL4LMs/train.py | 22 +- 7 files changed, 151 insertions(+), 201 deletions(-) diff --git a/benchmark/torch/RL4LMs/reviewer.py b/benchmark/torch/RL4LMs/reviewer.py index 4c2c94dcc..a02cdfd44 100644 --- a/benchmark/torch/RL4LMs/reviewer.py +++ b/benchmark/torch/RL4LMs/reviewer.py @@ -1,8 +1,8 @@ import gym from collections import OrderedDict import torch -from rl4lms_utils import TransitionInfo, Sample, Observation -from gym import Env, spaces +from rl4lms_utils import Observation +from gym import spaces from gym.spaces.dict import Dict as DictSpace from gym.spaces.discrete import Discrete import parl @@ -11,31 +11,17 @@ from rl4lms_utils import build_datapool, build_tokenizer, build_reward_fn def _flatten_obs(obs, space, n_reviewer=None): - assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" - assert len(obs) > 0, "need observations from at least one environment" - - if isinstance(space, gym.spaces.Dict): - assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" - assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" - if n_reviewer is not None: - return OrderedDict([(k, np.stack([o[k] for o in obs]).reshape((n_reviewer, -1, len(obs[0][k])))) for k in space.spaces.keys()]) - return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) - else: - raise NotImplementedError + if n_reviewer is not None: + return OrderedDict([(k, np.stack([o[k] for o in obs]).reshape((n_reviewer, -1, len(obs[0][k])))) for k in space.spaces.keys()]) + return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) def dict_to_tensor(obs, device): return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} - - - @parl.remote_class(wait=False) class Reviewer: def __init__( self, - tokenizer=None, - reward_function=None, - samples=None, reward_config=None, tokenizer_config=None, datapool_config=None, @@ -46,24 +32,17 @@ def __init__( prompt_truncation_side = "left", ): """ - A generic RL environment to generate textual sequences. - For eg: text generation, summarization, machine translation, text simplification + Reviewer who gives reward Args: - tokenizer (AutoTokenizer): pre-trained tokenizer - reward_function (RewardFunction): reward functiom - samples (Tuple[List[Sample], float]): list of samples max_episode_length (int, optional): Max steps to the model Defaults to 512. max_prompt_length (Optional[int], optional): maximum prompt length. Defaults to None. terminate_on_eos (bool, optional): whether to terminate on EOS. Defaults to False. context_start_token (bool, optional): start token for the context (For Encoder-Decoder models! ) prompt_truncation_side (str): truncation side for prompt text (Defaults to "left") """ - if tokenizer is None: - tokenizer = build_tokenizer(tokenizer_config) - if samples is None: - samples = build_datapool(datapool_config, remote_train=True)["train"] - if reward_function is None: - reward_function = build_reward_fn(reward_config) + tokenizer = build_tokenizer(tokenizer_config) + samples = build_datapool(datapool_config, remote_train=True)["train"] + reward_function = build_reward_fn(reward_config) self.tokenizer = tokenizer self.reward_function = reward_function self.max_steps = max_episode_length @@ -79,7 +58,6 @@ def __init__( self._vocab_size = tokenizer.vocab_size self.observation_space = DictSpace( { - # we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited # while creating rollout buffers, observations are concatenated for each key "prompt_or_input_encoded_pt": spaces.Box( low=0, high=self._vocab_size, shape=(self._max_text_length,) @@ -176,7 +154,7 @@ def get_new_obs_and_feedback_sentence(self, sentence): def ask(self, sample = None): """ - Resets the environment and starts a new episode + Reset the reviewer and starts a new episode """ # gets a new sample if not provided if sample is None: @@ -211,32 +189,22 @@ def __init__(self, tokenizer=None, datapool_config=None, tokenizer_config=None, - reward_fn=None, - question_samples=None, - seed = None, - start_index = 0, ): self.n_reviewers = reviewer_config["n_reviewers"] + # remote reviewers need to use config to initialize due to serialization problem reviewer_kwargs = { - # "reward_function": reward_fn, "reward_config": reward_config, - # "tokenizer": tokenizer, "tokenizer_config": tokenizer_config, - # "samples": question_samples, "datapool_config": datapool_config } reviewer_kwargs = {**reviewer_kwargs, **reviewer_config.get("args", {})} self.tokenizer = tokenizer self._remote_reviewers = self._create_reviewers(reviewer_kwargs, reviewer_config["parl_master_address"]) - # tem_future_object_ids = self._remote_reviewers[0].get_obs_and_action_space() - # self.observation_space, self.action_space = tem_future_object_ids.get() - # self.observation_space, self.action_space = tem_future_object_ids - # due to serialization, build obs space and action space here + # due to serialization problem, build obs space and action space here self._vocab_size = tokenizer.vocab_size self.observation_space = DictSpace( { - # we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited # while creating rollout buffers, observations are concatenated for each key "prompt_or_input_encoded_pt": spaces.Box( low=0, high=self._vocab_size, shape=(reviewer_kwargs["max_prompt_length"],) @@ -293,20 +261,6 @@ def _reviewers_feedback_sentence(self, all_sentences): np.stack(rews).reshape(self.n_reviewers, -1), np.stack(dones).reshape(self.n_reviewers, -1),\ np.stack(infos).reshape(self.n_reviewers, -1) - - def _feedback_one_step(self, actions): - future_object_ids = [ - self._remote_reviewers[i].get_new_obs_and_feedback_one_step( - actions[i]) for i in range(self.n_reviewers) - ] - feedback_res = [ - future_object.get() for future_object in future_object_ids - ] - # feedback_res = future_object_ids - obs, rews, dones, infos = zip(*feedback_res) - return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos - - def _create_reviewers(self, reviewer_kwargs, parl_port=None): parl.connect(parl_port, distributed_files=["./rl4lms_utils/*.py", "./*.py"]) return [Reviewer(**reviewer_kwargs) for _ in range(self.n_reviewers)] diff --git a/benchmark/torch/RL4LMs/rl4lms_agent.py b/benchmark/torch/RL4LMs/rl4lms_agent.py index 4e9357f14..c34f9039e 100644 --- a/benchmark/torch/RL4LMs/rl4lms_agent.py +++ b/benchmark/torch/RL4LMs/rl4lms_agent.py @@ -99,6 +99,7 @@ def get_inputs_for_generation(self, obs_tensor): def predict(self, *args, **kwargs): + # only use sample pass def forward_value( diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py index 1ac097f33..000d8eeb5 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py @@ -58,7 +58,6 @@ class DictRolloutBuffer: :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. :param gamma: Discount factor - :param n_envs: Number of parallel environments """ def __init__( @@ -130,7 +129,7 @@ def add(self, for key in self.observations.keys(): obs_ = np.array(obs[key]).copy() - # Reshape needed when using multiple envs with discrete observations + # Reshape needed when using multiple reviewers with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) if isinstance(self.observation_space.spaces[key], spaces.Discrete): obs_ = obs_.reshape((1,) + self.obs_shape[key]) @@ -161,8 +160,8 @@ def compute_returns_and_advantage(self, last_values, dones): For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375. - :param last_values: state value estimation for the last step (one for each env) - :param dones: if the last step was a terminal step (one bool for each env). + :param last_values: state value estimation for the last step (one for each reviewer) + :param dones: if the last step was a terminal step (one bool for each reviewer). """ # Convert to numpy last_values = last_values.clone().cpu().numpy().flatten() @@ -184,9 +183,9 @@ def compute_returns_and_advantage(self, last_values, dones): def swap_and_flatten(self, arr): """ - Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) - to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) - to [n_steps * n_envs, ...] (which maintain the order) + Swap and then flatten axes 0 (buffer_size) and 1 (n_reviewers) + to convert shape from [n_steps, n_reviewers, ...] (when ... is the shape of the features) + to [n_steps * n_reviewers, ...] (which maintain the order) :param arr: :return: diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py index 12b6aa400..55c4e48b8 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py @@ -13,16 +13,16 @@ def dict_to_tensor(obs, device): def get_one_token_obs(obs, idx, space): return OrderedDict([(k, obs[k][:, idx, :]) for k in space.spaces.keys()]) -def unpack_observations(obs_tensor, n_envs): +def unpack_observations(obs_tensor, n_reviewers): """ Unpacks vectorized dict observations into separate dict observations """ unpacked_obs = [] keys = obs_tensor.keys() - for env_ix in range(n_envs): + for reviewer_ix in range(n_reviewers): obs_dict = {} for key in keys: - obs_dict[key] = obs_tensor[key][env_ix].reshape(1, -1).cpu() + obs_dict[key] = obs_tensor[key][reviewer_ix].reshape(1, -1).cpu() unpacked_obs.append(obs_dict) return unpacked_obs @@ -80,42 +80,108 @@ class RolloutUtil: def __init__(self, kl_args): self._kl_controller = KLController(kl_args["coeff"], kl_args["target_kl"]) - def _generate_batch( + def collect_rollouts( self, - agent=None, - reviewer_group=None, - rollout_buffer=None, - tokenizer=None, - rollout_info=None, - device=None + agent, + reviewer_group, + rollout_buffer, + device ): - # if rollout buffer is already full, do not continue - if rollout_buffer.full: - return + # get tokenizer + tokenizer = reviewer_group.tokenizer + + # Switch to eval mode both training and testing + agent.eval_mode() + + # reset rollout buffer and stats + rollout_buffer.reset() + + # start the rollout process + rollout_info = { + "rollout_info/ep_rew": [], + "rollout_info/kl_div_mean": [], + "rollout_info/ep_lens": [], + "rollout_info/ep_kl_rew": [], + "rollout_info/log_prob": [], + "rollout_info/ref_log_prob": [], + "rollout_info/values": [], + } + num_timesteps = 0 + while not rollout_buffer.full: + # start parallel episodes + current_obs = reviewer_group.ask() + + # generate sentences using the model + obs_tensor = dict_to_tensor(current_obs, device) + generation_inputs = agent.get_inputs_for_generation(obs_tensor) + gen_output = agent.sample( + input_ids=generation_inputs.inputs, + attention_mask=generation_inputs.attention_masks, + tokenizer=tokenizer) + + # get episode state, reward, dones, infos from reviewers + sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos = reviewer_group.feedback_sentense( + gen_output=gen_output) + + # generate batch of rollouts and add to buffer + rollout_info, run_timesteps = self._generate_transition_and_add_to_buffer( + gen_sentence=gen_output, + init_obs=current_obs, + agent=agent, + n_reviewers=reviewer_group.n_reviewers, + obs_space=reviewer_group.observation_space, + sentence_new_obs=sentence_new_obs, + sentence_rewards=sentence_rewards, + sentence_dones=sentence_dones, + sentence_infos=sentence_infos, + rollout_buffer=rollout_buffer, + rollout_info=rollout_info, + device=device, + ) + num_timesteps += run_timesteps - # start parallel episodes - current_obs = reviewer_group.ask() + # aggregate rollout info + aggregated_rollout_info = {} + for key, values in rollout_info.items(): + aggregated_rollout_info[key] = np.mean(values).item() + aggregated_rollout_info[f"{key}_std"] = np.std(values).item() + aggregated_rollout_info[ + "rollout_info/kl_coeff" + ] = self._kl_controller.kl_coeff + logger.info(f"Rollout Info: {aggregated_rollout_info}") - # generate text using the model - obs_tensor = dict_to_tensor(current_obs, device) - generation_inputs = agent.get_inputs_for_generation(obs_tensor) - gen_output = agent.sample( - input_ids=generation_inputs.inputs, - attention_mask=generation_inputs.attention_masks, - tokenizer=tokenizer, + # adapt the KL coeff + self._kl_controller.step( + torch.tensor(aggregated_rollout_info["rollout_info/kl_div_mean"]) ) + return num_timesteps + + def _generate_transition_and_add_to_buffer( + self, + gen_sentence=None, + agent=None, + n_reviewers=None, + obs_space=None, + rollout_buffer=None, + rollout_info=None, + device=None, + sentence_new_obs=None, + sentence_rewards=None, + sentence_dones=None, + sentence_infos=None, + init_obs=None + ): + current_obs = init_obs review_times = 0 - episode_starts = np.ones((reviewer_group.n_reviewers,), dtype=bool) + episode_starts = np.ones((n_reviewers,), dtype=bool) # process them one step at a time to collect rollout info - episode_wise_transitions = [[] for _ in range(reviewer_group.n_reviewers)] - ep_terminated = np.zeros((reviewer_group.n_reviewers,), dtype=bool) + episode_wise_transitions = [[] for _ in range(n_reviewers)] + ep_terminated = np.zeros((n_reviewers,), dtype=bool) - sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos = reviewer_group.feedback_sentense( - gen_output=gen_output) - for idx, actions_tensor in enumerate(gen_output.step_wise_actions): + for idx, actions_tensor in enumerate(gen_sentence.step_wise_actions): if np.all(ep_terminated): break @@ -123,13 +189,7 @@ def _generate_batch( with torch.no_grad(): obs_tensor = dict_to_tensor(current_obs, device) - # get log probs (TBD: generalize this a bit) - policy_kwargs = { - "obs": obs_tensor, - "actions": actions_tensor, - } - - _, log_probs, _, _ = agent.forward_policy(**policy_kwargs) + _, log_probs, _, _ = agent.forward_policy(obs=obs_tensor, actions=actions_tensor) # sanity check assert torch.all(torch.isfinite(log_probs)), "Infinite values in log probs" @@ -150,43 +210,43 @@ def _generate_batch( actions = actions_tensor.cpu().numpy() rewards = sentence_rewards[:, idx] dones = sentence_dones[:, idx] - new_obs = get_one_token_obs(sentence_new_obs, idx, reviewer_group.observation_space) + new_obs = get_one_token_obs(sentence_new_obs, idx, obs_space) infos = sentence_infos[:, idx] - review_times += reviewer_group.n_reviewers + review_times += n_reviewers # compute total rewards total_rewards = rewards + kl_rewards.cpu().numpy() # unpack individual observations - unpacked_obs = unpack_observations(obs_tensor, reviewer_group.n_reviewers) + unpacked_obs = unpack_observations(obs_tensor, n_reviewers) # store episode wise transitions separately - for env_ix in range(reviewer_group.n_reviewers): + for reviewer_ix in range(n_reviewers): # only if not terminated already - if not ep_terminated[env_ix]: + if not ep_terminated[reviewer_ix]: transtion = TransitionInfo( - observation=unpacked_obs[env_ix], - action=actions[env_ix], - task_reward=rewards[env_ix], - total_reward=total_rewards[env_ix], - kl_div=kl_div.cpu().numpy()[env_ix], - episode_start=episode_starts[env_ix], - value=values[env_ix].cpu(), - log_prob=log_probs[env_ix].cpu(), - done=dones[env_ix], - ref_log_prob=ref_log_probs[env_ix].cpu(), - kl_reward=kl_rewards.cpu().numpy()[env_ix], - info=infos[env_ix], + observation=unpacked_obs[reviewer_ix], + action=actions[reviewer_ix], + task_reward=rewards[reviewer_ix], + total_reward=total_rewards[reviewer_ix], + kl_div=kl_div.cpu().numpy()[reviewer_ix], + episode_start=episode_starts[reviewer_ix], + value=values[reviewer_ix].cpu(), + log_prob=log_probs[reviewer_ix].cpu(), + done=dones[reviewer_ix], + ref_log_prob=ref_log_probs[reviewer_ix].cpu(), + kl_reward=kl_rewards.cpu().numpy()[reviewer_ix], + info=infos[reviewer_ix], ) - episode_wise_transitions[env_ix].append(transtion) + episode_wise_transitions[reviewer_ix].append(transtion) # mark this episode to terminated if done occurs once - if dones[env_ix]: - ep_terminated[env_ix] = True + if dones[reviewer_ix]: + ep_terminated[reviewer_ix] = True - episode_starts = np.zeros((reviewer_group.n_reviewers,), dtype=bool) + episode_starts = np.zeros((n_reviewers,), dtype=bool) current_obs = new_obs # now we flush all episode wise info to the 1-D buffer @@ -194,62 +254,3 @@ def _generate_batch( rollout_buffer, episode_wise_transitions, rollout_info ) return rollout_info, review_times - - - def collect_rollouts( - self, - agent, - reviewer_group, - rollout_buffer, - device - ): - # get tokenizer - tokenizer = reviewer_group.tokenizer - - # Switch to eval mode - # self._agent.alg.model.set_training_mode(False) - agent.eval_mode() - - # reset rollout buffer and stats - rollout_buffer.reset() - - # start the rollout process - rollout_info = { - "rollout_info/ep_rew": [], - "rollout_info/kl_div_mean": [], - "rollout_info/ep_lens": [], - "rollout_info/ep_kl_rew": [], - "rollout_info/log_prob": [], - "rollout_info/ref_log_prob": [], - "rollout_info/values": [], - } - num_timesteps = 0 - while not rollout_buffer.full: - # generate batch of rollouts - rollout_info, run_timesteps = self._generate_batch( - agent=agent, - reviewer_group=reviewer_group, - rollout_buffer=rollout_buffer, - tokenizer=tokenizer, - rollout_info=rollout_info, - device=device - ) - num_timesteps += run_timesteps - - # aggregate rollout info - aggregated_rollout_info = {} - for key, values in rollout_info.items(): - aggregated_rollout_info[key] = np.mean(values).item() - aggregated_rollout_info[f"{key}_std"] = np.std(values).item() - aggregated_rollout_info[ - "rollout_info/kl_coeff" - ] = self._kl_controller.kl_coeff - - logger.info(f"Rollout Info: {aggregated_rollout_info}") - - - # adapt the KL coeff - self._kl_controller.step( - torch.tensor(aggregated_rollout_info["rollout_info/kl_div_mean"]) - ) - return num_timesteps \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/seq2seq_model.py b/benchmark/torch/RL4LMs/seq2seq_model.py index 5be6782ad..05f6f4776 100644 --- a/benchmark/torch/RL4LMs/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/seq2seq_model.py @@ -14,14 +14,12 @@ ) - class Seq2SeqLMModel(parl.Model): def __init__( self, observation_space, action_space, model_name, - optimizer_kwargs = {}, weight_decay = 1e-6, apply_model_parallel = True, optimizer_class = torch.optim.AdamW, @@ -30,21 +28,18 @@ def __init__( device = None, ): super(Seq2SeqLMModel, self).__init__() - if optimizer_kwargs is None: - optimizer_kwargs = {} self.observation_space = observation_space self.action_space = action_space self.optimizer_class = optimizer_class - self.optimizer_kwargs = optimizer_kwargs self.optimizer = None self.device = device self._action_space = action_space self._apply_model_parallel = apply_model_parallel self._build_model_heads(model_name) - self._setup_optimizer(optimizer_kwargs, weight_decay, optimizer_class) + self._setup_optimizer(weight_decay, optimizer_class) self._generation_kwargs = generation_kwargs self._prompt_truncation_side = prompt_truncation_side @@ -397,7 +392,6 @@ def save(self, path): def _setup_optimizer( self, - optimizer_kwargs, weight_decay, optimizer_class, ): @@ -414,9 +408,7 @@ def _setup_optimizer( "weight_decay": 0.0, }, ] - self.optimizer = optimizer_class( - optimizer_grouped_parameters, **optimizer_kwargs - ) + self.optimizer = optimizer_class(optimizer_grouped_parameters) diff --git a/benchmark/torch/RL4LMs/t5_ppo.yml b/benchmark/torch/RL4LMs/t5_ppo.yml index 72fff2360..09375d66c 100644 --- a/benchmark/torch/RL4LMs/t5_ppo.yml +++ b/benchmark/torch/RL4LMs/t5_ppo.yml @@ -19,11 +19,10 @@ datapool: reviewer: parl_master_address: "localhost:8811" ## CHANGE FOR DEBUG ## -# n_envs: 10 +# n_reviewers: 10 n_reviewers: 2 ## CHANGE FOR DEBUG ## args: - max_prompt_length: 512 max_episode_length: 100 terminate_on_eos: True diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 37d1dcf7c..9423fcb0f 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -9,8 +9,7 @@ import torch import time -# env and reward function -from rl4lms_utils import build_reward_fn +# reviewer and reward function from reviewer import ReviewerGroup # evaluation, metrics, tokenizer & dataset @@ -41,29 +40,34 @@ def main(config): tokenizer = build_tokenizer(config["tokenizer"]) - # reward function & metrics - # reward_fn = build_reward_fn(config["reward_fn"]) # build reward_fn in reviewer + # metrics metrics = build_metrics(config["train_evaluation"]["metrics"]) # datapool samples_by_split = build_datapool(config["datapool"]) - reviewer_group = ReviewerGroup(reviewer_config=config["reviewer"], reward_config=config["reward_fn"], tokenizer=tokenizer, tokenizer_config=config["tokenizer"], datapool_config=config["datapool"],) - # reward_fn=reward_fn, - # question_samples=samples_by_split["train"]) rl4lms_model = Seq2SeqLMModel( observation_space = reviewer_group.observation_space, action_space= reviewer_group.action_space, device=device, - **config["alg"]["model"]["args"] + model_name=config["alg"]["model"]["args"]["model_name"], + apply_model_parallel=config["alg"]["model"]["args"]["apply_model_parallel"], + prompt_truncation_side=config["alg"]["model"]["args"]["prompt_truncation_side"], + generation_kwargs=config["alg"]["model"]["args"]["generation_kwargs"] ) - rl4lm_alg = RL4LMPPO(model=rl4lms_model, device=device, **config["alg"]["args"]) + rl4lm_alg = RL4LMPPO(model=rl4lms_model, + device=device, + n_steps=config["alg"]["args"]["n_steps"], + batch_size=config["alg"]["args"]["batch_size"], + learning_rate=config["alg"]["args"]["learning_rate"], + n_epochs=config["alg"]["args"]["n_epochs"], + ent_coef=config["alg"]["args"]["ent_coef"]) agent = RL4LMsAgent(rl4lm_alg, config["alg"]) rollout_buffer = DictRolloutBuffer( From 086ce6fec9c962fea5b8cfca027968fb5830d3c9 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 13 Mar 2023 16:36:59 +0800 Subject: [PATCH 14/34] use instructor instead of reviewer, add examiner --- benchmark/torch/RL4LMs/README.md | 4 + .../RL4LMs/{reviewer.py => instructor.py} | 59 ++++---- .../torch/RL4LMs/rl4lms_utils/__init__.py | 3 +- benchmark/torch/RL4LMs/rl4lms_utils/buffer.py | 12 +- .../RL4LMs/rl4lms_utils/evaluation_util.py | 120 ----------------- .../torch/RL4LMs/rl4lms_utils/examiner.py | 126 ++++++++++++++++++ .../torch/RL4LMs/rl4lms_utils/rollout_util.py | 68 +++++----- benchmark/torch/RL4LMs/t5_ppo.yml | 26 +--- benchmark/torch/RL4LMs/train.py | 75 +++++------ 9 files changed, 235 insertions(+), 258 deletions(-) rename benchmark/torch/RL4LMs/{reviewer.py => instructor.py} (81%) delete mode 100644 benchmark/torch/RL4LMs/rl4lms_utils/evaluation_util.py create mode 100644 benchmark/torch/RL4LMs/rl4lms_utils/examiner.py diff --git a/benchmark/torch/RL4LMs/README.md b/benchmark/torch/RL4LMs/README.md index 688e97f39..8e76ee439 100644 --- a/benchmark/torch/RL4LMs/README.md +++ b/benchmark/torch/RL4LMs/README.md @@ -16,5 +16,9 @@ ### Running command ```bash +# start xparl +xparl start --port 8811 --cpu_num 10 + +# start training python train.py --config_path t5_ppo.yml ``` \ No newline at end of file diff --git a/benchmark/torch/RL4LMs/reviewer.py b/benchmark/torch/RL4LMs/instructor.py similarity index 81% rename from benchmark/torch/RL4LMs/reviewer.py rename to benchmark/torch/RL4LMs/instructor.py index a02cdfd44..e19e7ca1c 100644 --- a/benchmark/torch/RL4LMs/reviewer.py +++ b/benchmark/torch/RL4LMs/instructor.py @@ -1,4 +1,3 @@ -import gym from collections import OrderedDict import torch from rl4lms_utils import Observation @@ -10,16 +9,16 @@ import numpy as np from rl4lms_utils import build_datapool, build_tokenizer, build_reward_fn -def _flatten_obs(obs, space, n_reviewer=None): - if n_reviewer is not None: - return OrderedDict([(k, np.stack([o[k] for o in obs]).reshape((n_reviewer, -1, len(obs[0][k])))) for k in space.spaces.keys()]) +def _flatten_obs(obs, space, n_instructor=None): + if n_instructor is not None: + return OrderedDict([(k, np.stack([o[k] for o in obs]).reshape((n_instructor, -1, len(obs[0][k])))) for k in space.spaces.keys()]) return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) def dict_to_tensor(obs, device): return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} @parl.remote_class(wait=False) -class Reviewer: +class Instructor: def __init__( self, reward_config=None, @@ -32,7 +31,7 @@ def __init__( prompt_truncation_side = "left", ): """ - Reviewer who gives reward + Instructor who gives reward Args: max_episode_length (int, optional): Max steps to the model Defaults to 512. max_prompt_length (Optional[int], optional): maximum prompt length. Defaults to None. @@ -154,7 +153,7 @@ def get_new_obs_and_feedback_sentence(self, sentence): def ask(self, sample = None): """ - Reset the reviewer and starts a new episode + Reset the instructor and starts a new episode """ # gets a new sample if not provided if sample is None: @@ -182,24 +181,24 @@ def get_obs_and_action_space(self): return (self.observation_space, self.action_space) -class ReviewerGroup: +class InstructorGroup: def __init__(self, - reviewer_config=None, + instructor_config=None, reward_config=None, tokenizer=None, datapool_config=None, tokenizer_config=None, ): - self.n_reviewers = reviewer_config["n_reviewers"] - # remote reviewers need to use config to initialize due to serialization problem - reviewer_kwargs = { + self.n_instructors = instructor_config["n_instructors"] + # remote instructors need to use config to initialize due to serialization problem + instructor_kwargs = { "reward_config": reward_config, "tokenizer_config": tokenizer_config, "datapool_config": datapool_config } - reviewer_kwargs = {**reviewer_kwargs, **reviewer_config.get("args", {})} + instructor_kwargs = {**instructor_kwargs, **instructor_config.get("args", {})} self.tokenizer = tokenizer - self._remote_reviewers = self._create_reviewers(reviewer_kwargs, reviewer_config["parl_master_address"]) + self._remote_instructors = self._create_instructors(instructor_kwargs, instructor_config["parl_master_address"]) # due to serialization problem, build obs space and action space here self._vocab_size = tokenizer.vocab_size @@ -207,24 +206,24 @@ def __init__(self, { # while creating rollout buffers, observations are concatenated for each key "prompt_or_input_encoded_pt": spaces.Box( - low=0, high=self._vocab_size, shape=(reviewer_kwargs["max_prompt_length"],) + low=0, high=self._vocab_size, shape=(instructor_kwargs["max_prompt_length"],) ), "prompt_or_input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(reviewer_kwargs["max_prompt_length"],) + low=0, high=1, shape=(instructor_kwargs["max_prompt_length"],) ), "context_encoded_pt": spaces.Box( - low=0, high=self._vocab_size, shape=(reviewer_kwargs["max_episode_length"],) + low=0, high=self._vocab_size, shape=(instructor_kwargs["max_episode_length"],) ), "context_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(reviewer_kwargs["max_episode_length"],) + low=0, high=1, shape=(instructor_kwargs["max_episode_length"],) ), "input_encoded_pt": spaces.Box( low=0, high=self._vocab_size, - shape=(reviewer_kwargs["max_prompt_length"] + reviewer_kwargs["max_episode_length"],), + shape=(instructor_kwargs["max_prompt_length"] + instructor_kwargs["max_episode_length"],), ), "input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(reviewer_kwargs["max_prompt_length"] + reviewer_kwargs["max_episode_length"],) + low=0, high=1, shape=(instructor_kwargs["max_prompt_length"] + instructor_kwargs["max_episode_length"],) ), } ) @@ -232,7 +231,7 @@ def __init__(self, def ask(self): future_object_ids = [ - remote_reviewer.ask() for remote_reviewer in self._remote_reviewers + remote_instructor.ask() for remote_instructor in self._remote_instructors ] sample_questions = [ future_object.get() for future_object in future_object_ids @@ -242,28 +241,28 @@ def ask(self): def feedback_sentense(self, gen_output): sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos = \ - self._reviewers_feedback_sentence(gen_output.step_wise_actions) + self._instructors_feedback_sentence(gen_output.step_wise_actions) return sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos - def _reviewers_feedback_sentence(self, all_sentences): + def _instructors_feedback_sentence(self, all_sentences): all_sentences = torch.stack(all_sentences).cpu().numpy().transpose(1, 0) future_object_ids = [ - self._remote_reviewers[i].get_new_obs_and_feedback_sentence( - all_sentences[i]) for i in range(self.n_reviewers) + self._remote_instructors[i].get_new_obs_and_feedback_sentence( + all_sentences[i]) for i in range(self.n_instructors) ] feedback_res = np.stack([future_object.get() for future_object in future_object_ids]) obs, rews, dones, infos = zip(*feedback_res.reshape(-1, 4)) - return _flatten_obs(obs, self.observation_space, self.n_reviewers), \ - np.stack(rews).reshape(self.n_reviewers, -1), np.stack(dones).reshape(self.n_reviewers, -1),\ - np.stack(infos).reshape(self.n_reviewers, -1) + return _flatten_obs(obs, self.observation_space, self.n_instructors), \ + np.stack(rews).reshape(self.n_instructors, -1), np.stack(dones).reshape(self.n_instructors, -1),\ + np.stack(infos).reshape(self.n_instructors, -1) - def _create_reviewers(self, reviewer_kwargs, parl_port=None): + def _create_instructors(self, instructor_kwargs, parl_port=None): parl.connect(parl_port, distributed_files=["./rl4lms_utils/*.py", "./*.py"]) - return [Reviewer(**reviewer_kwargs) for _ in range(self.n_reviewers)] + return [Instructor(**instructor_kwargs) for _ in range(self.n_instructors)] diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/__init__.py b/benchmark/torch/RL4LMs/rl4lms_utils/__init__.py index dc4965794..e884625f0 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/__init__.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/__init__.py @@ -1,14 +1,13 @@ from .data_wrapper import RefPolicyOutput, GenerationInputs, GenerationOutputs,\ PolicyType, Sample, Observation, TransitionInfo - from .huggingface_generation_util import override_generation_routines from .buffer import DictRolloutBuffer from .kl_controller import KLController -from .evaluation_util import evaluate_on_samples +from .examiner import Examiner from .data_pool import CNNDailyMail diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py index 000d8eeb5..a033722b9 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py @@ -129,7 +129,7 @@ def add(self, for key in self.observations.keys(): obs_ = np.array(obs[key]).copy() - # Reshape needed when using multiple reviewers with discrete observations + # Reshape needed when using multiple instructors with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) if isinstance(self.observation_space.spaces[key], spaces.Discrete): obs_ = obs_.reshape((1,) + self.obs_shape[key]) @@ -160,8 +160,8 @@ def compute_returns_and_advantage(self, last_values, dones): For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375. - :param last_values: state value estimation for the last step (one for each reviewer) - :param dones: if the last step was a terminal step (one bool for each reviewer). + :param last_values: state value estimation for the last step (one for each instructor) + :param dones: if the last step was a terminal step (one bool for each instructor). """ # Convert to numpy last_values = last_values.clone().cpu().numpy().flatten() @@ -183,9 +183,9 @@ def compute_returns_and_advantage(self, last_values, dones): def swap_and_flatten(self, arr): """ - Swap and then flatten axes 0 (buffer_size) and 1 (n_reviewers) - to convert shape from [n_steps, n_reviewers, ...] (when ... is the shape of the features) - to [n_steps * n_reviewers, ...] (which maintain the order) + Swap and then flatten axes 0 (buffer_size) and 1 (n_instructors) + to convert shape from [n_steps, n_instructors, ...] (when ... is the shape of the features) + to [n_steps * n_instructors, ...] (which maintain the order) :param arr: :return: diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/evaluation_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/evaluation_util.py deleted file mode 100644 index 28602724e..000000000 --- a/benchmark/torch/RL4LMs/rl4lms_utils/evaluation_util.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import Any, Dict, List - -from tqdm import tqdm -from transformers import AutoTokenizer - -from . import Sample -from parl.utils import logger - - -def get_batch(samples, batch_size): - current_ix = 0 - n_samples = len(samples) - while current_ix < n_samples: - current_batch = samples[current_ix : current_ix + batch_size] - yield current_batch - current_ix += batch_size - - -def evaluate_on_samples( - policy, - tokenizer, - samples, - batch_size, - max_prompt_length, - metrics, - epoch, - split_name, - dt_control_token = "", - gen_kwargs = None, -): - # generate text by batch - all_generated_texts = [] - all_ref_texts = [] - all_prompt_texts = [] - all_meta_infos = [] - ###########CHANGE FOR DEBUG############ - tem = [] - for i in range(100): - tem.append(samples[i]) - samples = tem - ###########CHANGE FOR DEBUG############ - - n_samples = len(samples) - for batch in tqdm(list(get_batch(samples, batch_size)), desc="Evaluating"): - batch_generated_texts = generate_text( - policy, tokenizer, batch, max_prompt_length, dt_control_token, gen_kwargs - ) - batch_ref_texts = [sample.references for sample in batch] - batch_prompt_texts = [sample.prompt_or_input_text for sample in batch] - batch_meta_infos = [sample.meta_data for sample in batch] - all_generated_texts.extend(batch_generated_texts) - all_ref_texts.extend(batch_ref_texts) - all_prompt_texts.extend(batch_prompt_texts) - all_meta_infos.extend(batch_meta_infos) - - # compute metrics - corpus_level_metrics = {} - sample_scores_by_metric = {} - if metrics is not None: - for metric in metrics: - metric_dict = metric.compute( - all_prompt_texts, - all_generated_texts, - all_ref_texts, - all_meta_infos, - policy.get_language_model(), - split_name, - ) - - for metric_key, (sample_scores, corpus_score) in metric_dict.items(): - if sample_scores is None: - sample_scores = ["n/a"] * n_samples - corpus_level_metrics[metric_key] = corpus_score - sample_scores_by_metric[metric_key] = sample_scores - - # aggregate sample metric scores - sample_predictions_dict = [] - for ix, (sample, prompt_text, generated_text, ref_texts) in enumerate( - zip(samples, all_prompt_texts, all_generated_texts, all_ref_texts) - ): - sample_prediction = { - "split_name": split_name, - "sample_id": sample.id, - "prompt_text": prompt_text, - "generated_text": generated_text, - "ref_text": "".join( - [ - f"" + ref_text + f"" - for ref_ix, ref_text in enumerate(ref_texts) - ] - ), - } - for metric_key, sample_scores in sample_scores_by_metric.items(): - sample_prediction[metric_key] = sample_scores[ix] - sample_predictions_dict.append(sample_prediction) - - metrics_dict_ = { - "epoch": epoch, - "metrics": corpus_level_metrics - } - - # logger - logger.info(f"{split_name} metrics: {metrics_dict_}") - - -def generate_text( - policy, - tokenizer, - samples, - max_prompt_length, - dt_control_token, - gen_kwargs, -): - prompt_texts = [ - dt_control_token + sample.prompt_or_input_text for sample in samples - ] - generated_texts = policy.sample( - tokenizer, prompt_texts, max_prompt_length, gen_kwargs=gen_kwargs - ).gen_texts - return generated_texts diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py b/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py new file mode 100644 index 000000000..58bd60fdc --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py @@ -0,0 +1,126 @@ +from tqdm import tqdm +from parl.utils import logger + + +# class for results evaluation +class Examiner: + def __init__(self, + tokenizer, + eval_batch_size, + metrics, + eval_gen_kwargs, + samples_by_split, + max_prompt_length + ): + self._tokenizer = tokenizer + self._batch_size = eval_batch_size + self._metrics = metrics + self._gen_kwargs = eval_gen_kwargs + self._samples_by_split = samples_by_split + self._max_prompt_length = max_prompt_length + + def evaluate(self, policy, sample_name_list, epoch): + for split_name in sample_name_list: + self._evaluate_on_samples(policy=policy, + epoch=epoch, + split_name=split_name) + + def _evaluate_on_samples( + self, + policy, + epoch, + split_name, + dt_control_token = "", + ): + samples = self._samples_by_split[split_name] + # generate text by batch + all_generated_texts = [] + all_ref_texts = [] + all_prompt_texts = [] + all_meta_infos = [] + + n_samples = len(samples) + for batch in tqdm(list(self._get_batch(samples, self._batch_size)), desc="Evaluating"): + batch_generated_texts = self._generate_text( + policy, self._tokenizer, batch, self._max_prompt_length, dt_control_token + ) + batch_ref_texts = [sample.references for sample in batch] + batch_prompt_texts = [sample.prompt_or_input_text for sample in batch] + batch_meta_infos = [sample.meta_data for sample in batch] + all_generated_texts.extend(batch_generated_texts) + all_ref_texts.extend(batch_ref_texts) + all_prompt_texts.extend(batch_prompt_texts) + all_meta_infos.extend(batch_meta_infos) + + # compute metrics + corpus_level_metrics = {} + sample_scores_by_metric = {} + if self._metrics is not None: + for metric in self._metrics: + metric_dict = metric.compute( + all_prompt_texts, + all_generated_texts, + all_ref_texts, + all_meta_infos, + policy.get_language_model(), + split_name, + ) + + for metric_key, (sample_scores, corpus_score) in metric_dict.items(): + if sample_scores is None: + sample_scores = ["n/a"] * n_samples + corpus_level_metrics[metric_key] = corpus_score + sample_scores_by_metric[metric_key] = sample_scores + + # aggregate sample metric scores + sample_predictions_dict = [] + for ix, (sample, prompt_text, generated_text, ref_texts) in enumerate( + zip(samples, all_prompt_texts, all_generated_texts, all_ref_texts) + ): + sample_prediction = { + "split_name": split_name, + "sample_id": sample.id, + "prompt_text": prompt_text, + "generated_text": generated_text, + "ref_text": "".join( + [ + f"" + ref_text + f"" + for ref_ix, ref_text in enumerate(ref_texts) + ] + ), + } + for metric_key, sample_scores in sample_scores_by_metric.items(): + sample_prediction[metric_key] = sample_scores[ix] + sample_predictions_dict.append(sample_prediction) + + metrics_dict_ = { + "epoch": epoch, + "metrics": corpus_level_metrics + } + + # logger + logger.info(f"{split_name} metrics: {metrics_dict_}") + + def _get_batch(self, samples, batch_size): + current_ix = 0 + n_samples = len(samples) + while current_ix < n_samples: + current_batch = samples[current_ix: current_ix + batch_size] + yield current_batch + current_ix += batch_size + + def _generate_text( + self, + policy, + tokenizer, + samples, + max_prompt_length, + dt_control_token, + ): + prompt_texts = [ + dt_control_token + sample.prompt_or_input_text for sample in samples + ] + generated_texts = policy.sample( + tokenizer, prompt_texts, max_prompt_length, gen_kwargs=self._gen_kwargs + ).gen_texts + return generated_texts diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py index 55c4e48b8..fa2d64fab 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py @@ -13,16 +13,16 @@ def dict_to_tensor(obs, device): def get_one_token_obs(obs, idx, space): return OrderedDict([(k, obs[k][:, idx, :]) for k in space.spaces.keys()]) -def unpack_observations(obs_tensor, n_reviewers): +def unpack_observations(obs_tensor, n_instructors): """ Unpacks vectorized dict observations into separate dict observations """ unpacked_obs = [] keys = obs_tensor.keys() - for reviewer_ix in range(n_reviewers): + for instructor_ix in range(n_instructors): obs_dict = {} for key in keys: - obs_dict[key] = obs_tensor[key][reviewer_ix].reshape(1, -1).cpu() + obs_dict[key] = obs_tensor[key][instructor_ix].reshape(1, -1).cpu() unpacked_obs.append(obs_dict) return unpacked_obs @@ -83,12 +83,12 @@ def __init__(self, kl_args): def collect_rollouts( self, agent, - reviewer_group, + instructor_group, rollout_buffer, device ): # get tokenizer - tokenizer = reviewer_group.tokenizer + tokenizer = instructor_group.tokenizer # Switch to eval mode both training and testing agent.eval_mode() @@ -109,7 +109,7 @@ def collect_rollouts( num_timesteps = 0 while not rollout_buffer.full: # start parallel episodes - current_obs = reviewer_group.ask() + current_obs = instructor_group.ask() # generate sentences using the model obs_tensor = dict_to_tensor(current_obs, device) @@ -119,8 +119,8 @@ def collect_rollouts( attention_mask=generation_inputs.attention_masks, tokenizer=tokenizer) - # get episode state, reward, dones, infos from reviewers - sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos = reviewer_group.feedback_sentense( + # get episode state, reward, dones, infos from instructors + sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos = instructor_group.feedback_sentense( gen_output=gen_output) # generate batch of rollouts and add to buffer @@ -128,8 +128,8 @@ def collect_rollouts( gen_sentence=gen_output, init_obs=current_obs, agent=agent, - n_reviewers=reviewer_group.n_reviewers, - obs_space=reviewer_group.observation_space, + n_instructors=instructor_group.n_instructors, + obs_space=instructor_group.observation_space, sentence_new_obs=sentence_new_obs, sentence_rewards=sentence_rewards, sentence_dones=sentence_dones, @@ -161,7 +161,7 @@ def _generate_transition_and_add_to_buffer( self, gen_sentence=None, agent=None, - n_reviewers=None, + n_instructors=None, obs_space=None, rollout_buffer=None, rollout_info=None, @@ -175,10 +175,10 @@ def _generate_transition_and_add_to_buffer( current_obs = init_obs review_times = 0 - episode_starts = np.ones((n_reviewers,), dtype=bool) + episode_starts = np.ones((n_instructors,), dtype=bool) # process them one step at a time to collect rollout info - episode_wise_transitions = [[] for _ in range(n_reviewers)] - ep_terminated = np.zeros((n_reviewers,), dtype=bool) + episode_wise_transitions = [[] for _ in range(n_instructors)] + ep_terminated = np.zeros((n_instructors,), dtype=bool) for idx, actions_tensor in enumerate(gen_sentence.step_wise_actions): @@ -213,40 +213,40 @@ def _generate_transition_and_add_to_buffer( new_obs = get_one_token_obs(sentence_new_obs, idx, obs_space) infos = sentence_infos[:, idx] - review_times += n_reviewers + review_times += n_instructors # compute total rewards total_rewards = rewards + kl_rewards.cpu().numpy() # unpack individual observations - unpacked_obs = unpack_observations(obs_tensor, n_reviewers) + unpacked_obs = unpack_observations(obs_tensor, n_instructors) # store episode wise transitions separately - for reviewer_ix in range(n_reviewers): + for instructor_ix in range(n_instructors): # only if not terminated already - if not ep_terminated[reviewer_ix]: + if not ep_terminated[instructor_ix]: transtion = TransitionInfo( - observation=unpacked_obs[reviewer_ix], - action=actions[reviewer_ix], - task_reward=rewards[reviewer_ix], - total_reward=total_rewards[reviewer_ix], - kl_div=kl_div.cpu().numpy()[reviewer_ix], - episode_start=episode_starts[reviewer_ix], - value=values[reviewer_ix].cpu(), - log_prob=log_probs[reviewer_ix].cpu(), - done=dones[reviewer_ix], - ref_log_prob=ref_log_probs[reviewer_ix].cpu(), - kl_reward=kl_rewards.cpu().numpy()[reviewer_ix], - info=infos[reviewer_ix], + observation=unpacked_obs[instructor_ix], + action=actions[instructor_ix], + task_reward=rewards[instructor_ix], + total_reward=total_rewards[instructor_ix], + kl_div=kl_div.cpu().numpy()[instructor_ix], + episode_start=episode_starts[instructor_ix], + value=values[instructor_ix].cpu(), + log_prob=log_probs[instructor_ix].cpu(), + done=dones[instructor_ix], + ref_log_prob=ref_log_probs[instructor_ix].cpu(), + kl_reward=kl_rewards.cpu().numpy()[instructor_ix], + info=infos[instructor_ix], ) - episode_wise_transitions[reviewer_ix].append(transtion) + episode_wise_transitions[instructor_ix].append(transtion) # mark this episode to terminated if done occurs once - if dones[reviewer_ix]: - ep_terminated[reviewer_ix] = True + if dones[instructor_ix]: + ep_terminated[instructor_ix] = True - episode_starts = np.zeros((n_reviewers,), dtype=bool) + episode_starts = np.zeros((n_instructors,), dtype=bool) current_obs = new_obs # now we flush all episode wise info to the 1-D buffer diff --git a/benchmark/torch/RL4LMs/t5_ppo.yml b/benchmark/torch/RL4LMs/t5_ppo.yml index 09375d66c..6d90889e7 100644 --- a/benchmark/torch/RL4LMs/t5_ppo.yml +++ b/benchmark/torch/RL4LMs/t5_ppo.yml @@ -16,12 +16,9 @@ datapool: prompt_prefix: "Summarize: " -reviewer: +instructor: parl_master_address: "localhost:8811" - ## CHANGE FOR DEBUG ## -# n_reviewers: 10 - n_reviewers: 2 - ## CHANGE FOR DEBUG ## + n_instructors: 10 args: max_prompt_length: 512 max_episode_length: 100 @@ -30,13 +27,9 @@ reviewer: context_start_token: 0 alg: - args: -# n_steps: 512 - #####CHNAGE FOR DEBUG######## - n_steps: 5 - #####CHANGE FOR DEBUG######## + args: + n_steps: 512 batch_size: 32 -# verbose: 1 learning_rate: 0.000002 n_epochs: 5 ent_coef: 0.0 @@ -69,19 +62,8 @@ train_evaluation: - id: bert_score args: language: en - # - id: bleurt - # args: - # config_name: bleurt-large-512 - id: diversity args: {} - # - id: summaCZS - # args: - # granularity: sentence - # use_ent: True - # use_con: False - # - id: summaCConv - # args: - # granularity: sentence generation_kwargs: do_sample: True top_k: 0 diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 9423fcb0f..33f6ad720 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -9,12 +9,12 @@ import torch import time -# reviewer and reward function -from reviewer import ReviewerGroup +# instructor and reward function +from instructor import InstructorGroup # evaluation, metrics, tokenizer & dataset from rl4lms_utils import build_metrics, build_tokenizer, build_datapool -from rl4lms_utils import evaluate_on_samples +from rl4lms_utils import Examiner # rollout from rl4lms_utils import DictRolloutBuffer, RolloutUtil @@ -46,15 +46,15 @@ def main(config): # datapool samples_by_split = build_datapool(config["datapool"]) - reviewer_group = ReviewerGroup(reviewer_config=config["reviewer"], + instructor_group = InstructorGroup(instructor_config=config["instructor"], reward_config=config["reward_fn"], tokenizer=tokenizer, tokenizer_config=config["tokenizer"], datapool_config=config["datapool"],) rl4lms_model = Seq2SeqLMModel( - observation_space = reviewer_group.observation_space, - action_space= reviewer_group.action_space, + observation_space=instructor_group.observation_space, + action_space=instructor_group.action_space, device=device, model_name=config["alg"]["model"]["args"]["model_name"], apply_model_parallel=config["alg"]["model"]["args"]["apply_model_parallel"], @@ -71,9 +71,9 @@ def main(config): agent = RL4LMsAgent(rl4lm_alg, config["alg"]) rollout_buffer = DictRolloutBuffer( - buffer_size=agent.alg.n_steps * reviewer_group.n_reviewers, - observation_space=reviewer_group.observation_space, - action_space=reviewer_group.action_space, + buffer_size=agent.alg.n_steps * instructor_group.n_instructors, + observation_space=instructor_group.observation_space, + action_space=instructor_group.action_space, device=device, gamma=agent.alg.gamma, gae_lambda=agent.alg.gae_lambda, @@ -81,26 +81,27 @@ def main(config): rollout_util = RolloutUtil(config["alg"]["kl_div"]) n_iters = int(config["train_evaluation"]["n_iters"]) - n_steps_per_iter = reviewer_group.n_reviewers * agent.alg.n_steps + n_steps_per_iter = instructor_group.n_instructors * agent.alg.n_steps - max_prompt_length = config["reviewer"]["args"]["max_prompt_length"] + max_prompt_length = config["instructor"]["args"]["max_prompt_length"] # gen kwargs for evaluation eval_gen_kwargs = config["train_evaluation"]["generation_kwargs"] eval_batch_size = config["train_evaluation"]["eval_batch_size"] - eval_splits = ["val", "test"] + examiner = Examiner( + tokenizer=tokenizer, + eval_batch_size=eval_batch_size, + metrics=metrics, + eval_gen_kwargs=eval_gen_kwargs, + samples_by_split=samples_by_split, + max_prompt_length=max_prompt_length + ) iter_start = 0 - for sp in eval_splits: - evaluate_on_samples(policy=agent.alg.model, - tokenizer=tokenizer, - samples=samples_by_split[sp], - batch_size=eval_batch_size, - max_prompt_length=max_prompt_length, - metrics=metrics, - epoch=iter_start, - split_name=sp, - gen_kwargs=eval_gen_kwargs) + examiner.evaluate(policy=agent.alg.model, + sample_name_list=["val", "test"], + epoch=iter_start) + epoch = 0 for epoch in range(iter_start, n_iters): print("========== BEGIN ==========") @@ -111,7 +112,7 @@ def main(config): num_timesteps = 0 while num_timesteps < n_steps_per_iter: - run_timesteps = rollout_util.collect_rollouts(agent, reviewer_group, rollout_buffer, device) + run_timesteps = rollout_util.collect_rollouts(agent, instructor_group, rollout_buffer, device) num_timesteps += run_timesteps agent.learn(rollout_buffer) @@ -124,27 +125,13 @@ def main(config): # evaluate on val set in the given intervals if (epoch + 1) % config["train_evaluation"]["eval_every"] == 0: - evaluate_on_samples(policy=agent.alg.model, - tokenizer=tokenizer, - samples=samples_by_split["val"], - batch_size=eval_batch_size, - max_prompt_length=max_prompt_length, - metrics=metrics, - epoch=epoch, - split_name="val", - gen_kwargs=eval_gen_kwargs) - - - for sp in eval_splits: - evaluate_on_samples(policy=agent.alg.model, - tokenizer=tokenizer, - samples=samples_by_split[sp], - batch_size=eval_batch_size, - max_prompt_length=max_prompt_length, - metrics=metrics, - epoch=epoch, - split_name=sp, - gen_kwargs=eval_gen_kwargs) + examiner.evaluate(policy=agent.alg.model, + sample_name_list=["val"], + epoch=epoch) + + examiner.evaluate(policy=agent.alg.model, + sample_name_list=["val", "test"], + epoch=epoch) if __name__ == '__main__': From 3acf2c3e76905f25ffa3a030e0391a98a4ba0364 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 13 Mar 2023 16:38:45 +0800 Subject: [PATCH 15/34] add requirements.txt --- benchmark/torch/RL4LMs/requirements.txt | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 benchmark/torch/RL4LMs/requirements.txt diff --git a/benchmark/torch/RL4LMs/requirements.txt b/benchmark/torch/RL4LMs/requirements.txt new file mode 100644 index 000000000..f5daa46d1 --- /dev/null +++ b/benchmark/torch/RL4LMs/requirements.txt @@ -0,0 +1,12 @@ +parl==2.1.1 +datasets==2.10.1 +PyYAML==6.0 +torch==1.11.0 +torchvision==0.12.0 +transformers==4.18.0 +charset-normalizer==3.0.1 +gym==0.21.0 +cchardet==2.1.7 +nltk==3.7 +gem-metrics @ git+https://github.com/GEM-benchmark/GEM-metrics.git@431a8174bd6b3637e8d6118bfad2983e39e99733 +bert-score==0.3.11 \ No newline at end of file From 090b19018a938bdd8a9a0281d19a745add039602 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 13 Mar 2023 16:55:35 +0800 Subject: [PATCH 16/34] change code style --- benchmark/torch/RL4LMs/instructor.py | 170 +++-- benchmark/torch/RL4LMs/rl4lm_ppo.py | 97 ++- benchmark/torch/RL4LMs/rl4lms_agent.py | 65 +- benchmark/torch/RL4LMs/rl4lms_utils/buffer.py | 51 +- .../rl4lms_utils/component_build_util.py | 26 +- .../torch/RL4LMs/rl4lms_utils/data_pool.py | 35 +- .../torch/RL4LMs/rl4lms_utils/data_wrapper.py | 112 ++-- .../torch/RL4LMs/rl4lms_utils/examiner.py | 73 +-- .../huggingface_generation_util.py | 589 +++++++----------- .../RL4LMs/rl4lms_utils/kl_controller.py | 2 +- .../torch/RL4LMs/rl4lms_utils/metric_util.py | 86 ++- .../torch/RL4LMs/rl4lms_utils/reward_util.py | 27 +- .../torch/RL4LMs/rl4lms_utils/rollout_util.py | 76 +-- benchmark/torch/RL4LMs/seq2seq_model.py | 202 ++---- benchmark/torch/RL4LMs/train.py | 67 +- 15 files changed, 686 insertions(+), 992 deletions(-) diff --git a/benchmark/torch/RL4LMs/instructor.py b/benchmark/torch/RL4LMs/instructor.py index e19e7ca1c..5f6c2e144 100644 --- a/benchmark/torch/RL4LMs/instructor.py +++ b/benchmark/torch/RL4LMs/instructor.py @@ -9,26 +9,30 @@ import numpy as np from rl4lms_utils import build_datapool, build_tokenizer, build_reward_fn + def _flatten_obs(obs, space, n_instructor=None): if n_instructor is not None: - return OrderedDict([(k, np.stack([o[k] for o in obs]).reshape((n_instructor, -1, len(obs[0][k])))) for k in space.spaces.keys()]) + return OrderedDict([(k, np.stack([o[k] for o in obs]).reshape((n_instructor, -1, len(obs[0][k])))) + for k in space.spaces.keys()]) return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) + def dict_to_tensor(obs, device): return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} + @parl.remote_class(wait=False) class Instructor: def __init__( - self, - reward_config=None, - tokenizer_config=None, - datapool_config=None, - max_episode_length = 512, - max_prompt_length = None, - terminate_on_eos = False, - context_start_token = None, - prompt_truncation_side = "left", + self, + reward_config=None, + tokenizer_config=None, + datapool_config=None, + max_episode_length=512, + max_prompt_length=None, + terminate_on_eos=False, + context_start_token=None, + prompt_truncation_side="left", ): """ Instructor who gives reward @@ -45,9 +49,7 @@ def __init__( self.tokenizer = tokenizer self.reward_function = reward_function self.max_steps = max_episode_length - self._max_text_length = ( - max_prompt_length if max_prompt_length else tokenizer.model_max_length - ) + self._max_text_length = (max_prompt_length if max_prompt_length else tokenizer.model_max_length) self._terminate_on_eos = terminate_on_eos self._context_start_token = context_start_token self._prompt_truncation_side = prompt_truncation_side @@ -55,31 +57,25 @@ def __init__( # set the observation and action space here self._vocab_size = tokenizer.vocab_size - self.observation_space = DictSpace( - { - # while creating rollout buffers, observations are concatenated for each key - "prompt_or_input_encoded_pt": spaces.Box( - low=0, high=self._vocab_size, shape=(self._max_text_length,) - ), - "prompt_or_input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self._max_text_length,) - ), - "context_encoded_pt": spaces.Box( - low=0, high=self._vocab_size, shape=(self.max_steps,) - ), - "context_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self.max_steps,) - ), - "input_encoded_pt": spaces.Box( - low=0, - high=self._vocab_size, - shape=(self._max_text_length + self.max_steps,), - ), - "input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self._max_text_length + self.max_steps,) - ), - } - ) + self.observation_space = DictSpace({ + # while creating rollout buffers, observations are concatenated for each key + "prompt_or_input_encoded_pt": + spaces.Box(low=0, high=self._vocab_size, shape=(self._max_text_length, )), + "prompt_or_input_attention_mask_pt": + spaces.Box(low=0, high=1, shape=(self._max_text_length, )), + "context_encoded_pt": + spaces.Box(low=0, high=self._vocab_size, shape=(self.max_steps, )), + "context_attention_mask_pt": + spaces.Box(low=0, high=1, shape=(self.max_steps, )), + "input_encoded_pt": + spaces.Box( + low=0, + high=self._vocab_size, + shape=(self._max_text_length + self.max_steps, ), + ), + "input_attention_mask_pt": + spaces.Box(low=0, high=1, shape=(self._max_text_length + self.max_steps, )), + }) self.action_space = Discrete(n=self._vocab_size) # see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency if 'mt5' in self.tokenizer.name_or_path: @@ -113,18 +109,17 @@ def get_new_obs_and_feedback_one_step(self, action): self.__current_obs = self.__current_obs.update(action, self.tokenizer) # decide if the episode is finished or not - done = (action == self.tokenizer.eos_token_id and self._terminate_on_eos) or ( - self.__time_step == self.max_steps - ) + done = (action == self.tokenizer.eos_token_id + and self._terminate_on_eos) or (self.__time_step == self.max_steps) # compute reward reward = self.reward_function( - previous_obs, - action, - self.__current_obs, - done, - self.__current_obs.meta_info, - ) + previous_obs, + action, + self.__current_obs, + done, + self.__current_obs.meta_info, + ) # populate additional info info = { @@ -151,7 +146,7 @@ def get_new_obs_and_feedback_sentence(self, sentence): res.append(one_step_res) return res - def ask(self, sample = None): + def ask(self, sample=None): """ Reset the instructor and starts a new episode """ @@ -182,13 +177,14 @@ def get_obs_and_action_space(self): class InstructorGroup: - def __init__(self, - instructor_config=None, - reward_config=None, - tokenizer=None, - datapool_config=None, - tokenizer_config=None, - ): + def __init__( + self, + instructor_config=None, + reward_config=None, + tokenizer=None, + datapool_config=None, + tokenizer_config=None, + ): self.n_instructors = instructor_config["n_instructors"] # remote instructors need to use config to initialize due to serialization problem instructor_kwargs = { @@ -202,40 +198,33 @@ def __init__(self, # due to serialization problem, build obs space and action space here self._vocab_size = tokenizer.vocab_size - self.observation_space = DictSpace( - { - # while creating rollout buffers, observations are concatenated for each key - "prompt_or_input_encoded_pt": spaces.Box( - low=0, high=self._vocab_size, shape=(instructor_kwargs["max_prompt_length"],) - ), - "prompt_or_input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(instructor_kwargs["max_prompt_length"],) - ), - "context_encoded_pt": spaces.Box( - low=0, high=self._vocab_size, shape=(instructor_kwargs["max_episode_length"],) - ), - "context_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(instructor_kwargs["max_episode_length"],) - ), - "input_encoded_pt": spaces.Box( - low=0, - high=self._vocab_size, - shape=(instructor_kwargs["max_prompt_length"] + instructor_kwargs["max_episode_length"],), - ), - "input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(instructor_kwargs["max_prompt_length"] + instructor_kwargs["max_episode_length"],) - ), - } - ) + self.observation_space = DictSpace({ + # while creating rollout buffers, observations are concatenated for each key + "prompt_or_input_encoded_pt": + spaces.Box(low=0, high=self._vocab_size, shape=(instructor_kwargs["max_prompt_length"], )), + "prompt_or_input_attention_mask_pt": + spaces.Box(low=0, high=1, shape=(instructor_kwargs["max_prompt_length"], )), + "context_encoded_pt": + spaces.Box(low=0, high=self._vocab_size, shape=(instructor_kwargs["max_episode_length"], )), + "context_attention_mask_pt": + spaces.Box(low=0, high=1, shape=(instructor_kwargs["max_episode_length"], )), + "input_encoded_pt": + spaces.Box( + low=0, + high=self._vocab_size, + shape=(instructor_kwargs["max_prompt_length"] + instructor_kwargs["max_episode_length"], ), + ), + "input_attention_mask_pt": + spaces.Box( + low=0, + high=1, + shape=(instructor_kwargs["max_prompt_length"] + instructor_kwargs["max_episode_length"], )), + }) self.action_space = Discrete(n=self._vocab_size) def ask(self): - future_object_ids = [ - remote_instructor.ask() for remote_instructor in self._remote_instructors - ] - sample_questions = [ - future_object.get() for future_object in future_object_ids - ] + future_object_ids = [remote_instructor.ask() for remote_instructor in self._remote_instructors] + sample_questions = [future_object.get() for future_object in future_object_ids] # sample_questions = future_object_ids return _flatten_obs(sample_questions, self.observation_space) @@ -245,12 +234,11 @@ def feedback_sentense(self, gen_output): return sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos - def _instructors_feedback_sentence(self, all_sentences): all_sentences = torch.stack(all_sentences).cpu().numpy().transpose(1, 0) future_object_ids = [ - self._remote_instructors[i].get_new_obs_and_feedback_sentence( - all_sentences[i]) for i in range(self.n_instructors) + self._remote_instructors[i].get_new_obs_and_feedback_sentence(all_sentences[i]) + for i in range(self.n_instructors) ] feedback_res = np.stack([future_object.get() for future_object in future_object_ids]) @@ -263,7 +251,3 @@ def _instructors_feedback_sentence(self, all_sentences): def _create_instructors(self, instructor_kwargs, parl_port=None): parl.connect(parl_port, distributed_files=["./rl4lms_utils/*.py", "./*.py"]) return [Instructor(**instructor_kwargs) for _ in range(self.n_instructors)] - - - - diff --git a/benchmark/torch/RL4LMs/rl4lm_ppo.py b/benchmark/torch/RL4LMs/rl4lm_ppo.py index 9a365d222..e5133b073 100644 --- a/benchmark/torch/RL4LMs/rl4lm_ppo.py +++ b/benchmark/torch/RL4LMs/rl4lm_ppo.py @@ -3,27 +3,29 @@ from gym import spaces from torch.nn import functional as F -from parl.algorithms.torch import PPO +from parl.algorithms.torch import PPO + class RL4LMPPO(parl.Algorithm): - def __init__(self, - model, - learning_rate = 3e-4, - n_steps = 2048, - batch_size = 64, - n_epochs = 10, - gamma = 0.99, - gae_lambda = 0.95, - clip_range = 0.2, - normalize_advantage = True, - ent_coef = 0.0, - vf_coef = 0.5, - max_grad_norm = 0.5, - target_kl = None, - seed = None, - device = "auto", - _init_setup_model = True, - ): + def __init__( + self, + model, + learning_rate=3e-4, + n_steps=2048, + batch_size=64, + n_epochs=10, + gamma=0.99, + gae_lambda=0.95, + clip_range=0.2, + normalize_advantage=True, + ent_coef=0.0, + vf_coef=0.5, + max_grad_norm=0.5, + target_kl=None, + seed=None, + device="auto", + _init_setup_model=True, + ): super(RL4LMPPO, self).__init__(model=model) self.learning_rate = learning_rate self.n_steps = n_steps @@ -58,14 +60,12 @@ def learn(self, rollout_buffer, log_info): # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() - values, log_prob, entropy = self.model.evaluate_actions(rollout_data.observations, actions) values = values.flatten() # Normalize advantage advantages = rollout_data.advantages if self.normalize_advantage: - advantages = (advantages - advantages.mean() - ) / (advantages.std() + 1e-8) + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = torch.exp(log_prob - rollout_data.old_log_prob) @@ -78,8 +78,7 @@ def learn(self, rollout_buffer, log_info): # Logging pg_losses.append(policy_loss.item()) - clip_fraction = torch.mean( - (torch.abs(ratio - 1) > self.clip_range).float()).item() + clip_fraction = torch.mean((torch.abs(ratio - 1) > self.clip_range).float()).item() clip_fractions.append(clip_fraction) # No clipping @@ -106,8 +105,7 @@ def learn(self, rollout_buffer, log_info): # and Schulman blog: http://joschu.net/blog/kl-approx.html with torch.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob - approx_kl_div = torch.mean( - (torch.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_div = torch.mean((torch.exp(log_ratio) - 1) - log_ratio).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: @@ -118,8 +116,7 @@ def learn(self, rollout_buffer, log_info): self.model.optimizer.zero_grad() loss.backward() # Clip grad norm - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.max_grad_norm) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) self.model.optimizer.step() return continue_training, loss @@ -131,46 +128,44 @@ def value(self, obs): pass def forward_value( - self, - obs, + self, + obs, ): return self.model.forward_value(obs) def forward_policy( - self, - obs, - actions, + self, + obs, + actions, ): return self.model.forward_policy( - obs = obs, - actions = actions, + obs=obs, + actions=actions, ) - def get_log_probs_ref_model( - self, - obs, - action, + self, + obs, + action, ): return self.model.get_log_probs_ref_model(obs, action) def sample( - self, - tokenizer, - texts = None, - max_prompt_length = None, - input_ids = None, - attention_mask = None, - gen_kwargs = None, + self, + tokenizer, + texts=None, + max_prompt_length=None, + input_ids=None, + attention_mask=None, + gen_kwargs=None, ): return self.model.sample( input_ids=input_ids, attention_mask=attention_mask, tokenizer=tokenizer, - texts = texts, - max_prompt_length = max_prompt_length, - gen_kwargs = gen_kwargs - ) + texts=texts, + max_prompt_length=max_prompt_length, + gen_kwargs=gen_kwargs) def eval_mode(self): - self.model.eval() \ No newline at end of file + self.model.eval() diff --git a/benchmark/torch/RL4LMs/rl4lms_agent.py b/benchmark/torch/RL4LMs/rl4lms_agent.py index c34f9039e..fb127024c 100644 --- a/benchmark/torch/RL4LMs/rl4lms_agent.py +++ b/benchmark/torch/RL4LMs/rl4lms_agent.py @@ -4,6 +4,7 @@ import torch from parl.utils import logger + def explained_variance(y_pred, y_true): """ Computes fraction of variance that ypred explains about y. @@ -20,11 +21,12 @@ def explained_variance(y_pred, y_true): class RL4LMsAgent(parl.Agent): - def __init__(self, - algorithm, - alg_config, - norm_reward = False, - ): + def __init__( + self, + algorithm, + alg_config, + norm_reward=False, + ): super(RL4LMsAgent, self).__init__(algorithm) self.dataset = None self.config = alg_config @@ -49,16 +51,13 @@ def learn(self, rollout_buffer): # train for n_epochs epochs for epoch in range(self.n_epochs): - continue_training, loss = self.alg.learn(rollout_buffer=rollout_buffer, - log_info=log_info) + continue_training, loss = self.alg.learn(rollout_buffer=rollout_buffer, log_info=log_info) if not continue_training: - print( - f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_divs[-1]:.2f}") + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_divs[-1]:.2f}") break self._n_updates += self.n_epochs - explained_var = explained_variance( - rollout_buffer.values.flatten(), rollout_buffer.returns.flatten()) + explained_var = explained_variance(rollout_buffer.values.flatten(), rollout_buffer.returns.flatten()) # Logs train_info = { @@ -80,12 +79,12 @@ def learn(self, rollout_buffer): # self._n_updates, exclude="tensorboard") # self.logger.record("train/clip_range", clip_range) train_info["train/n_updates"] = self._n_updates - train_info["train/clip_range"] = self.alg.clip_range + train_info["train/clip_range"] = self.alg.clip_range logger.info(train_info) ppo_train_info = { - "ppo/entropy_loss": np.mean(entropy_losses).item(), + "ppo/entropy_loss": np.mean(entropy_losses).item(), "ppo/policy_gradient_loss": np.mean(pg_losses).item(), "ppo/value_loss": np.mean(value_losses).item(), "ppo/approx_kl": np.mean(approx_kl_divs).item(), @@ -93,47 +92,44 @@ def learn(self, rollout_buffer): logger.info(ppo_train_info) - def get_inputs_for_generation(self, obs_tensor): return self.alg.model.get_inputs_for_generation(obs_tensor) - def predict(self, *args, **kwargs): # only use sample pass def forward_value( - self, - obs, + self, + obs, ): return self.alg.forward_value(obs) def forward_policy( - self, - obs, - actions, + self, + obs, + actions, ): return self.alg.forward_policy( - obs = obs, - actions = actions, + obs=obs, + actions=actions, ) - def get_log_probs_ref_model( - self, - obs, - action, + self, + obs, + action, ): return self.alg.get_log_probs_ref_model(obs, action) def sample( - self, - tokenizer, - texts = None, - max_prompt_length = None, - input_ids = None, - attention_mask = None, - gen_kwargs = None, + self, + tokenizer, + texts=None, + max_prompt_length=None, + input_ids=None, + attention_mask=None, + gen_kwargs=None, ): return self.alg.sample( input_ids=input_ids, @@ -141,8 +137,7 @@ def sample( tokenizer=tokenizer, texts=texts, max_prompt_length=max_prompt_length, - gen_kwargs=gen_kwargs - ) + gen_kwargs=gen_kwargs) def eval_mode(self): self.alg.eval_mode() diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py index a033722b9..b96da23a9 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py @@ -9,9 +9,8 @@ except ImportError: psutil = None -def get_obs_shape( - observation_space, -): + +def get_obs_shape(observation_space, ): """ Get the shape of the observation (useful for the buffers). @@ -22,13 +21,13 @@ def get_obs_shape( return observation_space.shape elif isinstance(observation_space, spaces.Discrete): # Observation is an int - return (1,) + return (1, ) elif isinstance(observation_space, spaces.MultiDiscrete): # Number of discrete features - return (int(len(observation_space.nvec)),) + return (int(len(observation_space.nvec)), ) elif isinstance(observation_space, spaces.MultiBinary): # Number of binary features - return (int(observation_space.n),) + return (int(observation_space.n), ) elif isinstance(observation_space, spaces.Dict): return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} @@ -61,13 +60,13 @@ class DictRolloutBuffer: """ def __init__( - self, - buffer_size, - observation_space, - action_space, - device = "cpu", - gae_lambda = 1, - gamma = 0.99, + self, + buffer_size, + observation_space, + action_space, + device="cpu", + gae_lambda=1, + gamma=0.99, ): self.buffer_size = buffer_size self.observation_space = observation_space @@ -105,13 +104,15 @@ def reset(self): self.pos = 0 self.full = False - def add(self, + def add( + self, obs, action, reward, episode_start, value, - log_prob,): + log_prob, + ): """ :param obs: Observation :param action: Action @@ -132,7 +133,7 @@ def add(self, # Reshape needed when using multiple instructors with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) if isinstance(self.observation_space.spaces[key], spaces.Discrete): - obs_ = obs_.reshape((1,) + self.obs_shape[key]) + obs_ = obs_.reshape((1, ) + self.obs_shape[key]) self.observations[key][self.pos] = obs_ self.actions[self.pos] = np.array(action).copy() @@ -192,7 +193,7 @@ def swap_and_flatten(self, arr): """ shape = arr.shape if len(shape) < 3: - shape = shape + (1,) + shape = shape + (1, ) return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) def get(self, batch_size): @@ -204,12 +205,10 @@ def get(self, batch_size): for key, obs in self.observations.items(): self.observations[key] = self.swap_and_flatten(obs) - _tensor_names = ["actions", "values", "log_probs", - "advantages", "returns"] + _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] for tensor in _tensor_names: - self.__dict__[tensor] = self.swap_and_flatten( - self.__dict__[tensor]) + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.generator_ready = True # Return everything, don't create minibatches @@ -218,10 +217,10 @@ def get(self, batch_size): start_idx = 0 while start_idx < self.buffer_size * 1: - yield self._get_samples(indices[start_idx: start_idx + batch_size]) + yield self._get_samples(indices[start_idx:start_idx + batch_size]) start_idx += batch_size - def to_torch(self, array, copy = True): + def to_torch(self, array, copy=True): """ Convert a numpy array to a PyTorch tensor. Note: it copies the data by default @@ -238,11 +237,11 @@ def to_torch(self, array, copy = True): def _get_samples(self, batch_inds): return DictRolloutBufferSamples( - observations={key: self.to_torch(obs[batch_inds]) for ( - key, obs) in self.observations.items()}, + observations={key: self.to_torch(obs[batch_inds]) + for (key, obs) in self.observations.items()}, actions=self.to_torch(self.actions[batch_inds]), old_values=self.to_torch(self.values[batch_inds].flatten()), old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), advantages=self.to_torch(self.advantages[batch_inds].flatten()), returns=self.to_torch(self.returns[batch_inds].flatten()), - ) \ No newline at end of file + ) diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py index 5ba3cf537..512244cdf 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py @@ -4,21 +4,18 @@ from .metric_util import MetricRegistry from .data_pool import CNNDailyMail + def build_tokenizer(tokenizer_config): logger.info(f"loading tokenizer of [{tokenizer_config['model_name']}] model") try: - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_config["model_name"]) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_config["model_name"]) except Exception: logger.info(f"trying to use local_files to load tokenizer of [{tokenizer_config['model_name']}] model") - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_config["model_name"], local_files_only=True) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_config["model_name"], local_files_only=True) if tokenizer.pad_token is None and tokenizer_config.get("pad_token_as_eos_token", True): tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = tokenizer_config.get( - "padding_side", "left") - tokenizer.truncation_side = tokenizer_config.get( - "truncation_side", "left") + tokenizer.padding_side = tokenizer_config.get("padding_side", "left") + tokenizer.truncation_side = tokenizer_config.get("truncation_side", "left") return tokenizer @@ -29,8 +26,9 @@ def build_reward_fn(reward_config): def build_metrics(metric_configs): - metrics = [MetricRegistry.get(metric_config["id"], metric_config.get("args", {})) - for metric_config in metric_configs] + metrics = [ + MetricRegistry.get(metric_config["id"], metric_config.get("args", {})) for metric_config in metric_configs + ] return metrics @@ -47,8 +45,7 @@ def _get_datapool_by_split(split): if remote_train: samples_by_split = { - "train": [(sample, weight) - for sample, weight in train_datapool], + "train": [(sample, weight) for sample, weight in train_datapool], } return samples_by_split @@ -56,11 +53,8 @@ def _get_datapool_by_split(split): test_datapool = _get_datapool_by_split("test") samples_by_split = { - "train": [(sample, weight) - for sample, weight in train_datapool], + "train": [(sample, weight) for sample, weight in train_datapool], "val": [sample for sample, _ in val_datapool], "test": [sample for sample, _ in test_datapool] } return samples_by_split - - diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py b/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py index d149d0e42..1720fe5e2 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py @@ -5,9 +5,7 @@ from nltk.tokenize import word_tokenize - class CNNDailyMail: - def __init__(self, samples): self._samples = samples @@ -21,37 +19,26 @@ def __getitem__(self, ix): return sample, 1.0 @classmethod - def prepare(cls, - split, - prompt_suffix = "", - prompt_prefix = "", - truncate_article = None, - max_size = None): - split2name = { - "train": "train", - "val": "validation", - "test": "test" - } + def prepare(cls, split, prompt_suffix="", prompt_prefix="", truncate_article=None, max_size=None): + split2name = {"train": "train", "val": "validation", "test": "test"} dataset = load_dataset("cnn_dailymail", "3.0.0") dataset_split = split2name[split] samples = [] - for ix, item in tqdm(enumerate(dataset[dataset_split]), - desc="Tokenizing dataset", - total=len(dataset[dataset_split])): + for ix, item in tqdm( + enumerate(dataset[dataset_split]), desc="Tokenizing dataset", total=len(dataset[dataset_split])): if truncate_article is not None: tokens = word_tokenize(item["article"]) tokens = tokens[:truncate_article] item["article"] = " ".join(tokens) - sample = Sample(id=f"{split}_{ix}", - prompt_or_input_text=prompt_prefix + - item["article"] + prompt_suffix, - references=[item["highlights"]] - ) + sample = Sample( + id=f"{split}_{ix}", + prompt_or_input_text=prompt_prefix + item["article"] + prompt_suffix, + references=[item["highlights"]]) samples.append(sample) - if max_size is not None and ix == (max_size-1): + if max_size is not None and ix == (max_size - 1): break pool_instance = cls(samples) @@ -67,6 +54,6 @@ def split(self, split_ratios): for ratio in split_ratios: count = int(len(self) * ratio) end_ix = start_ix + count - pools.append(type(self)(self._samples[start_ix: end_ix])) + pools.append(type(self)(self._samples[start_ix:end_ix])) start_ix = end_ix - return pools \ No newline at end of file + return pools diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py b/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py index 37235b247..745bb6152 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py @@ -45,7 +45,6 @@ class Sample: meta_data: Dict[str, Any] = None - class PolicyType(Enum): CAUSAL = 0 SEQ2SEQ = 1 @@ -123,13 +122,11 @@ def to_dict(self): "context_attention_mask_pt": self.context_attention_mask_pt.numpy().flatten(), "input_encoded_pt": self.input_encoded_pt.numpy().flatten(), "input_attention_mask_pt": self.input_attention_mask_pt.numpy().flatten() - } return dict_obs @staticmethod - def _concat(prompt: torch.tensor, prompt_mask: torch.tensor, - context: torch.tensor, context_mask: torch.tensor, + def _concat(prompt: torch.tensor, prompt_mask: torch.tensor, context: torch.tensor, context_mask: torch.tensor, pad_token: int): prompt_ = prompt[:, prompt_mask.flatten().bool().tolist()] @@ -137,14 +134,11 @@ def _concat(prompt: torch.tensor, prompt_mask: torch.tensor, actual_size = prompt_.shape[1] + context_.shape[1] full_size = prompt.shape[1] + context.shape[1] - concatenated = torch.full( - (full_size,), fill_value=pad_token).reshape(1, -1) + concatenated = torch.full((full_size, ), fill_value=pad_token).reshape(1, -1) concatenated_mask = torch.zeros((1, full_size)).int() - concatenated[:, full_size - - actual_size:] = torch.cat((prompt_, context_), dim=1) - concatenated_mask[:, full_size - - actual_size:] = 1 + concatenated[:, full_size - actual_size:] = torch.cat((prompt_, context_), dim=1) + concatenated_mask[:, full_size - actual_size:] = 1 return concatenated, concatenated_mask def update(self, action: int, tokenizer: AutoTokenizer): @@ -158,45 +152,35 @@ def update(self, action: int, tokenizer: AutoTokenizer): # get the current context current_context = deepcopy(self.context_encoded_pt) - current_context_attention_mask = deepcopy( - self.context_attention_mask_pt) + current_context_attention_mask = deepcopy(self.context_attention_mask_pt) # just shift the context (also the attention mask) to left by 1 current_context[:, 0:-1] = current_context[:, 1:].clone() - current_context_attention_mask[:, 0:- - 1] = current_context_attention_mask[:, 1:].clone() + current_context_attention_mask[:, 0:-1] = current_context_attention_mask[:, 1:].clone() # add the action always at the end (assumes left padding) current_context[:, -1] = action current_context_attention_mask[:, -1] = 1 # decode the context - context_text = tokenizer.decode( - current_context.flatten(), skip_special_tokens=True) + context_text = tokenizer.decode(current_context.flatten(), skip_special_tokens=True) # concatenate and still keep the left padding input_encoded_pt, input_attention_mask_pt = Observation._concat( - self.prompt_or_input_encoded_pt, self.prompt_or_input_attention_mask_pt, - current_context, current_context_attention_mask, - tokenizer.pad_token_id) + self.prompt_or_input_encoded_pt, self.prompt_or_input_attention_mask_pt, current_context, + current_context_attention_mask, tokenizer.pad_token_id) # and create a new observation - obs = Observation(self.prompt_or_input_encoded_pt, - self.prompt_or_input_attention_mask_pt, - self.prompt_or_input_text, - current_context, - current_context_attention_mask, - context_text, - self.target_or_reference_texts, - input_encoded_pt, - input_attention_mask_pt, - current_action_history, - self.meta_info) + obs = Observation(self.prompt_or_input_encoded_pt, self.prompt_or_input_attention_mask_pt, + self.prompt_or_input_text, current_context, current_context_attention_mask, context_text, + self.target_or_reference_texts, input_encoded_pt, input_attention_mask_pt, + current_action_history, self.meta_info) return obs @classmethod - def init_from_sample(cls, sample: Sample, + def init_from_sample(cls, + sample: Sample, tokenizer: AutoTokenizer, max_input_length: int, max_context_length: int, @@ -207,49 +191,51 @@ def init_from_sample(cls, sample: Sample, # override truncation side for prompt prev_truncation_side = tokenizer.truncation_side tokenizer.truncation_side = prompt_truncation_side - prompt_outputs = tokenizer(sample.prompt_or_input_text, - padding="max_length", - max_length=max_input_length, - return_tensors="pt", - return_attention_mask=True, - truncation=True) + prompt_outputs = tokenizer( + sample.prompt_or_input_text, + padding="max_length", + max_length=max_input_length, + return_tensors="pt", + return_attention_mask=True, + truncation=True) tokenizer.truncation_side = prev_truncation_side # for seq2seq models, context should be initialized to start token if provided if context_start_token is not None: - context_outputs = tokenizer("", - padding="max_length", - max_length=max_context_length, - return_tensors="pt", - return_attention_mask=True) + context_outputs = tokenizer( + "", + padding="max_length", + max_length=max_context_length, + return_tensors="pt", + return_attention_mask=True) context_outputs.input_ids = torch.ones(1, max_context_length, dtype=torch.int32) * tokenizer.pad_token_id context_outputs.input_ids[:, -1] = context_start_token context_outputs.attention_mask = torch.zeros(1, max_context_length, dtype=torch.int32) context_outputs.attention_mask[:, -1] = 1 else: - context_outputs = tokenizer("", - padding="max_length", - max_length=max_context_length, - return_tensors="pt", - return_attention_mask=True) + context_outputs = tokenizer( + "", + padding="max_length", + max_length=max_context_length, + return_tensors="pt", + return_attention_mask=True) # concatenate input_encoded_pt, input_attention_mask_pt = Observation._concat( - prompt_outputs.input_ids, prompt_outputs.attention_mask, - context_outputs.input_ids, context_outputs.attention_mask, - tokenizer.pad_token_id) - - obs = Observation(prompt_or_input_encoded_pt=prompt_outputs.input_ids, - prompt_or_input_attention_mask_pt=prompt_outputs.attention_mask, - prompt_or_input_text=sample.prompt_or_input_text, - context_encoded_pt=context_outputs.input_ids, - context_attention_mask_pt=context_outputs.attention_mask, - input_encoded_pt=input_encoded_pt, - input_attention_mask_pt=input_attention_mask_pt, - context_text="", - target_or_reference_texts=sample.references, - action_history=[], - meta_info=meta_info) + prompt_outputs.input_ids, prompt_outputs.attention_mask, context_outputs.input_ids, + context_outputs.attention_mask, tokenizer.pad_token_id) + + obs = Observation( + prompt_or_input_encoded_pt=prompt_outputs.input_ids, + prompt_or_input_attention_mask_pt=prompt_outputs.attention_mask, + prompt_or_input_text=sample.prompt_or_input_text, + context_encoded_pt=context_outputs.input_ids, + context_attention_mask_pt=context_outputs.attention_mask, + input_encoded_pt=input_encoded_pt, + input_attention_mask_pt=input_attention_mask_pt, + context_text="", + target_or_reference_texts=sample.references, + action_history=[], + meta_info=meta_info) return obs - diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py b/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py index 58bd60fdc..c231c18a1 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py @@ -4,14 +4,7 @@ # class for results evaluation class Examiner: - def __init__(self, - tokenizer, - eval_batch_size, - metrics, - eval_gen_kwargs, - samples_by_split, - max_prompt_length - ): + def __init__(self, tokenizer, eval_batch_size, metrics, eval_gen_kwargs, samples_by_split, max_prompt_length): self._tokenizer = tokenizer self._batch_size = eval_batch_size self._metrics = metrics @@ -21,17 +14,15 @@ def __init__(self, def evaluate(self, policy, sample_name_list, epoch): for split_name in sample_name_list: - self._evaluate_on_samples(policy=policy, - epoch=epoch, - split_name=split_name) + self._evaluate_on_samples(policy=policy, epoch=epoch, split_name=split_name) def _evaluate_on_samples( self, policy, epoch, split_name, - dt_control_token = "", - ): + dt_control_token="", + ): samples = self._samples_by_split[split_name] # generate text by batch all_generated_texts = [] @@ -41,9 +32,8 @@ def _evaluate_on_samples( n_samples = len(samples) for batch in tqdm(list(self._get_batch(samples, self._batch_size)), desc="Evaluating"): - batch_generated_texts = self._generate_text( - policy, self._tokenizer, batch, self._max_prompt_length, dt_control_token - ) + batch_generated_texts = self._generate_text(policy, self._tokenizer, batch, self._max_prompt_length, + dt_control_token) batch_ref_texts = [sample.references for sample in batch] batch_prompt_texts = [sample.prompt_or_input_text for sample in batch] batch_meta_infos = [sample.meta_data for sample in batch] @@ -75,28 +65,26 @@ def _evaluate_on_samples( # aggregate sample metric scores sample_predictions_dict = [] for ix, (sample, prompt_text, generated_text, ref_texts) in enumerate( - zip(samples, all_prompt_texts, all_generated_texts, all_ref_texts) - ): + zip(samples, all_prompt_texts, all_generated_texts, all_ref_texts)): sample_prediction = { - "split_name": split_name, - "sample_id": sample.id, - "prompt_text": prompt_text, - "generated_text": generated_text, - "ref_text": "".join( - [ - f"" + ref_text + f"" - for ref_ix, ref_text in enumerate(ref_texts) - ] - ), + "split_name": + split_name, + "sample_id": + sample.id, + "prompt_text": + prompt_text, + "generated_text": + generated_text, + "ref_text": + "".join([ + f"" + ref_text + f"" for ref_ix, ref_text in enumerate(ref_texts) + ]), } for metric_key, sample_scores in sample_scores_by_metric.items(): sample_prediction[metric_key] = sample_scores[ix] sample_predictions_dict.append(sample_prediction) - metrics_dict_ = { - "epoch": epoch, - "metrics": corpus_level_metrics - } + metrics_dict_ = {"epoch": epoch, "metrics": corpus_level_metrics} # logger logger.info(f"{split_name} metrics: {metrics_dict_}") @@ -105,22 +93,19 @@ def _get_batch(self, samples, batch_size): current_ix = 0 n_samples = len(samples) while current_ix < n_samples: - current_batch = samples[current_ix: current_ix + batch_size] + current_batch = samples[current_ix:current_ix + batch_size] yield current_batch current_ix += batch_size def _generate_text( - self, - policy, - tokenizer, - samples, - max_prompt_length, - dt_control_token, + self, + policy, + tokenizer, + samples, + max_prompt_length, + dt_control_token, ): - prompt_texts = [ - dt_control_token + sample.prompt_or_input_text for sample in samples - ] + prompt_texts = [dt_control_token + sample.prompt_or_input_text for sample in samples] generated_texts = policy.sample( - tokenizer, prompt_texts, max_prompt_length, gen_kwargs=self._gen_kwargs - ).gen_texts + tokenizer, prompt_texts, max_prompt_length, gen_kwargs=self._gen_kwargs).gen_texts return generated_texts diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py index 2666846e9..0c8a85825 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py @@ -52,9 +52,9 @@ ) from transformers.utils import ModelOutput, logging - logger = logging.get_logger(__name__) + @dataclass class SampleEncoderDecoderOutput(ModelOutput): """ @@ -98,7 +98,6 @@ class SampleEncoderDecoderOutput(ModelOutput): decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - class GenerationMixinWithRawScores: """ A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. @@ -119,37 +118,31 @@ class GenerationMixinWithRawScores: """ def _prepare_model_inputs( - self, - inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[int] = None, - model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: """ This function extracts the model-specific `inputs` for generation. """ # 1. retrieve all kwargs that are non-None or non-model input related. # some encoder-decoder models have different names for model and encoder - if ( - self.config.is_encoder_decoder - and hasattr(self, "encoder") - and self.encoder.main_input_name != self.main_input_name - ): + if (self.config.is_encoder_decoder and hasattr(self, "encoder") + and self.encoder.main_input_name != self.main_input_name): input_name = self.encoder.main_input_name else: input_name = self.main_input_name - model_kwargs = {k: v for k, v in model_kwargs.items( - ) if v is not None or k != input_name} + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} # 2. check whether model_input_name is passed as kwarg # if yes and `inputs` is None use kwarg inputs inputs_kwarg = model_kwargs.pop(input_name, None) if inputs_kwarg is not None and inputs is not None: - raise ValueError( - f"`inputs`: {inputs}` were passed alongside " - f"{input_name} which is not allowed." - f"Make sure to either pass {inputs} or {input_name}=..." - ) + raise ValueError(f"`inputs`: {inputs}` were passed alongside " + f"{input_name} which is not allowed." + f"Make sure to either pass {inputs} or {input_name}=...") elif inputs_kwarg is not None: inputs = inputs_kwarg @@ -159,33 +152,27 @@ def _prepare_model_inputs( # 4. Only encoder-decoder models can have non `input_ids` input format if not self.config.is_encoder_decoder and input_name != "input_ids": - raise ValueError( - f"If {input_name} is passed as model-specific keyword " - "input then model has to be an encoder-decoder and not a " - f"{self.__class__.__name__}." - ) + raise ValueError(f"If {input_name} is passed as model-specific keyword " + "input then model has to be an encoder-decoder and not a " + f"{self.__class__.__name__}.") # 5. if `inputs` is still None, try to create `input_ids` from BOS token if inputs is None: - inputs = self._prepare_input_ids_for_generation( - bos_token_id, model_kwargs.get("encoder_outputs")) + inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) return inputs, input_name, model_kwargs - def _can_retrieve_inputs_from_name( - self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor] - ) -> torch.Tensor: + def _can_retrieve_inputs_from_name(self, inputs: Optional[torch.Tensor], name: str, + model_kwargs: Dict[str, torch.Tensor]) -> torch.Tensor: """ If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved from name """ can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set( - inspect.signature(self.forward).parameters.keys() - ) + inspect.signature(self.forward).parameters.keys()) if can_retrieve_inputs and inputs is not None: - raise ValueError( - f"Cannot only pass one of {name} and {self.main_input_name}") + raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}") return can_retrieve_inputs @@ -195,41 +182,37 @@ def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) - """ return {"input_ids": input_ids} - def _prepare_input_ids_for_generation( - self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput] - ) -> torch.LongTensor: + def _prepare_input_ids_for_generation(self, bos_token_id: Optional[int], + encoder_outputs: Optional[ModelOutput]) -> torch.LongTensor: if self.config.is_encoder_decoder and encoder_outputs is not None: # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding shape = encoder_outputs.last_hidden_state.size()[:-1] return torch.ones(shape, dtype=torch.long, device=self.device) * -100 if bos_token_id is None: - raise ValueError( - "`bos_token_id` has to be defined when no `input_ids` are provided.") + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id def _prepare_attention_mask_for_generation( - self, - inputs: torch.Tensor, - pad_token_id: int, - eos_token_id: int, + self, + inputs: torch.Tensor, + pad_token_id: int, + eos_token_id: int, ) -> torch.LongTensor: - is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [ - torch.int, torch.long] - is_pad_token_in_inputs = (pad_token_id is not None) and ( - pad_token_id in inputs) - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( - (eos_token_id is not None) and (pad_token_id != eos_token_id) - ) + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] + is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ((eos_token_id is not None) and + (pad_token_id != eos_token_id)) # Check if input is input_ids and padded -> only then is attention_mask defined if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: return inputs.ne(pad_token_id).long() else: return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device) - def _prepare_encoder_decoder_kwargs_for_generation( - self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None - ) -> Dict[str, Any]: + def _prepare_encoder_decoder_kwargs_for_generation(self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str] = None) -> Dict[str, Any]: # 1. get encoder encoder = self.get_encoder() @@ -237,99 +220,80 @@ def _prepare_encoder_decoder_kwargs_for_generation( irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] encoder_kwargs = { argument: value - for argument, value in model_kwargs.items() - if not any(argument.startswith(p) for p in irrelevant_prefix) + for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) } # 3. make sure that encoder returns `ModelOutput` model_input_name = model_input_name if model_input_name is not None else self.main_input_name encoder_kwargs["return_dict"] = True encoder_kwargs[model_input_name] = inputs_tensor - model_kwargs["encoder_outputs"]: ModelOutput = encoder( - **encoder_kwargs) + model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) return model_kwargs def _prepare_decoder_input_ids_for_generation( - self, - batch_size: int, - decoder_start_token_id: int = None, - bos_token_id: int = None, - model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + self, + batch_size: int, + decoder_start_token_id: int = None, + bos_token_id: int = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.LongTensor: if model_kwargs is not None and "decoder_input_ids" in model_kwargs: return model_kwargs.pop("decoder_input_ids") else: - decoder_start_token_id = self._get_decoder_start_token_id( - decoder_start_token_id, bos_token_id) + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: - decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id - ) + decoder_start_token_id = (decoder_start_token_id + if decoder_start_token_id is not None else self.config.decoder_start_token_id) bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "decoder_start_token_id") - and self.config.decoder.decoder_start_token_id is not None - ): + elif (hasattr(self.config, "decoder") and hasattr(self.config.decoder, "decoder_start_token_id") + and self.config.decoder.decoder_start_token_id is not None): return self.config.decoder.decoder_start_token_id elif bos_token_id is not None: return bos_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "bos_token_id") - and self.config.decoder.bos_token_id is not None - ): + elif (hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id") + and self.config.decoder.bos_token_id is not None): return self.config.decoder.bos_token_id - raise ValueError( - "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." - ) + raise ValueError("`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation.") @staticmethod def _expand_inputs_for_generation( - input_ids: torch.LongTensor, - expand_size: int = 1, - is_encoder_decoder: bool = False, - attention_mask: Optional[torch.LongTensor] = None, - encoder_outputs: Optional[ModelOutput] = None, - **model_kwargs, + input_ids: torch.LongTensor, + expand_size: int = 1, + is_encoder_decoder: bool = False, + attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[ModelOutput] = None, + **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: - expanded_return_idx = ( - torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, - expand_size).view(-1).to(input_ids.device) - ) + expanded_return_idx = (torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to( + input_ids.device)) input_ids = input_ids.index_select(0, expanded_return_idx) if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = token_type_ids.index_select( - 0, expanded_return_idx) + model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) if attention_mask is not None: - model_kwargs["attention_mask"] = attention_mask.index_select( - 0, expanded_return_idx) + model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) if is_encoder_decoder: if encoder_outputs is None: - raise ValueError( - "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( - 0, expanded_return_idx.to( - encoder_outputs.last_hidden_state.device) - ) + 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)) model_kwargs["encoder_outputs"] = encoder_outputs return input_ids, model_kwargs @staticmethod - def _update_model_kwargs_for_generation( - outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False - ) -> Dict[str, Any]: + def _update_model_kwargs_for_generation(outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False) -> Dict[str, Any]: # update past if "past_key_values" in outputs: model_kwargs["past"] = outputs.past_key_values @@ -343,16 +307,14 @@ def _update_model_kwargs_for_generation( # update token_type_ids with last value if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) # update attention mask if not is_encoder_decoder: if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) return model_kwargs @@ -362,12 +324,12 @@ def _reorder_cache(self, past, beam_idx): ) def _get_logits_warper( - self, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - typical_p: Optional[float] = None, - temperature: Optional[float] = None, - num_beams: Optional[int] = None, + self, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + temperature: Optional[float] = None, + num_beams: Optional[int] = None, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances @@ -387,36 +349,33 @@ def _get_logits_warper( if temperature is not None and temperature != 1.0: warpers.append(TemperatureLogitsWarper(temperature)) if top_k is not None and top_k != 0: - warpers.append(TopKLogitsWarper( - top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) if top_p is not None and top_p < 1.0: - warpers.append(TopPLogitsWarper( - top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) if typical_p is not None and typical_p < 1.0: - warpers.append(TypicalLogitsWarper( - mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) return warpers def _get_logits_processor( - self, - repetition_penalty: float, - no_repeat_ngram_size: int, - encoder_no_repeat_ngram_size: int, - input_ids_seq_length: int, - encoder_input_ids: torch.LongTensor, - bad_words_ids: List[List[int]], - min_length: int, - max_length: int, - eos_token_id: int, - forced_bos_token_id: int, - forced_eos_token_id: int, - prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], - num_beams: int, - num_beam_groups: int, - diversity_penalty: float, - remove_invalid_values: bool, - exponential_decay_length_penalty: Tuple, - logits_processor: Optional[LogitsProcessorList], + self, + repetition_penalty: float, + no_repeat_ngram_size: int, + encoder_no_repeat_ngram_size: int, + input_ids_seq_length: int, + encoder_input_ids: torch.LongTensor, + bad_words_ids: List[List[int]], + min_length: int, + max_length: int, + eos_token_id: int, + forced_bos_token_id: int, + forced_eos_token_id: int, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], + num_beams: int, + num_beam_groups: int, + diversity_penalty: float, + remove_invalid_values: bool, + exponential_decay_length_penalty: Tuple, + logits_processor: Optional[LogitsProcessorList], ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] @@ -426,32 +385,23 @@ def _get_logits_processor( # init warp parameters repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty - no_repeat_ngram_size = ( - no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size - ) - encoder_no_repeat_ngram_size = ( - encoder_no_repeat_ngram_size - if encoder_no_repeat_ngram_size is not None - else self.config.encoder_no_repeat_ngram_size - ) + no_repeat_ngram_size = (no_repeat_ngram_size + if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size) + encoder_no_repeat_ngram_size = (encoder_no_repeat_ngram_size if encoder_no_repeat_ngram_size is not None else + self.config.encoder_no_repeat_ngram_size) bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids min_length = min_length if min_length is not None else self.config.min_length eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty - forced_bos_token_id = ( - forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id - ) - forced_eos_token_id = ( - forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id - ) - remove_invalid_values = ( - remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values - ) - exponential_decay_length_penalty = ( - exponential_decay_length_penalty - if exponential_decay_length_penalty is not None - else self.config.exponential_decay_length_penalty - ) + forced_bos_token_id = (forced_bos_token_id + if forced_bos_token_id is not None else self.config.forced_bos_token_id) + forced_eos_token_id = (forced_eos_token_id + if forced_eos_token_id is not None else self.config.forced_eos_token_id) + remove_invalid_values = (remove_invalid_values + if remove_invalid_values is not None else self.config.remove_invalid_values) + exponential_decay_length_penalty = (exponential_decay_length_penalty + if exponential_decay_length_penalty is not None else + self.config.exponential_decay_length_penalty) # instantiate processors list # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files @@ -459,95 +409,75 @@ def _get_logits_processor( if diversity_penalty is not None and diversity_penalty > 0.0: processors.append( HammingDiversityLogitsProcessor( - diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups - ) - ) + diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups)) if repetition_penalty is not None and repetition_penalty != 1.0: - processors.append(RepetitionPenaltyLogitsProcessor( - penalty=repetition_penalty)) + processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: - processors.append( - NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) + processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: if self.config.is_encoder_decoder: - processors.append(EncoderNoRepeatNGramLogitsProcessor( - encoder_no_repeat_ngram_size, encoder_input_ids)) + processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids)) else: - raise ValueError( - "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" - ) + raise ValueError("It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture") if bad_words_ids is not None: - processors.append(NoBadWordsLogitsProcessor( - bad_words_ids, eos_token_id)) + processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) if min_length is not None and eos_token_id is not None and min_length > 0: - processors.append(MinLengthLogitsProcessor( - min_length, eos_token_id)) + processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) if prefix_allowed_tokens_fn is not None: - processors.append(PrefixConstrainedLogitsProcessor( - prefix_allowed_tokens_fn, num_beams // num_beam_groups)) + processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) if forced_bos_token_id is not None: - processors.append( - ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) + processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) if forced_eos_token_id is not None: - processors.append(ForcedEOSTokenLogitsProcessor( - max_length, forced_eos_token_id)) + processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) if remove_invalid_values is True: processors.append(InfNanRemoveLogitsProcessor()) if exponential_decay_length_penalty is not None: processors.append( - ExponentialDecayLengthPenalty( - exponential_decay_length_penalty, eos_token_id, input_ids_seq_length) - ) - processors = self._merge_criteria_processor_list( - processors, logits_processor) + ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)) + processors = self._merge_criteria_processor_list(processors, logits_processor) return processors - def _get_stopping_criteria( - self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList] - ) -> StoppingCriteriaList: + def _get_stopping_criteria(self, max_length: Optional[int], max_time: Optional[float], + stopping_criteria: Optional[StoppingCriteriaList]) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if max_length is not None: criteria.append(MaxLengthCriteria(max_length=max_length)) if max_time is not None: criteria.append(MaxTimeCriteria(max_time=max_time)) - criteria = self._merge_criteria_processor_list( - criteria, stopping_criteria) + criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria def _merge_criteria_processor_list( - self, - default_list: Union[LogitsProcessorList, StoppingCriteriaList], - custom_list: Union[LogitsProcessorList, StoppingCriteriaList], + self, + default_list: Union[LogitsProcessorList, StoppingCriteriaList], + custom_list: Union[LogitsProcessorList, StoppingCriteriaList], ) -> Union[LogitsProcessorList, StoppingCriteriaList]: if len(custom_list) == 0: return default_list for default in default_list: for custom in custom_list: if type(custom) is type(default): - object_type = "stopping criteria" if isinstance( - custom, StoppingCriteria) else "logits processor" + object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" raise ValueError( f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to `generate`, " f"but it has already been created with the values {default}. {default} has been created by passing the " "corresponding arguments to generate or by the model's config default values. " f"If you just want to change the default values of {object_type} consider passing them as arguments " - f"to `generate` instead of using a custom {object_type}." - ) + f"to `generate` instead of using a custom {object_type}.") default_list.extend(custom_list) return default_list def compute_beam_search_raw_logits( - self, - sequences: torch.Tensor, - scores: Tuple[torch.Tensor], - beam_indices: torch.Tensor, - eos_token_id: int = None, + self, + sequences: torch.Tensor, + scores: Tuple[torch.Tensor], + beam_indices: torch.Tensor, + eos_token_id: int = None, ): """Compute raw logits for beam search""" if not self.config.is_encoder_decoder: - raise NotImplementedError( - "Beam Search raw logits code is implemented only for enoder-decoder only models") + raise NotImplementedError("Beam Search raw logits code is implemented only for enoder-decoder only models") # since sequences can be shorter than scores (probably due to beam search finalization) # we always have to generate raw_logits only for generated sequences @@ -569,57 +499,55 @@ def compute_beam_search_raw_logits( # gen_steps x batch_size x vocab_size beam_indices = beam_indices.unsqueeze(-1).repeat(1, 1, vocab_size) step_wise_logits = scores.gather(dim=1, index=beam_indices) - assert step_wise_logits.shape == torch.Size( - (gen_steps, batch_size, vocab_size)) + assert step_wise_logits.shape == torch.Size((gen_steps, batch_size, vocab_size)) # finally convert to tuples - step_wise_logits = [(step_wise_logits[t], None) - for t in range(gen_steps)] + step_wise_logits = [(step_wise_logits[t], None) for t in range(gen_steps)] return step_wise_logits - @ torch.no_grad() + @torch.no_grad() def generate( - self, - inputs: Optional[torch.Tensor] = None, - max_length: Optional[int] = None, - min_length: Optional[int] = None, - do_sample: Optional[bool] = None, - early_stopping: Optional[bool] = None, - num_beams: Optional[int] = None, - temperature: Optional[float] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - typical_p: Optional[float] = None, - repetition_penalty: Optional[float] = None, - bad_words_ids: Optional[Iterable[int]] = None, - force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, - bos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - length_penalty: Optional[float] = None, - no_repeat_ngram_size: Optional[int] = None, - encoder_no_repeat_ngram_size: Optional[int] = None, - num_return_sequences: Optional[int] = None, - max_time: Optional[float] = None, - max_new_tokens: Optional[int] = None, - decoder_start_token_id: Optional[int] = None, - use_cache: Optional[bool] = None, - num_beam_groups: Optional[int] = None, - diversity_penalty: Optional[float] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), - stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), - constraints: Optional[List[Constraint]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - forced_bos_token_id: Optional[int] = None, - forced_eos_token_id: Optional[int] = None, - remove_invalid_values: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, - **model_kwargs, + self, + inputs: Optional[torch.Tensor] = None, + max_length: Optional[int] = None, + min_length: Optional[int] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[bool] = None, + num_beams: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + bad_words_ids: Optional[Iterable[int]] = None, + force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + no_repeat_ngram_size: Optional[int] = None, + encoder_no_repeat_ngram_size: Optional[int] = None, + num_return_sequences: Optional[int] = None, + max_time: Optional[float] = None, + max_new_tokens: Optional[int] = None, + decoder_start_token_id: Optional[int] = None, + use_cache: Optional[bool] = None, + num_beam_groups: Optional[int] = None, + diversity_penalty: Optional[float] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), + stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), + constraints: Optional[List[Constraint]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + forced_bos_token_id: Optional[int] = None, + forced_eos_token_id: Optional[int] = None, + remove_invalid_values: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, + **model_kwargs, ): r""" @@ -848,9 +776,8 @@ def generate( early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups do_sample = do_sample if do_sample is not None else self.config.do_sample - num_return_sequences = ( - num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences - ) + num_return_sequences = (num_return_sequences + if num_return_sequences is not None else self.config.num_return_sequences) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id @@ -860,26 +787,22 @@ def generate( if pad_token_id is None and eos_token_id is not None: # special case if pad_token_id is not defined - logger.warning( - f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict_in_generate = (return_dict_in_generate + if return_dict_in_generate is not None else self.config.return_dict_in_generate) # 2. Define model inputs # inputs_tensor has to be defined # model_input_name is defined if model-specific keyword input is passed # otherwise model_input_name is None # all model-specific keyword inputs are removed from `model_kwargs` - inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( - inputs, bos_token_id, model_kwargs) + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) batch_size = inputs_tensor.shape[0] # 3. Define other model kwargs @@ -887,21 +810,18 @@ def generate( model_kwargs["output_hidden_states"] = output_hidden_states model_kwargs["use_cache"] = use_cache - accepts_attention_mask = "attention_mask" in set( - inspect.signature(self.forward).parameters.keys()) + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, pad_token_id, eos_token_id - ) + inputs_tensor, pad_token_id, eos_token_id) if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: # if model is encoder decoder encoder_outputs are created # and added to `model_kwargs` - model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, model_kwargs, model_input_name - ) + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs_tensor, model_kwargs, + model_input_name) # 4. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: @@ -941,32 +861,21 @@ def generate( # 6. determine generation mode is_constraint_gen_mode = constraints is not None or force_words_ids is not None - is_greedy_gen_mode = ( - (num_beams == 1) and (num_beam_groups == - 1) and do_sample is False and not is_constraint_gen_mode - ) - is_sample_gen_mode = ( - (num_beams == 1) and (num_beam_groups == - 1) and do_sample is True and not is_constraint_gen_mode - ) - is_beam_gen_mode = ( - (num_beams > 1) and (num_beam_groups == - 1) and do_sample is False and not is_constraint_gen_mode - ) - is_beam_sample_gen_mode = ( - (num_beams > 1) and (num_beam_groups == - 1) and do_sample is True and not is_constraint_gen_mode - ) - is_group_beam_gen_mode = (num_beams > 1) and ( - num_beam_groups > 1) and not is_constraint_gen_mode + is_greedy_gen_mode = ((num_beams == 1) and (num_beam_groups == 1) and do_sample is False + and not is_constraint_gen_mode) + is_sample_gen_mode = ((num_beams == 1) and (num_beam_groups == 1) and do_sample is True + and not is_constraint_gen_mode) + is_beam_gen_mode = ((num_beams > 1) and (num_beam_groups == 1) and do_sample is False + and not is_constraint_gen_mode) + is_beam_sample_gen_mode = ((num_beams > 1) and (num_beam_groups == 1) and do_sample is True + and not is_constraint_gen_mode) + is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode if num_beam_groups > num_beams: - raise ValueError( - "`num_beam_groups` has to be smaller or equal to `num_beams`") + raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") if is_group_beam_gen_mode and do_sample is True: raise ValueError( - "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." - ) + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`.") # 7. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( @@ -992,15 +901,13 @@ def generate( # 8. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( - max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria - ) + max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria) # 9. go into different generation modes if is_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( - top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams - ) + top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams) # 11. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -1027,22 +934,21 @@ def generate( else: raise NotImplementedError - def sample( - self, - input_ids: torch.LongTensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - **model_kwargs, + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, ): r""" Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and @@ -1152,35 +1058,28 @@ def sample( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict_in_generate = (return_dict_in_generate + if return_dict_in_generate is not None else self.config.return_dict_in_generate) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if ( - return_dict_in_generate and output_hidden_states) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get( - "attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get( - "hidden_states") if output_hidden_states else None - ) + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = (model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states else None) # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) @@ -1193,8 +1092,7 @@ def sample( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0).to(input_ids.device) + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -1202,8 +1100,7 @@ def sample( break # prepare model inputs - model_inputs = self.prepare_inputs_for_generation( - input_ids, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( @@ -1221,29 +1118,22 @@ def sample( next_token_logits = outputs.logits[:, -1, :] # pre-process distribution - next_token_scores = logits_processor( - input_ids, next_token_logits, model_inputs=model_inputs) - next_token_scores = logits_warper( - input_ids, next_token_scores) + next_token_scores = logits_processor(input_ids, next_token_logits, model_inputs=model_inputs) + next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: - scores += ((next_token_logits_raw, next_token_scores),) + scores += ((next_token_logits_raw, next_token_scores), ) if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( - outputs.attentions,) - ) + decoder_attentions += ((outputs.decoder_attentions, ) if self.config.is_encoder_decoder else + (outputs.attentions, )) if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) + cross_attentions += (outputs.cross_attentions, ) if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) + decoder_hidden_states += ((outputs.decoder_hidden_states, ) if self.config.is_encoder_decoder else + (outputs.hidden_states, )) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) @@ -1252,22 +1142,19 @@ def sample( # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + \ pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder) cur_len = cur_len + 1 # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul( - (next_tokens != eos_token_id).long()) + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py b/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py index 20b1f7034..b35d33941 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py @@ -3,7 +3,7 @@ class KLController: - def __init__(self, kl_coeff, target_kl = None): + def __init__(self, kl_coeff, target_kl=None): self._kl_coeff = kl_coeff self._target_kl = target_kl diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py index 374ca3a3c..423d5bc34 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py @@ -7,44 +7,41 @@ from parl.utils import logger - class MeteorMetric: def __init__(self): super().__init__() self._metric = load_metric("meteor") def compute( - self, - prompt_texts, - generated_texts, - reference_texts, - meta_infos = None, - model = None, - split_name = None, + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos=None, + model=None, + split_name=None, ): - score = self._metric.compute( - predictions=generated_texts, references=reference_texts - )["meteor"] + score = self._metric.compute(predictions=generated_texts, references=reference_texts)["meteor"] metric_dict = {"lexical/meteor": (None, score)} return metric_dict class RougeMetric: - def __init__(self, use_single_ref = True): + def __init__(self, use_single_ref=True): super().__init__() self._metric = load_metric("rouge") self._use_single_ref = use_single_ref def compute( - self, - prompt_texts, - generated_texts, - reference_texts, - meta_infos = None, - model = None, - split_name = None, + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos=None, + model=None, + split_name=None, ): if self._use_single_ref: # TBD: this is required for CNN/DM dataset, without this we get low scores @@ -53,9 +50,7 @@ def compute( else: ref_texts = reference_texts - metric_results = self._metric.compute( - predictions=generated_texts, references=ref_texts, use_stemmer=True - ) + metric_results = self._metric.compute(predictions=generated_texts, references=ref_texts, use_stemmer=True) score_keys = ["rouge1", "rouge2", "rougeL", "rougeLsum"] metric_dict = {} for rouge_type in score_keys: @@ -73,13 +68,13 @@ def __init__(self, language): self._last_gpu = f"cuda:{torch.cuda.device_count() - 1}" def compute( - self, - prompt_texts, - generated_texts, - reference_texts, - meta_infos = None, - model = None, - split_name = None, + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos=None, + model=None, + split_name=None, ): with torch.no_grad(): metric_results = self._metric.compute( @@ -100,13 +95,13 @@ def __init__(self): self._metric = load_metric("bleu") def compute( - self, - prompt_texts, - generated_texts, - reference_texts, - meta_infos = None, - model = None, - split_name = None, + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos=None, + model=None, + split_name=None, ): tokenized_predictions = [] @@ -119,8 +114,7 @@ def compute( try: metric_results = self._metric.compute( - predictions=tokenized_predictions, references=tokenized_reference_texts - ) + predictions=tokenized_predictions, references=tokenized_reference_texts) bleu_score = metric_results["bleu"] metric_dict = {"lexical/bleu": (None, bleu_score)} return metric_dict @@ -129,18 +123,18 @@ def compute( class DiversityMetrics: - def __init__(self, window_size = 100): + def __init__(self, window_size=100): self._msttr_metric = MSTTR(window_size=window_size) self._n_gram_metric = NGramStats() def compute( - self, - prompt_texts, - generated_texts, - reference_texts, - meta_infos = None, - model = None, - split_name = None, + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos=None, + model=None, + split_name=None, ): predictions = Predictions(data={"filename": "", "values": generated_texts}) diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py index a8d4ca77f..f81bf36b9 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py @@ -1,9 +1,8 @@ from datasets import load_metric + class RougeRewardFunction: - def __init__( - self, rouge_type, use_single_ref = True - ): + def __init__(self, rouge_type, use_single_ref=True): super().__init__() self._metric = load_metric("rouge") self._rouge_type = rouge_type @@ -12,12 +11,12 @@ def __init__( self._use_single_ref = use_single_ref def __call__( - self, - current_observation, - action, - next_observation, - done, - meta_info = None, + self, + current_observation, + action, + next_observation, + done, + meta_info=None, ): if done: # TBD: considers only one reference for now @@ -27,14 +26,10 @@ def __call__( references = [next_observation.target_or_reference_texts] predicted = [next_observation.context_text] - metric_results = self._metric.compute( - predictions=predicted, references=references, use_stemmer=True - ) + metric_results = self._metric.compute(predictions=predicted, references=references, use_stemmer=True) reward = metric_results[self._rouge_type].mid.fmeasure if self._shaping_fn is not None: - aux_score = self._shaping_fn( - current_observation, action, next_observation, done, meta_info - ) + aux_score = self._shaping_fn(current_observation, action, next_observation, done, meta_info) reward = reward + aux_score return reward - return 0 \ No newline at end of file + return 0 diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py index fa2d64fab..33ac81dc7 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py @@ -10,9 +10,11 @@ def dict_to_tensor(obs, device): return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} + def get_one_token_obs(obs, idx, space): return OrderedDict([(k, obs[k][:, idx, :]) for k in space.spaces.keys()]) + def unpack_observations(obs_tensor, n_instructors): """ Unpacks vectorized dict observations into separate dict observations @@ -27,9 +29,7 @@ def unpack_observations(obs_tensor, n_instructors): return unpacked_obs -def add_to_buffer( - rollout_buffer, episode_wise_transitions, rollout_info -): +def add_to_buffer(rollout_buffer, episode_wise_transitions, rollout_info): advantages_computed = False for ep_ix, transitions in enumerate(episode_wise_transitions): ep_length = len(transitions) @@ -40,9 +40,7 @@ def add_to_buffer( total_kl_reward += transition.kl_reward rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div) rollout_info["rollout_info/log_prob"].append(transition.log_prob) - rollout_info["rollout_info/ref_log_prob"].append( - transition.ref_log_prob - ) + rollout_info["rollout_info/ref_log_prob"].append(transition.ref_log_prob) rollout_info["rollout_info/values"].append(transition.value.numpy()) if not rollout_buffer.full: @@ -59,15 +57,10 @@ def add_to_buffer( if rollout_buffer.full and not advantages_computed: # we fetch the last value for the last time step # values come from the next transitions's values - next_values = ( - transitions[transition_ix + 1].value - if (transition_ix + 1) < ep_length - else torch.tensor([0.0]) - ) + next_values = (transitions[transition_ix + 1].value if + (transition_ix + 1) < ep_length else torch.tensor([0.0])) - rollout_buffer.compute_returns_and_advantage( - last_values=next_values, dones=transition.done - ) + rollout_buffer.compute_returns_and_advantage(last_values=next_values, dones=transition.done) advantages_computed = True rollout_info["rollout_info/ep_rew"].append(total_reward) @@ -80,13 +73,7 @@ class RolloutUtil: def __init__(self, kl_args): self._kl_controller = KLController(kl_args["coeff"], kl_args["target_kl"]) - def collect_rollouts( - self, - agent, - instructor_group, - rollout_buffer, - device - ): + def collect_rollouts(self, agent, instructor_group, rollout_buffer, device): # get tokenizer tokenizer = instructor_group.tokenizer @@ -145,41 +132,34 @@ def collect_rollouts( for key, values in rollout_info.items(): aggregated_rollout_info[key] = np.mean(values).item() aggregated_rollout_info[f"{key}_std"] = np.std(values).item() - aggregated_rollout_info[ - "rollout_info/kl_coeff" - ] = self._kl_controller.kl_coeff + aggregated_rollout_info["rollout_info/kl_coeff"] = self._kl_controller.kl_coeff logger.info(f"Rollout Info: {aggregated_rollout_info}") # adapt the KL coeff - self._kl_controller.step( - torch.tensor(aggregated_rollout_info["rollout_info/kl_div_mean"]) - ) + self._kl_controller.step(torch.tensor(aggregated_rollout_info["rollout_info/kl_div_mean"])) return num_timesteps - def _generate_transition_and_add_to_buffer( - self, - gen_sentence=None, - agent=None, - n_instructors=None, - obs_space=None, - rollout_buffer=None, - rollout_info=None, - device=None, - sentence_new_obs=None, - sentence_rewards=None, - sentence_dones=None, - sentence_infos=None, - init_obs=None - ): + def _generate_transition_and_add_to_buffer(self, + gen_sentence=None, + agent=None, + n_instructors=None, + obs_space=None, + rollout_buffer=None, + rollout_info=None, + device=None, + sentence_new_obs=None, + sentence_rewards=None, + sentence_dones=None, + sentence_infos=None, + init_obs=None): current_obs = init_obs review_times = 0 - episode_starts = np.ones((n_instructors,), dtype=bool) + episode_starts = np.ones((n_instructors, ), dtype=bool) # process them one step at a time to collect rollout info episode_wise_transitions = [[] for _ in range(n_instructors)] - ep_terminated = np.zeros((n_instructors,), dtype=bool) - + ep_terminated = np.zeros((n_instructors, ), dtype=bool) for idx, actions_tensor in enumerate(gen_sentence.step_wise_actions): if np.all(ep_terminated): @@ -246,11 +226,9 @@ def _generate_transition_and_add_to_buffer( if dones[instructor_ix]: ep_terminated[instructor_ix] = True - episode_starts = np.zeros((n_instructors,), dtype=bool) + episode_starts = np.zeros((n_instructors, ), dtype=bool) current_obs = new_obs # now we flush all episode wise info to the 1-D buffer - rollout_info = add_to_buffer( - rollout_buffer, episode_wise_transitions, rollout_info - ) + rollout_info = add_to_buffer(rollout_buffer, episode_wise_transitions, rollout_info) return rollout_info, review_times diff --git a/benchmark/torch/RL4LMs/seq2seq_model.py b/benchmark/torch/RL4LMs/seq2seq_model.py index 05f6f4776..22d26c688 100644 --- a/benchmark/torch/RL4LMs/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/seq2seq_model.py @@ -9,23 +9,23 @@ import parl from rl4lms_utils import ( override_generation_routines, - - GenerationInputs, GenerationOutputs, + GenerationInputs, + GenerationOutputs, ) class Seq2SeqLMModel(parl.Model): def __init__( - self, - observation_space, - action_space, - model_name, - weight_decay = 1e-6, - apply_model_parallel = True, - optimizer_class = torch.optim.AdamW, - generation_kwargs = {}, - prompt_truncation_side = "left", - device = None, + self, + observation_space, + action_space, + model_name, + weight_decay=1e-6, + apply_model_parallel=True, + optimizer_class=torch.optim.AdamW, + generation_kwargs={}, + prompt_truncation_side="left", + device=None, ): super(Seq2SeqLMModel, self).__init__() @@ -43,7 +43,6 @@ def __init__( self._generation_kwargs = generation_kwargs self._prompt_truncation_side = prompt_truncation_side - def _build_model_heads(self, model_name): self._policy_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) self._policy_model.__class__ = override_generation_routines(type(self._policy_model)) @@ -51,9 +50,7 @@ def _build_model_heads(self, model_name): self._value_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) self._ref_model = deepcopy(self._policy_model).eval() - self._value_head = nn.Linear( - self._value_model.config.hidden_size, 1, bias=False - ) + self._value_head = nn.Linear(self._value_model.config.hidden_size, 1, bias=False) # apply model parallel if torch.cuda.is_available(): @@ -66,31 +63,23 @@ def _build_model_heads(self, model_name): self._policy_model = torch.nn.DataParallel(self._policy_model) self._ref_model = torch.nn.DataParallel(self._ref_model) self._value_model = torch.nn.DataParallel(self._value_model) - self._value_head = torch.nn.DataParallel( - self._value_head.to(self.device) - ) + self._value_head = torch.nn.DataParallel(self._value_head.to(self.device)) def forward_policy( - self, - obs, - actions, + self, + obs, + actions, ): # 1. prepare model inputs past_model_kwargs = { "attention_mask": obs["prompt_or_input_attention_mask_pt"], } - inputs_tensor, model_input_name, past_model_kwargs = unwrap_model( - self._policy_model - )._prepare_model_inputs( - obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs - ) + inputs_tensor, model_input_name, past_model_kwargs = unwrap_model(self._policy_model)._prepare_model_inputs( + obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs) # 2. prepare encoder outputs - past_model_kwargs = unwrap_model( - self._policy_model - )._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, past_model_kwargs, model_input_name - ) + past_model_kwargs = unwrap_model(self._policy_model)._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, past_model_kwargs, model_input_name) # 3. Prepare input_ids for auto-regressive generation input_ids = obs["context_encoded_pt"].int() @@ -99,14 +88,10 @@ def forward_policy( # all set to get into auto-regressive mode # prepare all of the model inputs for the decoder batch_size = input_ids.shape[0] - model_inputs = unwrap_model(self._policy_model).prepare_inputs_for_generation( - input_ids, **past_model_kwargs - ) + model_inputs = unwrap_model(self._policy_model).prepare_inputs_for_generation(input_ids, **past_model_kwargs) # and forward pass to get next token logits - outputs = self._policy_model( - **model_inputs, decoder_attention_mask=decoder_attn_mask, return_dict=True - ) + outputs = self._policy_model(**model_inputs, decoder_attention_mask=decoder_attn_mask, return_dict=True) next_token_logits = outputs.logits[:, -1, :] # get log probs @@ -115,14 +100,10 @@ def forward_policy( entropy = dist.entropy() # update the model kwargs for further generation - past_model_kwargs = unwrap_model( - self._policy_model - )._update_model_kwargs_for_generation( + past_model_kwargs = unwrap_model(self._policy_model)._update_model_kwargs_for_generation( outputs, past_model_kwargs, - is_encoder_decoder=unwrap_model( - self._policy_model - ).config.is_encoder_decoder, + is_encoder_decoder=unwrap_model(self._policy_model).config.is_encoder_decoder, ) past_model_kwargs["decoder_attention_mask"] = torch.cat( (decoder_attn_mask, torch.ones(batch_size, 1).to(decoder_attn_mask.device)), @@ -132,25 +113,19 @@ def forward_policy( return actions, log_prob, entropy, past_model_kwargs def forward_value( - self, - obs, + self, + obs, ): # 1. prepare model inputs past_model_kwargs = { "attention_mask": obs["prompt_or_input_attention_mask_pt"], } - inputs_tensor, model_input_name, past_model_kwargs = unwrap_model( - self._value_model - )._prepare_model_inputs( - obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs - ) + inputs_tensor, model_input_name, past_model_kwargs = unwrap_model(self._value_model)._prepare_model_inputs( + obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs) # 2. prepare encoder outputs - past_model_kwargs = unwrap_model( - self._value_model - )._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, past_model_kwargs, model_input_name - ) + past_model_kwargs = unwrap_model(self._value_model)._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, past_model_kwargs, model_input_name) # 3. Prepare input_ids for auto-regressive generation input_ids = obs["context_encoded_pt"].int() @@ -159,31 +134,21 @@ def forward_value( # all set to get into auto-regressive mode # prepare all of the model inputs for the decoder batch_size = input_ids.shape[0] - model_inputs = unwrap_model(self._value_model).prepare_inputs_for_generation( - input_ids, **past_model_kwargs - ) + model_inputs = unwrap_model(self._value_model).prepare_inputs_for_generation(input_ids, **past_model_kwargs) # and forrward pass to get hidden states outputs = self._value_model( - **model_inputs, - output_hidden_states=True, - decoder_attention_mask=decoder_attn_mask, - return_dict=True - ) + **model_inputs, output_hidden_states=True, decoder_attention_mask=decoder_attn_mask, return_dict=True) # get decoder's last hidden state last_tokens_hidden = outputs.decoder_hidden_states[-1][:, -1, :].to(self.device) values = self._value_head.forward(last_tokens_hidden) # update the model kwargs for further generation - past_model_kwargs = unwrap_model( - self._value_model - )._update_model_kwargs_for_generation( + past_model_kwargs = unwrap_model(self._value_model)._update_model_kwargs_for_generation( outputs, past_model_kwargs, - is_encoder_decoder=unwrap_model( - self._value_model - ).config.is_encoder_decoder, + is_encoder_decoder=unwrap_model(self._value_model).config.is_encoder_decoder, ) past_model_kwargs["decoder_attention_mask"] = torch.cat( (decoder_attn_mask, torch.ones(batch_size, 1).to(decoder_attn_mask.device)), @@ -191,9 +156,7 @@ def forward_value( ) return values, past_model_kwargs - def evaluate_actions( - self, obs, actions - ): + def evaluate_actions(self, obs, actions): _, log_prob, entropy, _ = self.forward_policy(obs=obs, actions=actions) values, _ = self.forward_value(obs) @@ -207,26 +170,20 @@ def to(self, device): return super().to(device) def get_log_probs_ref_model( - self, - obs, - action, + self, + obs, + action, ): # 1. prepare model inputs past_model_kwargs = { "attention_mask": obs["prompt_or_input_attention_mask_pt"], } - inputs_tensor, model_input_name, past_model_kwargs = unwrap_model( - self._ref_model - )._prepare_model_inputs( - obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs - ) + inputs_tensor, model_input_name, past_model_kwargs = unwrap_model(self._ref_model)._prepare_model_inputs( + obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs) # 2. prepare encoder outputs - past_model_kwargs = unwrap_model( - self._ref_model - )._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, past_model_kwargs, model_input_name - ) + past_model_kwargs = unwrap_model(self._ref_model)._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, past_model_kwargs, model_input_name) # 3. Prepare input_ids for auto-regressive generation input_ids = obs["context_encoded_pt"].int() @@ -235,14 +192,10 @@ def get_log_probs_ref_model( # all set to get into auto-regressive mode # prepare all of the model inputs for the decoder batch_size = input_ids.shape[0] - model_inputs = unwrap_model(self._ref_model).prepare_inputs_for_generation( - input_ids, **past_model_kwargs - ) + model_inputs = unwrap_model(self._ref_model).prepare_inputs_for_generation(input_ids, **past_model_kwargs) # and forward pass to get next token logits - outputs = self._ref_model( - **model_inputs, decoder_attention_mask=decoder_attn_mask, return_dict=True - ) + outputs = self._ref_model(**model_inputs, decoder_attention_mask=decoder_attn_mask, return_dict=True) next_token_logits = outputs.logits[:, -1, :] # get log probs @@ -250,9 +203,7 @@ def get_log_probs_ref_model( log_prob = dist.log_prob(action) # update the model kwargs for further generation - past_model_kwargs = unwrap_model( - self._ref_model - )._update_model_kwargs_for_generation( + past_model_kwargs = unwrap_model(self._ref_model)._update_model_kwargs_for_generation( outputs, past_model_kwargs, is_encoder_decoder=unwrap_model(self._ref_model).config.is_encoder_decoder, @@ -264,30 +215,25 @@ def get_log_probs_ref_model( return log_prob, past_model_kwargs def get_policy_first_device(self): - return ( - self._policy_model.get_encoder().first_device - if self._apply_model_parallel - else self.device - ) + return (self._policy_model.get_encoder().first_device if self._apply_model_parallel else self.device) def get_inputs_for_generation(self, obs): - generation_inputs = GenerationInputs( - obs["prompt_or_input_encoded_pt"], obs["prompt_or_input_attention_mask_pt"] - ) + generation_inputs = GenerationInputs(obs["prompt_or_input_encoded_pt"], + obs["prompt_or_input_attention_mask_pt"]) return generation_inputs def get_language_model(self): return unwrap_model(self._policy_model) def sample( - self, - tokenizer, - texts = None, - max_prompt_length = None, - input_ids = None, - attention_mask = None, - gen_kwargs = None, + self, + tokenizer, + texts=None, + max_prompt_length=None, + input_ids=None, + attention_mask=None, + gen_kwargs=None, ): # if it different from rollout gen kwargs @@ -297,12 +243,7 @@ def sample( # switch to eval self._policy_model.eval() - if ( - input_ids is None - and attention_mask is None - and texts is not None - and max_prompt_length is not None - ): + if (input_ids is None and attention_mask is None and texts is not None and max_prompt_length is not None): # override truncation side for prompt prev_truncation_side = tokenizer.truncation_side tokenizer.truncation_side = self._prompt_truncation_side @@ -320,13 +261,9 @@ def sample( # if min_length argument is set and if policy is not a seq2seq LM (ie. causal LM) # then it has to be adjusted to input_size + min_length - if "min_length" in gen_kwargs.keys() and not self.is_encoder_decoder( - self._policy_model - ): + if "min_length" in gen_kwargs.keys() and not self.is_encoder_decoder(self._policy_model): generation_kwargs_ = deepcopy(gen_kwargs) - generation_kwargs_["min_length"] = ( - input_ids.shape[1] + gen_kwargs["min_length"] - ) + generation_kwargs_["min_length"] = (input_ids.shape[1] + gen_kwargs["min_length"]) else: generation_kwargs_ = gen_kwargs @@ -346,10 +283,7 @@ def sample( gen_tokens = gen_output["sequences"][:, -seq_length:] # to texts - gen_texts = [ - tokenizer.decode(output, skip_special_tokens=True) - for output in gen_tokens.tolist() - ] + gen_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in gen_tokens.tolist()] # extract scores (logits) step_wise_logprobs = [] @@ -362,19 +296,15 @@ def sample( step_wise_logprobs.append(log_probs) step_wise_actions.append(actions_at_step) - gen_output = GenerationOutputs( - step_wise_logprobs, step_wise_actions, gen_tokens, gen_texts - ) + gen_output = GenerationOutputs(step_wise_logprobs, step_wise_actions, gen_tokens, gen_texts) return gen_output - def is_encoder_decoder(self, model): return unwrap_model(model).config.is_encoder_decoder def set_training_mode(self, mode): self.train(mode) - def _get_constructor_parameters(self): return dict( observation_space=self.observation_space, @@ -389,11 +319,10 @@ def save(self, path): """ torch.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) - def _setup_optimizer( - self, - weight_decay, - optimizer_class, + self, + weight_decay, + optimizer_class, ): params = list(self.named_parameters()) @@ -409,6 +338,3 @@ def _setup_optimizer( }, ] self.optimizer = optimizer_class(optimizer_grouped_parameters) - - - diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 33f6ad720..5f9ac6ceb 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -35,8 +35,7 @@ def recursive_dict_update(d, u): def main(config): - device = torch.device("cuda" if torch.cuda. - is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = build_tokenizer(config["tokenizer"]) @@ -46,11 +45,13 @@ def main(config): # datapool samples_by_split = build_datapool(config["datapool"]) - instructor_group = InstructorGroup(instructor_config=config["instructor"], - reward_config=config["reward_fn"], - tokenizer=tokenizer, - tokenizer_config=config["tokenizer"], - datapool_config=config["datapool"],) + instructor_group = InstructorGroup( + instructor_config=config["instructor"], + reward_config=config["reward_fn"], + tokenizer=tokenizer, + tokenizer_config=config["tokenizer"], + datapool_config=config["datapool"], + ) rl4lms_model = Seq2SeqLMModel( observation_space=instructor_group.observation_space, @@ -59,15 +60,15 @@ def main(config): model_name=config["alg"]["model"]["args"]["model_name"], apply_model_parallel=config["alg"]["model"]["args"]["apply_model_parallel"], prompt_truncation_side=config["alg"]["model"]["args"]["prompt_truncation_side"], - generation_kwargs=config["alg"]["model"]["args"]["generation_kwargs"] - ) - rl4lm_alg = RL4LMPPO(model=rl4lms_model, - device=device, - n_steps=config["alg"]["args"]["n_steps"], - batch_size=config["alg"]["args"]["batch_size"], - learning_rate=config["alg"]["args"]["learning_rate"], - n_epochs=config["alg"]["args"]["n_epochs"], - ent_coef=config["alg"]["args"]["ent_coef"]) + generation_kwargs=config["alg"]["model"]["args"]["generation_kwargs"]) + rl4lm_alg = RL4LMPPO( + model=rl4lms_model, + device=device, + n_steps=config["alg"]["args"]["n_steps"], + batch_size=config["alg"]["args"]["batch_size"], + learning_rate=config["alg"]["args"]["learning_rate"], + n_epochs=config["alg"]["args"]["n_epochs"], + ent_coef=config["alg"]["args"]["ent_coef"]) agent = RL4LMsAgent(rl4lm_alg, config["alg"]) rollout_buffer = DictRolloutBuffer( @@ -89,18 +90,15 @@ def main(config): eval_gen_kwargs = config["train_evaluation"]["generation_kwargs"] eval_batch_size = config["train_evaluation"]["eval_batch_size"] examiner = Examiner( - tokenizer=tokenizer, - eval_batch_size=eval_batch_size, - metrics=metrics, - eval_gen_kwargs=eval_gen_kwargs, - samples_by_split=samples_by_split, - max_prompt_length=max_prompt_length - ) + tokenizer=tokenizer, + eval_batch_size=eval_batch_size, + metrics=metrics, + eval_gen_kwargs=eval_gen_kwargs, + samples_by_split=samples_by_split, + max_prompt_length=max_prompt_length) iter_start = 0 - examiner.evaluate(policy=agent.alg.model, - sample_name_list=["val", "test"], - epoch=iter_start) + examiner.evaluate(policy=agent.alg.model, sample_name_list=["val", "test"], epoch=iter_start) epoch = 0 for epoch in range(iter_start, n_iters): @@ -125,21 +123,15 @@ def main(config): # evaluate on val set in the given intervals if (epoch + 1) % config["train_evaluation"]["eval_every"] == 0: - examiner.evaluate(policy=agent.alg.model, - sample_name_list=["val"], - epoch=epoch) + examiner.evaluate(policy=agent.alg.model, sample_name_list=["val"], epoch=epoch) - examiner.evaluate(policy=agent.alg.model, - sample_name_list=["val", "test"], - epoch=epoch) + examiner.evaluate(policy=agent.alg.model, sample_name_list=["val", "test"], epoch=epoch) if __name__ == '__main__': parser = ArgumentParser(description="Fine-tune LM to generate controlled text") parser.add_argument("--config_path", type=str, help="path to the config file") - parser.add_argument( - "--project_name", type=str, help="project name", default="rl4lm_exps" - ) + parser.add_argument("--project_name", type=str, help="project name", default="rl4lm_exps") parser.add_argument( "--experiment_name", type=str, @@ -152,9 +144,7 @@ def main(config): help="Base path to store experiment results", default=os.getcwd(), ) - parser.add_argument( - "--entity_name", type=str, help="entity name", default="summarization" - ) + parser.add_argument("--entity_name", type=str, help="entity name", default="summarization") args = parser.parse_args() # load the config file @@ -170,4 +160,3 @@ def main(config): logger.set_level("DEBUG") main(config) - From b66f07eabccfac9f838a62bbe5803fec6021947a Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 13 Mar 2023 19:01:14 +0800 Subject: [PATCH 17/34] change train.py style --- benchmark/torch/RL4LMs/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 5f9ac6ceb..4f30b8eee 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -18,7 +18,6 @@ # rollout from rl4lms_utils import DictRolloutBuffer, RolloutUtil - # agent, algorithm and model from rl4lm_ppo import RL4LMPPO from rl4lms_agent import RL4LMsAgent From 337ac75fa2be8f1cbfa84c81c8e43bc1ac96ba76 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 13 Mar 2023 21:02:16 +0800 Subject: [PATCH 18/34] change style --- benchmark/torch/RL4LMs/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 4f30b8eee..86f18d8da 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -2,6 +2,7 @@ import sys from argparse import ArgumentParser import datetime + import yaml import collections from parl.utils import logger From d0ced44ec49af722099c7aacd9de7a1b2fbb51dc Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 13 Mar 2023 21:11:15 +0800 Subject: [PATCH 19/34] change style --- benchmark/torch/RL4LMs/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 86f18d8da..4f30b8eee 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -2,7 +2,6 @@ import sys from argparse import ArgumentParser import datetime - import yaml import collections from parl.utils import logger From 151fcea2d1a8728d6a9a88550b41b7eb3287da48 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Tue, 14 Mar 2023 10:08:00 +0800 Subject: [PATCH 20/34] change code style(add copyright) --- benchmark/torch/RL4LMs/README.md | 2 +- benchmark/torch/RL4LMs/instructor.py | 14 ++++++++++++++ benchmark/torch/RL4LMs/requirements.txt | 2 +- benchmark/torch/RL4LMs/rl4lm_ppo.py | 16 ++++++++++++++-- benchmark/torch/RL4LMs/rl4lms_agent.py | 14 ++++++++++++++ benchmark/torch/RL4LMs/rl4lms_utils/__init__.py | 14 ++++++++++++++ benchmark/torch/RL4LMs/rl4lms_utils/buffer.py | 14 ++++++++++++++ .../RL4LMs/rl4lms_utils/component_build_util.py | 14 ++++++++++++++ benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py | 14 ++++++++++++++ .../torch/RL4LMs/rl4lms_utils/data_wrapper.py | 14 ++++++++++++++ benchmark/torch/RL4LMs/rl4lms_utils/examiner.py | 14 ++++++++++++++ .../rl4lms_utils/huggingface_generation_util.py | 14 ++++++++++++++ .../torch/RL4LMs/rl4lms_utils/kl_controller.py | 15 ++++++++++++++- .../torch/RL4LMs/rl4lms_utils/metric_util.py | 14 ++++++++++++++ .../torch/RL4LMs/rl4lms_utils/reward_util.py | 14 ++++++++++++++ .../torch/RL4LMs/rl4lms_utils/rollout_util.py | 14 ++++++++++++++ benchmark/torch/RL4LMs/seq2seq_model.py | 14 ++++++++++++++ benchmark/torch/RL4LMs/t5_ppo.yml | 1 - benchmark/torch/RL4LMs/train.py | 16 +++++++++++++++- 19 files changed, 227 insertions(+), 7 deletions(-) diff --git a/benchmark/torch/RL4LMs/README.md b/benchmark/torch/RL4LMs/README.md index 8e76ee439..da133cfcf 100644 --- a/benchmark/torch/RL4LMs/README.md +++ b/benchmark/torch/RL4LMs/README.md @@ -21,4 +21,4 @@ xparl start --port 8811 --cpu_num 10 # start training python train.py --config_path t5_ppo.yml -``` \ No newline at end of file +``` diff --git a/benchmark/torch/RL4LMs/instructor.py b/benchmark/torch/RL4LMs/instructor.py index 5f6c2e144..37bcd4ad6 100644 --- a/benchmark/torch/RL4LMs/instructor.py +++ b/benchmark/torch/RL4LMs/instructor.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections import OrderedDict import torch from rl4lms_utils import Observation diff --git a/benchmark/torch/RL4LMs/requirements.txt b/benchmark/torch/RL4LMs/requirements.txt index f5daa46d1..e7f120809 100644 --- a/benchmark/torch/RL4LMs/requirements.txt +++ b/benchmark/torch/RL4LMs/requirements.txt @@ -9,4 +9,4 @@ gym==0.21.0 cchardet==2.1.7 nltk==3.7 gem-metrics @ git+https://github.com/GEM-benchmark/GEM-metrics.git@431a8174bd6b3637e8d6118bfad2983e39e99733 -bert-score==0.3.11 \ No newline at end of file +bert-score==0.3.11 diff --git a/benchmark/torch/RL4LMs/rl4lm_ppo.py b/benchmark/torch/RL4LMs/rl4lm_ppo.py index e5133b073..a0b975ca4 100644 --- a/benchmark/torch/RL4LMs/rl4lm_ppo.py +++ b/benchmark/torch/RL4LMs/rl4lm_ppo.py @@ -1,10 +1,22 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import parl import torch from gym import spaces from torch.nn import functional as F -from parl.algorithms.torch import PPO - class RL4LMPPO(parl.Algorithm): def __init__( diff --git a/benchmark/torch/RL4LMs/rl4lms_agent.py b/benchmark/torch/RL4LMs/rl4lms_agent.py index fb127024c..fff8316eb 100644 --- a/benchmark/torch/RL4LMs/rl4lms_agent.py +++ b/benchmark/torch/RL4LMs/rl4lms_agent.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import parl import numpy as np diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/__init__.py b/benchmark/torch/RL4LMs/rl4lms_utils/__init__.py index e884625f0..8e721573b 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/__init__.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/__init__.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .data_wrapper import RefPolicyOutput, GenerationInputs, GenerationOutputs,\ PolicyType, Sample, Observation, TransitionInfo diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py index b96da23a9..e483a8886 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np import torch from gym import spaces diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py index 512244cdf..695e0e550 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from transformers import AutoTokenizer from parl.utils import logger from .reward_util import RougeRewardFunction diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py b/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py index 1720fe5e2..7212bd84d 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from datasets import load_dataset from .data_wrapper import Sample import random diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py b/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py index 745bb6152..7c1003513 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from enum import Enum from typing import Dict, List diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py b/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py index c231c18a1..fd33ff592 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from tqdm import tqdm from parl.utils import logger diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py index 0c8a85825..68923e01b 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # coding=utf-8 # Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py b/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py index b35d33941..c359c0e0e 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py @@ -1,4 +1,17 @@ -from typing import Optional, Dict, Any +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py index 423d5bc34..684d53429 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import numpy as np from datasets import load_metric diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py index f81bf36b9..163a2afa4 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from datasets import load_metric diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py index 33ac81dc7..86e23e3d5 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import numpy as np diff --git a/benchmark/torch/RL4LMs/seq2seq_model.py b/benchmark/torch/RL4LMs/seq2seq_model.py index 22d26c688..6af6e191d 100644 --- a/benchmark/torch/RL4LMs/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/seq2seq_model.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch from torch import nn from transformers import AutoModelForSeq2SeqLM diff --git a/benchmark/torch/RL4LMs/t5_ppo.yml b/benchmark/torch/RL4LMs/t5_ppo.yml index 6d90889e7..2ffd592e3 100644 --- a/benchmark/torch/RL4LMs/t5_ppo.yml +++ b/benchmark/torch/RL4LMs/t5_ppo.yml @@ -70,4 +70,3 @@ train_evaluation: temperature: 0.7 min_length: 50 max_new_tokens: 100 - diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 4f30b8eee..2e211c7bc 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import sys from argparse import ArgumentParser @@ -5,7 +19,6 @@ import yaml import collections from parl.utils import logger - import torch import time @@ -18,6 +31,7 @@ # rollout from rl4lms_utils import DictRolloutBuffer, RolloutUtil + # agent, algorithm and model from rl4lm_ppo import RL4LMPPO from rl4lms_agent import RL4LMsAgent From f91d2c957a2aa99ebd28b2a5a7f9857a15ab7410 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Tue, 14 Mar 2023 21:47:16 +0800 Subject: [PATCH 21/34] bring for-batch-rollout loop out of rl4lms_ppo --- benchmark/torch/RL4LMs/instructor.py | 1 - benchmark/torch/RL4LMs/rl4lm_ppo.py | 183 ------------------------- benchmark/torch/RL4LMs/rl4lms_agent.py | 47 +++++-- benchmark/torch/RL4LMs/rl4lms_ppo.py | 183 +++++++++++++++++++++++++ benchmark/torch/RL4LMs/train.py | 9 +- 5 files changed, 222 insertions(+), 201 deletions(-) delete mode 100644 benchmark/torch/RL4LMs/rl4lm_ppo.py create mode 100644 benchmark/torch/RL4LMs/rl4lms_ppo.py diff --git a/benchmark/torch/RL4LMs/instructor.py b/benchmark/torch/RL4LMs/instructor.py index 37bcd4ad6..a59c677f3 100644 --- a/benchmark/torch/RL4LMs/instructor.py +++ b/benchmark/torch/RL4LMs/instructor.py @@ -67,7 +67,6 @@ def __init__( self._terminate_on_eos = terminate_on_eos self._context_start_token = context_start_token self._prompt_truncation_side = prompt_truncation_side - super().__init__() # set the observation and action space here self._vocab_size = tokenizer.vocab_size diff --git a/benchmark/torch/RL4LMs/rl4lm_ppo.py b/benchmark/torch/RL4LMs/rl4lm_ppo.py deleted file mode 100644 index a0b975ca4..000000000 --- a/benchmark/torch/RL4LMs/rl4lm_ppo.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import parl -import torch -from gym import spaces -from torch.nn import functional as F - - -class RL4LMPPO(parl.Algorithm): - def __init__( - self, - model, - learning_rate=3e-4, - n_steps=2048, - batch_size=64, - n_epochs=10, - gamma=0.99, - gae_lambda=0.95, - clip_range=0.2, - normalize_advantage=True, - ent_coef=0.0, - vf_coef=0.5, - max_grad_norm=0.5, - target_kl=None, - seed=None, - device="auto", - _init_setup_model=True, - ): - super(RL4LMPPO, self).__init__(model=model) - self.learning_rate = learning_rate - self.n_steps = n_steps - self.batch_size = batch_size - self.n_epochs = n_epochs - self.gamma = gamma - self.gae_lambda = gae_lambda - self.clip_range = clip_range - self.normalize_advantage = normalize_advantage - self.ent_coef = ent_coef - self.vf_coef = vf_coef - self.max_grad_norm = max_grad_norm - self.target_kl = target_kl - self.seed = seed - self.device = device - for param_group in self.model.optimizer.param_groups: - param_group["lr"] = self.learning_rate - - def learn(self, rollout_buffer, log_info): - entropy_losses = log_info["entropy_losses"] - pg_losses = log_info["pg_losses"] - value_losses = log_info["value_losses"] - clip_fractions = log_info["clip_fractions"] - approx_kl_divs = log_info["approx_kl_divs"] - continue_training = True - # Do a complete pass on the rollout buffer - for batch_ix, rollout_data in enumerate(list(rollout_buffer.get(self.batch_size))): - # self.verify_rollout_data(rollout_data) - - actions = rollout_data.actions - if isinstance(self.model.action_space, spaces.Discrete): - # Convert discrete action from float to long - actions = rollout_data.actions.long().flatten() - - values, log_prob, entropy = self.model.evaluate_actions(rollout_data.observations, actions) - values = values.flatten() - # Normalize advantage - advantages = rollout_data.advantages - if self.normalize_advantage: - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) - - # ratio between old and new policy, should be one at the first iteration - ratio = torch.exp(log_prob - rollout_data.old_log_prob) - - # clipped surrogate loss - policy_loss_1 = advantages * ratio - policy_loss_2 = advantages * \ - torch.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range) - policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean() - - # Logging - pg_losses.append(policy_loss.item()) - clip_fraction = torch.mean((torch.abs(ratio - 1) > self.clip_range).float()).item() - clip_fractions.append(clip_fraction) - - # No clipping - values_pred = values - - # Value loss using the TD(gae_lambda) target - value_loss = F.mse_loss(rollout_data.returns, values_pred) - value_losses.append(value_loss.item()) - - # Entropy loss favor exploration - if entropy is None: - # Approximate entropy when no analytical form - entropy_loss = -torch.mean(-log_prob) - else: - entropy_loss = -torch.mean(entropy) - - entropy_losses.append(entropy_loss.item()) - - loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss - - # Calculate approximate form of reverse KL Divergence for early stopping - # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 - # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 - # and Schulman blog: http://joschu.net/blog/kl-approx.html - with torch.no_grad(): - log_ratio = log_prob - rollout_data.old_log_prob - approx_kl_div = torch.mean((torch.exp(log_ratio) - 1) - log_ratio).cpu().numpy() - approx_kl_divs.append(approx_kl_div) - - if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: - continue_training = False - break - - # Optimization step - self.model.optimizer.zero_grad() - loss.backward() - # Clip grad norm - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) - self.model.optimizer.step() - - return continue_training, loss - - def predict(self, obs): - pass - - def value(self, obs): - pass - - def forward_value( - self, - obs, - ): - return self.model.forward_value(obs) - - def forward_policy( - self, - obs, - actions, - ): - return self.model.forward_policy( - obs=obs, - actions=actions, - ) - - def get_log_probs_ref_model( - self, - obs, - action, - ): - return self.model.get_log_probs_ref_model(obs, action) - - def sample( - self, - tokenizer, - texts=None, - max_prompt_length=None, - input_ids=None, - attention_mask=None, - gen_kwargs=None, - ): - return self.model.sample( - input_ids=input_ids, - attention_mask=attention_mask, - tokenizer=tokenizer, - texts=texts, - max_prompt_length=max_prompt_length, - gen_kwargs=gen_kwargs) - - def eval_mode(self): - self.model.eval() diff --git a/benchmark/torch/RL4LMs/rl4lms_agent.py b/benchmark/torch/RL4LMs/rl4lms_agent.py index fff8316eb..6f9bb94d6 100644 --- a/benchmark/torch/RL4LMs/rl4lms_agent.py +++ b/benchmark/torch/RL4LMs/rl4lms_agent.py @@ -14,7 +14,7 @@ import parl import numpy as np - +from gym import spaces import torch from parl.utils import logger @@ -38,13 +38,14 @@ class RL4LMsAgent(parl.Agent): def __init__( self, algorithm, - alg_config, + n_epochs, + batch_size=64, norm_reward=False, ): super(RL4LMsAgent, self).__init__(algorithm) self.dataset = None - self.config = alg_config - self.n_epochs = alg_config["args"]["n_epochs"] + self.n_epochs = n_epochs + self.batch_size = batch_size self._norm_reward = norm_reward self._n_updates = 0 @@ -53,24 +54,44 @@ def learn(self, rollout_buffer): pg_losses, value_losses = [], [] clip_fractions = [] approx_kl_divs = [] - log_info = { - "entropy_losses": entropy_losses, - "pg_losses": pg_losses, - "value_losses": value_losses, - "clip_fractions": clip_fractions, - "approx_kl_divs": approx_kl_divs - } loss = torch.tensor(0.0) # train for n_epochs epochs for epoch in range(self.n_epochs): - continue_training, loss = self.alg.learn(rollout_buffer=rollout_buffer, log_info=log_info) + continue_training = True + + for batch_ix, rollout_data in enumerate(list(rollout_buffer.get(self.batch_size))): + batch_action = rollout_data.actions + if isinstance(self.alg.model.action_space, spaces.Discrete): + # Convert discrete action from float to long + batch_action = rollout_data.actions.long().flatten() + batch_obs = rollout_data.observations + batch_adv = rollout_data.advantages + batch_logprob = rollout_data.old_log_prob + batch_return = rollout_data.returns + + continue_training, alg_learn_info = self.alg.learn( + batch_obs=batch_obs, + batch_action=batch_action, + batch_logprob=batch_logprob, + batch_return=batch_return, + batch_adv=batch_adv) + + entropy_losses.append(alg_learn_info["entropy_losses"]) + pg_losses.append(alg_learn_info["pg_losses"]) + value_losses.append(alg_learn_info["value_losses"]) + clip_fractions.append(alg_learn_info["clip_fractions"]) + approx_kl_divs.append(alg_learn_info["approx_kl_divs"]) + if not continue_training: + break + + self._n_updates += 1 # according to stable-baseline3 if not continue_training: print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_divs[-1]:.2f}") break - self._n_updates += self.n_epochs + # self._n_updates += self.n_epochs # change original RL4LMs code explained_var = explained_variance(rollout_buffer.values.flatten(), rollout_buffer.returns.flatten()) # Logs diff --git a/benchmark/torch/RL4LMs/rl4lms_ppo.py b/benchmark/torch/RL4LMs/rl4lms_ppo.py new file mode 100644 index 000000000..e65ab92bd --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_ppo.py @@ -0,0 +1,183 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import torch +from torch.nn import functional as F + + +class RL4LMsPPO(parl.Algorithm): + def __init__( + self, + model, + learning_rate=3e-4, + n_steps=2048, + n_epochs=10, + gamma=0.99, + gae_lambda=0.95, + clip_range=0.2, + normalize_advantage=True, + ent_coef=0.0, + vf_coef=0.5, + max_grad_norm=0.5, + target_kl=None, + seed=None, + device="auto", + use_clipped_value_loss=False, + ): + super(RL4LMsPPO, self).__init__(model=model) + self.learning_rate = learning_rate + self.n_steps = n_steps + self.n_epochs = n_epochs + self.gamma = gamma + self.gae_lambda = gae_lambda + self.clip_range = clip_range + self.normalize_advantage = normalize_advantage + self.ent_coef = ent_coef + self.vf_coef = vf_coef + self.max_grad_norm = max_grad_norm + self.target_kl = target_kl + self.seed = seed + self.device = device + self.use_clipped_value_loss = use_clipped_value_loss + for param_group in self.model.optimizer.param_groups: + param_group["lr"] = self.learning_rate + + def learn(self, + batch_obs, + batch_action, + batch_logprob, + batch_return, + batch_adv): + # Do a complete pass on the rollout batch + continue_training = True + learn_info = {"entropy_losses": None, + "pg_losses": None, + "value_losses": None, + "clip_fractions": None, + "approx_kl_divs": None, + "loss":None} + + values, action_log_probs, entropy = self.model.evaluate_actions(batch_obs, batch_action) + values = values.flatten() + + # Normalize advantage + if self.normalize_advantage: + batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = torch.exp(action_log_probs - batch_logprob) + + # clipped surrogate loss + surr1 = ratio * batch_adv + surr2 = torch.clamp(ratio, 1 - self.clip_range, + 1 + self.clip_range) * batch_adv + + policy_loss = -torch.min(surr1, surr2).mean() + + # Logging + learn_info["pg_losses"] = policy_loss.item() + clip_fraction = torch.mean((torch.abs(ratio - 1) > self.clip_range).float()).item() + learn_info["clip_fractions"] = clip_fraction + + # No clipping + values_pred = values + + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(batch_return, values_pred) + learn_info["value_losses"] = value_loss.item() + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -torch.mean(-action_log_probs) + else: + entropy_loss = -torch.mean(entropy) + + learn_info["entropy_losses"] = entropy_loss.item() + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with torch.no_grad(): + log_ratio = action_log_probs - batch_logprob + approx_kl_div = torch.mean((torch.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + learn_info["approx_kl_divs"] = approx_kl_div + + learn_info["loss"] = loss + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + return continue_training, learn_info + + # Optimization step + self.model.optimizer.zero_grad() + loss.backward() + # Clip grad norm + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.model.optimizer.step() + + return continue_training, learn_info + + def predict(self, obs): + pass + + def value(self, obs): + pass + + def forward_value( + self, + obs, + ): + return self.model.forward_value(obs) + + def forward_policy( + self, + obs, + actions, + ): + return self.model.forward_policy( + obs=obs, + actions=actions, + ) + + def get_log_probs_ref_model( + self, + obs, + action, + ): + return self.model.get_log_probs_ref_model(obs, action) + + def sample( + self, + tokenizer, + texts=None, + max_prompt_length=None, + input_ids=None, + attention_mask=None, + gen_kwargs=None, + ): + return self.model.sample( + input_ids=input_ids, + attention_mask=attention_mask, + tokenizer=tokenizer, + texts=texts, + max_prompt_length=max_prompt_length, + gen_kwargs=gen_kwargs) + + def eval_mode(self): + self.model.eval() diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 2e211c7bc..6558a5fae 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -33,7 +33,7 @@ from rl4lms_utils import DictRolloutBuffer, RolloutUtil # agent, algorithm and model -from rl4lm_ppo import RL4LMPPO +from rl4lms_ppo import RL4LMsPPO from rl4lms_agent import RL4LMsAgent from seq2seq_model import Seq2SeqLMModel @@ -74,15 +74,16 @@ def main(config): apply_model_parallel=config["alg"]["model"]["args"]["apply_model_parallel"], prompt_truncation_side=config["alg"]["model"]["args"]["prompt_truncation_side"], generation_kwargs=config["alg"]["model"]["args"]["generation_kwargs"]) - rl4lm_alg = RL4LMPPO( + rl4lm_alg = RL4LMsPPO( model=rl4lms_model, device=device, n_steps=config["alg"]["args"]["n_steps"], - batch_size=config["alg"]["args"]["batch_size"], learning_rate=config["alg"]["args"]["learning_rate"], n_epochs=config["alg"]["args"]["n_epochs"], ent_coef=config["alg"]["args"]["ent_coef"]) - agent = RL4LMsAgent(rl4lm_alg, config["alg"]) + agent = RL4LMsAgent(rl4lm_alg, + n_epochs=config["alg"]["args"]["n_epochs"], + batch_size=config["alg"]["args"]["batch_size"],) rollout_buffer = DictRolloutBuffer( buffer_size=agent.alg.n_steps * instructor_group.n_instructors, From dc1d8357bb2feebc721d11e8eddd5612ea3e9c44 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Wed, 15 Mar 2023 11:39:27 +0800 Subject: [PATCH 22/34] change name of policy/value , obs-preprocess and add-to-buffer --- benchmark/torch/RL4LMs/instructor.py | 4 - benchmark/torch/RL4LMs/rl4lms_agent.py | 24 ++-- benchmark/torch/RL4LMs/rl4lms_ppo.py | 21 ++- .../torch/RL4LMs/rl4lms_utils/examiner.py | 2 +- .../torch/RL4LMs/rl4lms_utils/rollout_util.py | 133 ++++++++---------- benchmark/torch/RL4LMs/seq2seq_model.py | 15 +- benchmark/torch/RL4LMs/train.py | 2 +- 7 files changed, 96 insertions(+), 105 deletions(-) diff --git a/benchmark/torch/RL4LMs/instructor.py b/benchmark/torch/RL4LMs/instructor.py index a59c677f3..b430c0053 100644 --- a/benchmark/torch/RL4LMs/instructor.py +++ b/benchmark/torch/RL4LMs/instructor.py @@ -31,10 +31,6 @@ def _flatten_obs(obs, space, n_instructor=None): return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) -def dict_to_tensor(obs, device): - return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} - - @parl.remote_class(wait=False) class Instructor: def __init__( diff --git a/benchmark/torch/RL4LMs/rl4lms_agent.py b/benchmark/torch/RL4LMs/rl4lms_agent.py index 6f9bb94d6..194443b59 100644 --- a/benchmark/torch/RL4LMs/rl4lms_agent.py +++ b/benchmark/torch/RL4LMs/rl4lms_agent.py @@ -48,6 +48,7 @@ def __init__( self.batch_size = batch_size self._norm_reward = norm_reward self._n_updates = 0 + self.device = self.alg.model.device def learn(self, rollout_buffer): entropy_losses = [] @@ -127,25 +128,28 @@ def learn(self, rollout_buffer): logger.info(ppo_train_info) - def get_inputs_for_generation(self, obs_tensor): + def get_inputs_for_generation(self, dict_obs_tensor): + obs_tensor = self.prepare_obs_input(dict_obs_tensor) return self.alg.model.get_inputs_for_generation(obs_tensor) - def predict(self, *args, **kwargs): - # only use sample - pass + def prepare_obs_input(self, obs): + return {key: torch.as_tensor(_obs).to(self.device) for (key, _obs) in obs.items()} - def forward_value( + def value( self, obs, ): - return self.alg.forward_value(obs) + return self.alg.value(obs) - def forward_policy( + # note: RL4LMs uses the same way (language model always does sample() to generate in summarization + # task) for collecting data and testing, so here sample() only needs to return info + # like log_prob and gen_kwargs without action + def policy( self, obs, actions, ): - return self.alg.forward_policy( + return self.alg.policy( obs=obs, actions=actions, ) @@ -157,7 +161,7 @@ def get_log_probs_ref_model( ): return self.alg.get_log_probs_ref_model(obs, action) - def sample( + def predict( self, tokenizer, texts=None, @@ -166,7 +170,7 @@ def sample( attention_mask=None, gen_kwargs=None, ): - return self.alg.sample( + return self.alg.predict( input_ids=input_ids, attention_mask=attention_mask, tokenizer=tokenizer, diff --git a/benchmark/torch/RL4LMs/rl4lms_ppo.py b/benchmark/torch/RL4LMs/rl4lms_ppo.py index e65ab92bd..ffdb6210e 100644 --- a/benchmark/torch/RL4LMs/rl4lms_ppo.py +++ b/benchmark/torch/RL4LMs/rl4lms_ppo.py @@ -133,24 +133,21 @@ def learn(self, return continue_training, learn_info - def predict(self, obs): - pass - - def value(self, obs): - pass - - def forward_value( + def value( self, obs, ): - return self.model.forward_value(obs) + return self.model.value(obs) - def forward_policy( + # note: RL4LMs uses the same way (language model always does sample() to generate in summarization + # task) for collecting data and testing, so here policy() only needs to return info + # like log_prob and gen_kwargs without action + def policy( self, obs, actions, ): - return self.model.forward_policy( + return self.model.policy( obs=obs, actions=actions, ) @@ -162,7 +159,7 @@ def get_log_probs_ref_model( ): return self.model.get_log_probs_ref_model(obs, action) - def sample( + def predict( self, tokenizer, texts=None, @@ -171,7 +168,7 @@ def sample( attention_mask=None, gen_kwargs=None, ): - return self.model.sample( + return self.model.predict( input_ids=input_ids, attention_mask=attention_mask, tokenizer=tokenizer, diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py b/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py index fd33ff592..657f9505e 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py @@ -120,6 +120,6 @@ def _generate_text( dt_control_token, ): prompt_texts = [dt_control_token + sample.prompt_or_input_text for sample in samples] - generated_texts = policy.sample( + generated_texts = policy.predict( tokenizer, prompt_texts, max_prompt_length, gen_kwargs=self._gen_kwargs).gen_texts return generated_texts diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py index 86e23e3d5..4ea34b319 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py @@ -21,10 +21,6 @@ from .data_wrapper import TransitionInfo -def dict_to_tensor(obs, device): - return {key: torch.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} - - def get_one_token_obs(obs, idx, space): return OrderedDict([(k, obs[k][:, idx, :]) for k in space.spaces.keys()]) @@ -43,51 +39,11 @@ def unpack_observations(obs_tensor, n_instructors): return unpacked_obs -def add_to_buffer(rollout_buffer, episode_wise_transitions, rollout_info): - advantages_computed = False - for ep_ix, transitions in enumerate(episode_wise_transitions): - ep_length = len(transitions) - total_reward = 0.0 - total_kl_reward = 0.0 - for transition_ix, transition in enumerate(transitions): - total_reward += transition.task_reward - total_kl_reward += transition.kl_reward - rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div) - rollout_info["rollout_info/log_prob"].append(transition.log_prob) - rollout_info["rollout_info/ref_log_prob"].append(transition.ref_log_prob) - rollout_info["rollout_info/values"].append(transition.value.numpy()) - - if not rollout_buffer.full: - rollout_buffer.add( - transition.observation, - transition.action, - transition.total_reward, - transition.episode_start, - transition.value, - transition.log_prob, - ) - - # if the buffer is full, compute advantages - if rollout_buffer.full and not advantages_computed: - # we fetch the last value for the last time step - # values come from the next transitions's values - next_values = (transitions[transition_ix + 1].value if - (transition_ix + 1) < ep_length else torch.tensor([0.0])) - - rollout_buffer.compute_returns_and_advantage(last_values=next_values, dones=transition.done) - advantages_computed = True - - rollout_info["rollout_info/ep_rew"].append(total_reward) - rollout_info["rollout_info/ep_lens"].append(ep_length) - rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward) - return rollout_info - - class RolloutUtil: def __init__(self, kl_args): self._kl_controller = KLController(kl_args["coeff"], kl_args["target_kl"]) - def collect_rollouts(self, agent, instructor_group, rollout_buffer, device): + def collect_rollouts(self, agent, instructor_group, rollout_buffer): # get tokenizer tokenizer = instructor_group.tokenizer @@ -113,9 +69,11 @@ def collect_rollouts(self, agent, instructor_group, rollout_buffer, device): current_obs = instructor_group.ask() # generate sentences using the model - obs_tensor = dict_to_tensor(current_obs, device) - generation_inputs = agent.get_inputs_for_generation(obs_tensor) - gen_output = agent.sample( + generation_inputs = agent.get_inputs_for_generation(current_obs) + + # note: RL4LMs uses the same way (language model always does sample() to generate in summarization + # task) for collecting data and testing, so here agent uses predict() rather than sample() + gen_output = agent.predict( input_ids=generation_inputs.inputs, attention_mask=generation_inputs.attention_masks, tokenizer=tokenizer) @@ -125,7 +83,7 @@ def collect_rollouts(self, agent, instructor_group, rollout_buffer, device): gen_output=gen_output) # generate batch of rollouts and add to buffer - rollout_info, run_timesteps = self._generate_transition_and_add_to_buffer( + episode_wise_transitions, run_timesteps = self._generate_transition( gen_sentence=gen_output, init_obs=current_obs, agent=agent, @@ -135,12 +93,49 @@ def collect_rollouts(self, agent, instructor_group, rollout_buffer, device): sentence_rewards=sentence_rewards, sentence_dones=sentence_dones, sentence_infos=sentence_infos, - rollout_buffer=rollout_buffer, - rollout_info=rollout_info, - device=device, ) num_timesteps += run_timesteps + # now we flush all episode wise info to the 1-D buffer + # log transition and add to buffer + advantages_computed = False + for ep_ix, transitions in enumerate(episode_wise_transitions): + ep_length = len(transitions) + total_reward = 0.0 + total_kl_reward = 0.0 + for transition_ix, transition in enumerate(transitions): + total_reward += transition.task_reward + total_kl_reward += transition.kl_reward + rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div) + rollout_info["rollout_info/log_prob"].append(transition.log_prob) + rollout_info["rollout_info/ref_log_prob"].append(transition.ref_log_prob) + rollout_info["rollout_info/values"].append(transition.value.numpy()) + + # add to buffer + if not rollout_buffer.full: + rollout_buffer.add( + transition.observation, + transition.action, + transition.total_reward, + transition.episode_start, + transition.value, + transition.log_prob, + ) + + # if the buffer is full, compute advantages + if rollout_buffer.full and not advantages_computed: + # we fetch the last value for the last time step + # values come from the next transitions's values + next_values = (transitions[transition_ix + 1].value if + (transition_ix + 1) < ep_length else torch.tensor([0.0])) + + rollout_buffer.compute_returns_and_advantage(last_values=next_values, dones=transition.done) + advantages_computed = True + + rollout_info["rollout_info/ep_rew"].append(total_reward) + rollout_info["rollout_info/ep_lens"].append(ep_length) + rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward) + # aggregate rollout info aggregated_rollout_info = {} for key, values in rollout_info.items(): @@ -154,19 +149,16 @@ def collect_rollouts(self, agent, instructor_group, rollout_buffer, device): self._kl_controller.step(torch.tensor(aggregated_rollout_info["rollout_info/kl_div_mean"])) return num_timesteps - def _generate_transition_and_add_to_buffer(self, - gen_sentence=None, - agent=None, - n_instructors=None, - obs_space=None, - rollout_buffer=None, - rollout_info=None, - device=None, - sentence_new_obs=None, - sentence_rewards=None, - sentence_dones=None, - sentence_infos=None, - init_obs=None): + def _generate_transition(self, + gen_sentence=None, + agent=None, + n_instructors=None, + obs_space=None, + sentence_new_obs=None, + sentence_rewards=None, + sentence_dones=None, + sentence_infos=None, + init_obs=None): current_obs = init_obs review_times = 0 @@ -181,15 +173,16 @@ def _generate_transition_and_add_to_buffer(self, # evaluate actions with actions from rollout with torch.no_grad(): - obs_tensor = dict_to_tensor(current_obs, device) + # prepare here for forward of value_model, policy_model and ref_model + obs_tensor = agent.prepare_obs_input(current_obs) - _, log_probs, _, _ = agent.forward_policy(obs=obs_tensor, actions=actions_tensor) + log_probs, _, _ = agent.policy(obs=obs_tensor, actions=actions_tensor) # sanity check assert torch.all(torch.isfinite(log_probs)), "Infinite values in log probs" # get values - values, _ = agent.forward_value(obs_tensor) + values, _ = agent.value(obs_tensor) # get reference log probs ref_log_probs, _ = agent.get_log_probs_ref_model(obs_tensor, actions_tensor) @@ -243,6 +236,4 @@ def _generate_transition_and_add_to_buffer(self, episode_starts = np.zeros((n_instructors, ), dtype=bool) current_obs = new_obs - # now we flush all episode wise info to the 1-D buffer - rollout_info = add_to_buffer(rollout_buffer, episode_wise_transitions, rollout_info) - return rollout_info, review_times + return episode_wise_transitions, review_times diff --git a/benchmark/torch/RL4LMs/seq2seq_model.py b/benchmark/torch/RL4LMs/seq2seq_model.py index 6af6e191d..21235dcc1 100644 --- a/benchmark/torch/RL4LMs/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/seq2seq_model.py @@ -79,7 +79,10 @@ def _build_model_heads(self, model_name): self._value_model = torch.nn.DataParallel(self._value_model) self._value_head = torch.nn.DataParallel(self._value_head.to(self.device)) - def forward_policy( + # note: RL4LMs uses the same way (language model always does sample() to generate in summarization + # task) for collecting data and testing, so here policy() only needs to return info + # like log_prob and gen_kwargs without action + def policy( self, obs, actions, @@ -124,9 +127,9 @@ def forward_policy( dim=-1, ) - return actions, log_prob, entropy, past_model_kwargs + return log_prob, entropy, past_model_kwargs - def forward_value( + def value( self, obs, ): @@ -172,8 +175,8 @@ def forward_value( def evaluate_actions(self, obs, actions): - _, log_prob, entropy, _ = self.forward_policy(obs=obs, actions=actions) - values, _ = self.forward_value(obs) + log_prob, entropy, _ = self.policy(obs=obs, actions=actions) + values, _ = self.value(obs) return values, log_prob, entropy def to(self, device): @@ -240,7 +243,7 @@ def get_inputs_for_generation(self, obs): def get_language_model(self): return unwrap_model(self._policy_model) - def sample( + def predict( self, tokenizer, texts=None, diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 6558a5fae..2b3e2ddd1 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -124,7 +124,7 @@ def main(config): num_timesteps = 0 while num_timesteps < n_steps_per_iter: - run_timesteps = rollout_util.collect_rollouts(agent, instructor_group, rollout_buffer, device) + run_timesteps = rollout_util.collect_rollouts(agent, instructor_group, rollout_buffer) num_timesteps += run_timesteps agent.learn(rollout_buffer) From a23e8fe414650fc696e0e9bb8843af99cb5b8145 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Wed, 15 Mar 2023 14:57:27 +0800 Subject: [PATCH 23/34] change config structure --- benchmark/torch/RL4LMs/instructor.py | 3 +- benchmark/torch/RL4LMs/rl4lms_ppo.py | 2 - benchmark/torch/RL4LMs/t5_ppo.yml | 79 +++++++++++++++------------- benchmark/torch/RL4LMs/train.py | 42 +++++++-------- 4 files changed, 64 insertions(+), 62 deletions(-) diff --git a/benchmark/torch/RL4LMs/instructor.py b/benchmark/torch/RL4LMs/instructor.py index b430c0053..7d30d844f 100644 --- a/benchmark/torch/RL4LMs/instructor.py +++ b/benchmark/torch/RL4LMs/instructor.py @@ -189,7 +189,6 @@ class InstructorGroup: def __init__( self, instructor_config=None, - reward_config=None, tokenizer=None, datapool_config=None, tokenizer_config=None, @@ -197,7 +196,7 @@ def __init__( self.n_instructors = instructor_config["n_instructors"] # remote instructors need to use config to initialize due to serialization problem instructor_kwargs = { - "reward_config": reward_config, + "reward_config": instructor_config["reward_fn"], "tokenizer_config": tokenizer_config, "datapool_config": datapool_config } diff --git a/benchmark/torch/RL4LMs/rl4lms_ppo.py b/benchmark/torch/RL4LMs/rl4lms_ppo.py index ffdb6210e..0bf9506ae 100644 --- a/benchmark/torch/RL4LMs/rl4lms_ppo.py +++ b/benchmark/torch/RL4LMs/rl4lms_ppo.py @@ -23,7 +23,6 @@ def __init__( model, learning_rate=3e-4, n_steps=2048, - n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, @@ -39,7 +38,6 @@ def __init__( super(RL4LMsPPO, self).__init__(model=model) self.learning_rate = learning_rate self.n_steps = n_steps - self.n_epochs = n_epochs self.gamma = gamma self.gae_lambda = gae_lambda self.clip_range = clip_range diff --git a/benchmark/torch/RL4LMs/t5_ppo.yml b/benchmark/torch/RL4LMs/t5_ppo.yml index 2ffd592e3..9340c6ec4 100644 --- a/benchmark/torch/RL4LMs/t5_ppo.yml +++ b/benchmark/torch/RL4LMs/t5_ppo.yml @@ -1,4 +1,4 @@ - +# config for RL4LMs summarization tokenizer: model_name: t5-base @@ -6,9 +6,6 @@ tokenizer: truncation_side: left pad_token_as_eos_token: False -reward_fn: - args: - rouge_type: "rouge1" datapool: id: cnn_daily_mail @@ -19,6 +16,9 @@ datapool: instructor: parl_master_address: "localhost:8811" n_instructors: 10 + reward_fn: + args: + rouge_type: "rouge1" args: max_prompt_length: 512 max_episode_length: 100 @@ -26,47 +26,54 @@ instructor: prompt_truncation_side: "right" context_start_token: 0 -alg: +kl_div: + coeff: 0.001 + target_kl: 0.2 + +agent: args: - n_steps: 512 batch_size: 32 - learning_rate: 0.000002 n_epochs: 5 - ent_coef: 0.0 - kl_div: - coeff: 0.001 - target_kl: 0.2 - model: - id: seq2seq_lm_actor_critic_model + alg: args: - model_name: t5-base - apply_model_parallel: True - prompt_truncation_side: "right" - generation_kwargs: - do_sample: True - top_k: 50 - min_length: 50 - max_new_tokens: 100 - -train_evaluation: - eval_batch_size: 100 - n_iters: 100 - eval_every: 10 - save_every: 1 + n_steps: 512 + learning_rate: 0.000002 + ent_coef: 0.0 + model: + args: + model_name: t5-base + apply_model_parallel: True + prompt_truncation_side: "right" + generation_kwargs: + do_sample: True + top_k: 50 + min_length: 50 + max_new_tokens: 100 + +examiner: + args: + max_prompt_length: 512 + eval_batch_size: 100 + generation_kwargs: + do_sample: True + top_k: 0 + temperature: 0.7 + min_length: 50 + max_new_tokens: 100 metrics: - id: meteor - args: {} + args: { } - id: rouge - id: bleu - args: {} + args: { } - id: bert_score args: language: en - id: diversity - args: {} - generation_kwargs: - do_sample: True - top_k: 0 - temperature: 0.7 - min_length: 50 - max_new_tokens: 100 + args: { } + + +train_evaluation: + n_iters: 100 + eval_every: 10 + diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 2b3e2ddd1..b00c78d6b 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -52,38 +52,36 @@ def main(config): tokenizer = build_tokenizer(config["tokenizer"]) - # metrics - metrics = build_metrics(config["train_evaluation"]["metrics"]) - # datapool samples_by_split = build_datapool(config["datapool"]) instructor_group = InstructorGroup( instructor_config=config["instructor"], - reward_config=config["reward_fn"], tokenizer=tokenizer, tokenizer_config=config["tokenizer"], datapool_config=config["datapool"], ) + model_config = config["agent"]["alg"]["model"] rl4lms_model = Seq2SeqLMModel( observation_space=instructor_group.observation_space, action_space=instructor_group.action_space, device=device, - model_name=config["alg"]["model"]["args"]["model_name"], - apply_model_parallel=config["alg"]["model"]["args"]["apply_model_parallel"], - prompt_truncation_side=config["alg"]["model"]["args"]["prompt_truncation_side"], - generation_kwargs=config["alg"]["model"]["args"]["generation_kwargs"]) + model_name=model_config["args"]["model_name"], + apply_model_parallel=model_config["args"]["apply_model_parallel"], + prompt_truncation_side=model_config["args"]["prompt_truncation_side"], + generation_kwargs=model_config["args"]["generation_kwargs"]) + alg_config = config["agent"]["alg"] rl4lm_alg = RL4LMsPPO( model=rl4lms_model, device=device, - n_steps=config["alg"]["args"]["n_steps"], - learning_rate=config["alg"]["args"]["learning_rate"], - n_epochs=config["alg"]["args"]["n_epochs"], - ent_coef=config["alg"]["args"]["ent_coef"]) + n_steps=alg_config["args"]["n_steps"], + learning_rate=alg_config["args"]["learning_rate"], + ent_coef=alg_config["args"]["ent_coef"]) + agent_config = config["agent"] agent = RL4LMsAgent(rl4lm_alg, - n_epochs=config["alg"]["args"]["n_epochs"], - batch_size=config["alg"]["args"]["batch_size"],) + n_epochs=agent_config["args"]["n_epochs"], + batch_size=agent_config["args"]["batch_size"],) rollout_buffer = DictRolloutBuffer( buffer_size=agent.alg.n_steps * instructor_group.n_instructors, @@ -93,23 +91,23 @@ def main(config): gamma=agent.alg.gamma, gae_lambda=agent.alg.gae_lambda, ) - rollout_util = RolloutUtil(config["alg"]["kl_div"]) + rollout_util = RolloutUtil(config["kl_div"]) n_iters = int(config["train_evaluation"]["n_iters"]) n_steps_per_iter = instructor_group.n_instructors * agent.alg.n_steps - max_prompt_length = config["instructor"]["args"]["max_prompt_length"] - # gen kwargs for evaluation - eval_gen_kwargs = config["train_evaluation"]["generation_kwargs"] - eval_batch_size = config["train_evaluation"]["eval_batch_size"] + examiner_config = config["examiner"] + # metrics + metrics = build_metrics(examiner_config["metrics"]) examiner = Examiner( tokenizer=tokenizer, - eval_batch_size=eval_batch_size, + eval_batch_size=examiner_config["args"]["eval_batch_size"], + max_prompt_length=examiner_config["args"]["max_prompt_length"], + eval_gen_kwargs=examiner_config["args"]["generation_kwargs"], metrics=metrics, - eval_gen_kwargs=eval_gen_kwargs, samples_by_split=samples_by_split, - max_prompt_length=max_prompt_length) + ) iter_start = 0 examiner.evaluate(policy=agent.alg.model, sample_name_list=["val", "test"], epoch=iter_start) From c2be52f734432ecbd3cca55fe7c297bf0b65d394 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Wed, 15 Mar 2023 16:38:58 +0800 Subject: [PATCH 24/34] change ppo code style according to parl ppo --- benchmark/torch/RL4LMs/rl4lms_agent.py | 10 +- benchmark/torch/RL4LMs/rl4lms_ppo.py | 119 ++++++++++-------- benchmark/torch/RL4LMs/rl4lms_utils/buffer.py | 4 +- benchmark/torch/RL4LMs/t5_ppo.yml | 9 +- benchmark/torch/RL4LMs/train.py | 13 +- 5 files changed, 88 insertions(+), 67 deletions(-) diff --git a/benchmark/torch/RL4LMs/rl4lms_agent.py b/benchmark/torch/RL4LMs/rl4lms_agent.py index 194443b59..510a7e2d7 100644 --- a/benchmark/torch/RL4LMs/rl4lms_agent.py +++ b/benchmark/torch/RL4LMs/rl4lms_agent.py @@ -71,12 +71,14 @@ def learn(self, rollout_buffer): batch_adv = rollout_data.advantages batch_logprob = rollout_data.old_log_prob batch_return = rollout_data.returns + batch_value = rollout_data.old_values continue_training, alg_learn_info = self.alg.learn( batch_obs=batch_obs, batch_action=batch_action, - batch_logprob=batch_logprob, + batch_value=batch_value, batch_return=batch_return, + batch_logprob=batch_logprob, batch_adv=batch_adv) entropy_losses.append(alg_learn_info["entropy_losses"]) @@ -115,7 +117,7 @@ def learn(self, rollout_buffer): # self._n_updates, exclude="tensorboard") # self.logger.record("train/clip_range", clip_range) train_info["train/n_updates"] = self._n_updates - train_info["train/clip_range"] = self.alg.clip_range + train_info["train/clip_param"] = self.alg.clip_param logger.info(train_info) @@ -141,8 +143,8 @@ def value( ): return self.alg.value(obs) - # note: RL4LMs uses the same way (language model always does sample() to generate in summarization - # task) for collecting data and testing, so here sample() only needs to return info + # note: RL4LMs uses the same way (language model always does sample() to generate in summarization task) for + # collecting data and testing, so here use policy() instead of sample() and only need to return info # like log_prob and gen_kwargs without action def policy( self, diff --git a/benchmark/torch/RL4LMs/rl4lms_ppo.py b/benchmark/torch/RL4LMs/rl4lms_ppo.py index 0bf9506ae..5b47ac5a3 100644 --- a/benchmark/torch/RL4LMs/rl4lms_ppo.py +++ b/benchmark/torch/RL4LMs/rl4lms_ppo.py @@ -14,98 +14,113 @@ import parl import torch -from torch.nn import functional as F +from parl.utils.utils import check_model_method class RL4LMsPPO(parl.Algorithm): def __init__( self, model, - learning_rate=3e-4, - n_steps=2048, - gamma=0.99, - gae_lambda=0.95, - clip_range=0.2, - normalize_advantage=True, - ent_coef=0.0, - vf_coef=0.5, + clip_param=0.2, + value_loss_coef=0.5, + entropy_coef=0.0, + initial_lr=3e-4, max_grad_norm=0.5, + use_clipped_value_loss=False, + norm_adv=True, target_kl=None, seed=None, - device="auto", - use_clipped_value_loss=False, ): + # check model method + check_model_method(model, 'value', self.__class__.__name__) + check_model_method(model, 'policy', self.__class__.__name__) + + assert isinstance(clip_param, float) + assert isinstance(value_loss_coef, float) + assert isinstance(entropy_coef, float) + assert isinstance(initial_lr, float) + assert isinstance(max_grad_norm, float) + assert isinstance(use_clipped_value_loss, bool) + assert isinstance(norm_adv, bool) + super(RL4LMsPPO, self).__init__(model=model) - self.learning_rate = learning_rate - self.n_steps = n_steps - self.gamma = gamma - self.gae_lambda = gae_lambda - self.clip_range = clip_range - self.normalize_advantage = normalize_advantage - self.ent_coef = ent_coef - self.vf_coef = vf_coef + self.initial_lr = initial_lr + self.clip_param = clip_param + self.norm_adv = norm_adv + self.entropy_coef = entropy_coef + self.value_loss_coef = value_loss_coef self.max_grad_norm = max_grad_norm self.target_kl = target_kl self.seed = seed - self.device = device self.use_clipped_value_loss = use_clipped_value_loss + for param_group in self.model.optimizer.param_groups: - param_group["lr"] = self.learning_rate + param_group["lr"] = self.initial_lr + self.optimizer = self.model.optimizer def learn(self, batch_obs, batch_action, - batch_logprob, + batch_value, batch_return, - batch_adv): + batch_logprob, + batch_adv, + lr=None): # Do a complete pass on the rollout batch continue_training = True learn_info = {"entropy_losses": None, - "pg_losses": None, - "value_losses": None, - "clip_fractions": None, - "approx_kl_divs": None, + "pg_losses": None, + "value_losses": None, + "clip_fractions": None, + "approx_kl_divs": None, "loss":None} - values, action_log_probs, entropy = self.model.evaluate_actions(batch_obs, batch_action) + values, _ = self.model.value(batch_obs) + action_log_probs, entropy, _ = self.model.policy(batch_obs, batch_action) values = values.flatten() + entropy_loss = torch.mean(entropy) + learn_info["entropy_losses"] = entropy_loss.item() # Normalize advantage - if self.normalize_advantage: - batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-8) + if self.norm_adv: + batch_adv = (batch_adv - batch_adv.mean()) / ( + batch_adv.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = torch.exp(action_log_probs - batch_logprob) # clipped surrogate loss surr1 = ratio * batch_adv - surr2 = torch.clamp(ratio, 1 - self.clip_range, - 1 + self.clip_range) * batch_adv + surr2 = torch.clamp(ratio, 1.0 - self.clip_param, + 1.0 + self.clip_param) * batch_adv - policy_loss = -torch.min(surr1, surr2).mean() + action_loss = -torch.min(surr1, surr2).mean() # Logging - learn_info["pg_losses"] = policy_loss.item() - clip_fraction = torch.mean((torch.abs(ratio - 1) > self.clip_range).float()).item() + learn_info["pg_losses"] = action_loss.item() + clip_fraction = torch.mean((torch.abs(ratio - 1) > self.clip_param).float()).item() learn_info["clip_fractions"] = clip_fraction - # No clipping - values_pred = values + # clipping + # values_pred = values + if self.use_clipped_value_loss: + value_pred_clipped = batch_value + torch.clamp( + values - batch_value, + -self.clip_param, + self.clip_param, + ) + value_losses = (values - batch_return).pow(2) + value_losses_clipped = (value_pred_clipped - batch_return).pow(2) + value_loss = 0.5 * torch.max(value_losses, + value_losses_clipped).mean() + else: + value_loss = 0.5 * (batch_return - values).pow(2).mean() # Value loss using the TD(gae_lambda) target - value_loss = F.mse_loss(batch_return, values_pred) + # value_loss = F.mse_loss(batch_return, values_pred) learn_info["value_losses"] = value_loss.item() - # Entropy loss favor exploration - if entropy is None: - # Approximate entropy when no analytical form - entropy_loss = -torch.mean(-action_log_probs) - else: - entropy_loss = -torch.mean(entropy) - - learn_info["entropy_losses"] = entropy_loss.item() - - loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + loss = value_loss * self.value_loss_coef + action_loss - self.entropy_coef * entropy_loss # Calculate approximate form of reverse KL Divergence for early stopping # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 @@ -122,12 +137,16 @@ def learn(self, continue_training = False return continue_training, learn_info + if lr: + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + # Optimization step - self.model.optimizer.zero_grad() + self.optimizer.zero_grad() loss.backward() # Clip grad norm torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) - self.model.optimizer.step() + self.optimizer.step() return continue_training, learn_info diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py index e483a8886..a93a1332c 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py @@ -79,7 +79,7 @@ def __init__( observation_space, action_space, device="cpu", - gae_lambda=1, + gae_lambda=0.95, gamma=0.99, ): self.buffer_size = buffer_size @@ -200,7 +200,7 @@ def swap_and_flatten(self, arr): """ Swap and then flatten axes 0 (buffer_size) and 1 (n_instructors) to convert shape from [n_steps, n_instructors, ...] (when ... is the shape of the features) - to [n_steps * n_instructors, ...] (which maintain the order) + to [n_steps_per_episode * n_instructors, ...] (which maintain the order) :param arr: :return: diff --git a/benchmark/torch/RL4LMs/t5_ppo.yml b/benchmark/torch/RL4LMs/t5_ppo.yml index 9340c6ec4..41aea3140 100644 --- a/benchmark/torch/RL4LMs/t5_ppo.yml +++ b/benchmark/torch/RL4LMs/t5_ppo.yml @@ -30,15 +30,18 @@ kl_div: coeff: 0.001 target_kl: 0.2 +rollout_buffer: + args: + n_steps_per_episode: 512 # buffer length = n_steps_per_episode * n_instructors + agent: args: batch_size: 32 n_epochs: 5 alg: args: - n_steps: 512 - learning_rate: 0.000002 - ent_coef: 0.0 + initial_lr: 0.000002 + entropy_coef: 0.0 model: args: model_name: t5-base diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index b00c78d6b..1c8688b0c 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -74,27 +74,24 @@ def main(config): alg_config = config["agent"]["alg"] rl4lm_alg = RL4LMsPPO( model=rl4lms_model, - device=device, - n_steps=alg_config["args"]["n_steps"], - learning_rate=alg_config["args"]["learning_rate"], - ent_coef=alg_config["args"]["ent_coef"]) + initial_lr=alg_config["args"]["initial_lr"], + entropy_coef=alg_config["args"]["entropy_coef"]) agent_config = config["agent"] agent = RL4LMsAgent(rl4lm_alg, n_epochs=agent_config["args"]["n_epochs"], batch_size=agent_config["args"]["batch_size"],) + buffer_config = config["rollout_buffer"] rollout_buffer = DictRolloutBuffer( - buffer_size=agent.alg.n_steps * instructor_group.n_instructors, + buffer_size= buffer_config["args"]["n_steps_per_episode"] * instructor_group.n_instructors, observation_space=instructor_group.observation_space, action_space=instructor_group.action_space, device=device, - gamma=agent.alg.gamma, - gae_lambda=agent.alg.gae_lambda, ) rollout_util = RolloutUtil(config["kl_div"]) n_iters = int(config["train_evaluation"]["n_iters"]) - n_steps_per_iter = instructor_group.n_instructors * agent.alg.n_steps + n_steps_per_iter = instructor_group.n_instructors * buffer_config["args"]["n_steps_per_episode"] # gen kwargs for evaluation examiner_config = config["examiner"] From b34ea18a28adf652464c07bab3fe1d4611dcf25a Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Wed, 15 Mar 2023 16:47:53 +0800 Subject: [PATCH 25/34] yapf code style --- benchmark/torch/RL4LMs/rl4lms_agent.py | 14 ++++---- benchmark/torch/RL4LMs/rl4lms_ppo.py | 32 +++++++------------ .../torch/RL4LMs/rl4lms_utils/rollout_util.py | 18 +++++------ benchmark/torch/RL4LMs/train.py | 14 ++++---- 4 files changed, 36 insertions(+), 42 deletions(-) diff --git a/benchmark/torch/RL4LMs/rl4lms_agent.py b/benchmark/torch/RL4LMs/rl4lms_agent.py index 510a7e2d7..5ed047da0 100644 --- a/benchmark/torch/RL4LMs/rl4lms_agent.py +++ b/benchmark/torch/RL4LMs/rl4lms_agent.py @@ -74,12 +74,12 @@ def learn(self, rollout_buffer): batch_value = rollout_data.old_values continue_training, alg_learn_info = self.alg.learn( - batch_obs=batch_obs, - batch_action=batch_action, - batch_value=batch_value, - batch_return=batch_return, - batch_logprob=batch_logprob, - batch_adv=batch_adv) + batch_obs=batch_obs, + batch_action=batch_action, + batch_value=batch_value, + batch_return=batch_return, + batch_logprob=batch_logprob, + batch_adv=batch_adv) entropy_losses.append(alg_learn_info["entropy_losses"]) pg_losses.append(alg_learn_info["pg_losses"]) @@ -89,7 +89,7 @@ def learn(self, rollout_buffer): if not continue_training: break - self._n_updates += 1 # according to stable-baseline3 + self._n_updates += 1 # according to stable-baseline3 if not continue_training: print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_divs[-1]:.2f}") break diff --git a/benchmark/torch/RL4LMs/rl4lms_ppo.py b/benchmark/torch/RL4LMs/rl4lms_ppo.py index 5b47ac5a3..41060aee4 100644 --- a/benchmark/torch/RL4LMs/rl4lms_ppo.py +++ b/benchmark/torch/RL4LMs/rl4lms_ppo.py @@ -58,22 +58,17 @@ def __init__( param_group["lr"] = self.initial_lr self.optimizer = self.model.optimizer - def learn(self, - batch_obs, - batch_action, - batch_value, - batch_return, - batch_logprob, - batch_adv, - lr=None): + def learn(self, batch_obs, batch_action, batch_value, batch_return, batch_logprob, batch_adv, lr=None): # Do a complete pass on the rollout batch continue_training = True - learn_info = {"entropy_losses": None, - "pg_losses": None, - "value_losses": None, - "clip_fractions": None, - "approx_kl_divs": None, - "loss":None} + learn_info = { + "entropy_losses": None, + "pg_losses": None, + "value_losses": None, + "clip_fractions": None, + "approx_kl_divs": None, + "loss": None + } values, _ = self.model.value(batch_obs) action_log_probs, entropy, _ = self.model.policy(batch_obs, batch_action) @@ -83,16 +78,14 @@ def learn(self, # Normalize advantage if self.norm_adv: - batch_adv = (batch_adv - batch_adv.mean()) / ( - batch_adv.std() + 1e-8) + batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = torch.exp(action_log_probs - batch_logprob) # clipped surrogate loss surr1 = ratio * batch_adv - surr2 = torch.clamp(ratio, 1.0 - self.clip_param, - 1.0 + self.clip_param) * batch_adv + surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * batch_adv action_loss = -torch.min(surr1, surr2).mean() @@ -111,8 +104,7 @@ def learn(self, ) value_losses = (values - batch_return).pow(2) value_losses_clipped = (value_pred_clipped - batch_return).pow(2) - value_loss = 0.5 * torch.max(value_losses, - value_losses_clipped).mean() + value_loss = 0.5 * torch.max(value_losses, value_losses_clipped).mean() else: value_loss = 0.5 * (batch_return - values).pow(2).mean() diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py index 4ea34b319..ad3fd7760 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py @@ -150,15 +150,15 @@ def collect_rollouts(self, agent, instructor_group, rollout_buffer): return num_timesteps def _generate_transition(self, - gen_sentence=None, - agent=None, - n_instructors=None, - obs_space=None, - sentence_new_obs=None, - sentence_rewards=None, - sentence_dones=None, - sentence_infos=None, - init_obs=None): + gen_sentence=None, + agent=None, + n_instructors=None, + obs_space=None, + sentence_new_obs=None, + sentence_rewards=None, + sentence_dones=None, + sentence_infos=None, + init_obs=None): current_obs = init_obs review_times = 0 diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 1c8688b0c..8a9acbaf0 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -77,13 +77,15 @@ def main(config): initial_lr=alg_config["args"]["initial_lr"], entropy_coef=alg_config["args"]["entropy_coef"]) agent_config = config["agent"] - agent = RL4LMsAgent(rl4lm_alg, - n_epochs=agent_config["args"]["n_epochs"], - batch_size=agent_config["args"]["batch_size"],) + agent = RL4LMsAgent( + rl4lm_alg, + n_epochs=agent_config["args"]["n_epochs"], + batch_size=agent_config["args"]["batch_size"], + ) buffer_config = config["rollout_buffer"] rollout_buffer = DictRolloutBuffer( - buffer_size= buffer_config["args"]["n_steps_per_episode"] * instructor_group.n_instructors, + buffer_size=buffer_config["args"]["n_steps_per_episode"] * instructor_group.n_instructors, observation_space=instructor_group.observation_space, action_space=instructor_group.action_space, device=device, @@ -140,12 +142,12 @@ def main(config): if __name__ == '__main__': parser = ArgumentParser(description="Fine-tune LM to generate controlled text") parser.add_argument("--config_path", type=str, help="path to the config file") - parser.add_argument("--project_name", type=str, help="project name", default="rl4lm_exps") + parser.add_argument("--project_name", type=str, help="project name", default="rl4lms_exps") parser.add_argument( "--experiment_name", type=str, help="experiment name", - default="rl4lm_experiment", + default="rl4lms_experiment", ) parser.add_argument( "--base_path_to_store_results", From 02c89566f55964905a7c9fb45823fe7d8b907f85 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Thu, 16 Mar 2023 21:34:54 +0800 Subject: [PATCH 26/34] change code for PARL-RL4LMs summarization version 0.1 --- benchmark/torch/RL4LMs/README.md | 2 +- benchmark/torch/RL4LMs/requirements.txt | 3 +- benchmark/torch/RL4LMs/rl4lms_agent.py | 17 +-- benchmark/torch/RL4LMs/rl4lms_ppo.py | 17 +-- .../huggingface_generation_util.py | 31 +----- .../torch/RL4LMs/rl4lms_utils/metric_util.py | 12 +-- benchmark/torch/RL4LMs/seq2seq_model.py | 17 +-- benchmark/torch/RL4LMs/t5_ppo.yml | 82 -------------- benchmark/torch/RL4LMs/t5_ppo_config.py | 101 ++++++++++++++++++ benchmark/torch/RL4LMs/train.py | 48 ++------- 10 files changed, 127 insertions(+), 203 deletions(-) delete mode 100644 benchmark/torch/RL4LMs/t5_ppo.yml create mode 100644 benchmark/torch/RL4LMs/t5_ppo_config.py diff --git a/benchmark/torch/RL4LMs/README.md b/benchmark/torch/RL4LMs/README.md index da133cfcf..e635996bd 100644 --- a/benchmark/torch/RL4LMs/README.md +++ b/benchmark/torch/RL4LMs/README.md @@ -20,5 +20,5 @@ xparl start --port 8811 --cpu_num 10 # start training -python train.py --config_path t5_ppo.yml +python train.py ``` diff --git a/benchmark/torch/RL4LMs/requirements.txt b/benchmark/torch/RL4LMs/requirements.txt index e7f120809..95ebbe5b0 100644 --- a/benchmark/torch/RL4LMs/requirements.txt +++ b/benchmark/torch/RL4LMs/requirements.txt @@ -1,6 +1,5 @@ -parl==2.1.1 +parl>=2.1.1 datasets==2.10.1 -PyYAML==6.0 torch==1.11.0 torchvision==0.12.0 transformers==4.18.0 diff --git a/benchmark/torch/RL4LMs/rl4lms_agent.py b/benchmark/torch/RL4LMs/rl4lms_agent.py index 5ed047da0..dc95e3810 100644 --- a/benchmark/torch/RL4LMs/rl4lms_agent.py +++ b/benchmark/torch/RL4LMs/rl4lms_agent.py @@ -137,30 +137,19 @@ def get_inputs_for_generation(self, dict_obs_tensor): def prepare_obs_input(self, obs): return {key: torch.as_tensor(_obs).to(self.device) for (key, _obs) in obs.items()} - def value( - self, - obs, - ): + def value(self, obs): return self.alg.value(obs) # note: RL4LMs uses the same way (language model always does sample() to generate in summarization task) for # collecting data and testing, so here use policy() instead of sample() and only need to return info # like log_prob and gen_kwargs without action - def policy( - self, - obs, - actions, - ): + def policy(self, obs, actions): return self.alg.policy( obs=obs, actions=actions, ) - def get_log_probs_ref_model( - self, - obs, - action, - ): + def get_log_probs_ref_model(self, obs, action): return self.alg.get_log_probs_ref_model(obs, action) def predict( diff --git a/benchmark/torch/RL4LMs/rl4lms_ppo.py b/benchmark/torch/RL4LMs/rl4lms_ppo.py index 41060aee4..85d916110 100644 --- a/benchmark/torch/RL4LMs/rl4lms_ppo.py +++ b/benchmark/torch/RL4LMs/rl4lms_ppo.py @@ -142,30 +142,19 @@ def learn(self, batch_obs, batch_action, batch_value, batch_return, batch_logpro return continue_training, learn_info - def value( - self, - obs, - ): + def value(self, obs): return self.model.value(obs) # note: RL4LMs uses the same way (language model always does sample() to generate in summarization # task) for collecting data and testing, so here policy() only needs to return info # like log_prob and gen_kwargs without action - def policy( - self, - obs, - actions, - ): + def policy(self, obs, actions): return self.model.policy( obs=obs, actions=actions, ) - def get_log_probs_ref_model( - self, - obs, - action, - ): + def get_log_probs_ref_model(self, obs, action): return self.model.get_log_probs_ref_model(obs, action) def predict( diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py index 68923e01b..4060b126a 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py @@ -1,32 +1,7 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Third party code # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# coding=utf-8 -# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# The following code are copied or modified from: +# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py from transformers.generation_utils import GenerationMixin import inspect diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py index 684d53429..fc155e051 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py @@ -21,7 +21,7 @@ from parl.utils import logger -class MeteorMetric: +class MeteorMetric(object): def __init__(self): super().__init__() self._metric = load_metric("meteor") @@ -42,7 +42,7 @@ def compute( return metric_dict -class RougeMetric: +class RougeMetric(object): def __init__(self, use_single_ref=True): super().__init__() self._metric = load_metric("rouge") @@ -73,7 +73,7 @@ def compute( return metric_dict -class BERTScoreMetric: +class BERTScoreMetric(object): def __init__(self, language): super().__init__() self._metric = load_metric("bertscore") @@ -103,7 +103,7 @@ def compute( return metric_dict -class BLEUMetric: +class BLEUMetric(object): def __init__(self): super().__init__() self._metric = load_metric("bleu") @@ -136,7 +136,7 @@ def compute( return {"lexical/bleu": (None, "n/a")} -class DiversityMetrics: +class DiversityMetrics(object): def __init__(self, window_size=100): self._msttr_metric = MSTTR(window_size=window_size) self._n_gram_metric = NGramStats() @@ -164,7 +164,7 @@ def compute( return diversity_metrics -class MetricRegistry: +class MetricRegistry(object): _registry = { "meteor": MeteorMetric, "rouge": RougeMetric, diff --git a/benchmark/torch/RL4LMs/seq2seq_model.py b/benchmark/torch/RL4LMs/seq2seq_model.py index 21235dcc1..ffcc48bd6 100644 --- a/benchmark/torch/RL4LMs/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/seq2seq_model.py @@ -82,11 +82,7 @@ def _build_model_heads(self, model_name): # note: RL4LMs uses the same way (language model always does sample() to generate in summarization # task) for collecting data and testing, so here policy() only needs to return info # like log_prob and gen_kwargs without action - def policy( - self, - obs, - actions, - ): + def policy(self, obs, actions): # 1. prepare model inputs past_model_kwargs = { "attention_mask": obs["prompt_or_input_attention_mask_pt"], @@ -129,10 +125,7 @@ def policy( return log_prob, entropy, past_model_kwargs - def value( - self, - obs, - ): + def value(self, obs): # 1. prepare model inputs past_model_kwargs = { "attention_mask": obs["prompt_or_input_attention_mask_pt"], @@ -186,11 +179,7 @@ def to(self, device): else: return super().to(device) - def get_log_probs_ref_model( - self, - obs, - action, - ): + def get_log_probs_ref_model(self, obs, action): # 1. prepare model inputs past_model_kwargs = { "attention_mask": obs["prompt_or_input_attention_mask_pt"], diff --git a/benchmark/torch/RL4LMs/t5_ppo.yml b/benchmark/torch/RL4LMs/t5_ppo.yml deleted file mode 100644 index 41aea3140..000000000 --- a/benchmark/torch/RL4LMs/t5_ppo.yml +++ /dev/null @@ -1,82 +0,0 @@ -# config for RL4LMs summarization - -tokenizer: - model_name: t5-base - padding_side: left - truncation_side: left - pad_token_as_eos_token: False - - -datapool: - id: cnn_daily_mail - args: - prompt_prefix: "Summarize: " - - -instructor: - parl_master_address: "localhost:8811" - n_instructors: 10 - reward_fn: - args: - rouge_type: "rouge1" - args: - max_prompt_length: 512 - max_episode_length: 100 - terminate_on_eos: True - prompt_truncation_side: "right" - context_start_token: 0 - -kl_div: - coeff: 0.001 - target_kl: 0.2 - -rollout_buffer: - args: - n_steps_per_episode: 512 # buffer length = n_steps_per_episode * n_instructors - -agent: - args: - batch_size: 32 - n_epochs: 5 - alg: - args: - initial_lr: 0.000002 - entropy_coef: 0.0 - model: - args: - model_name: t5-base - apply_model_parallel: True - prompt_truncation_side: "right" - generation_kwargs: - do_sample: True - top_k: 50 - min_length: 50 - max_new_tokens: 100 - -examiner: - args: - max_prompt_length: 512 - eval_batch_size: 100 - generation_kwargs: - do_sample: True - top_k: 0 - temperature: 0.7 - min_length: 50 - max_new_tokens: 100 - metrics: - - id: meteor - args: { } - - id: rouge - - id: bleu - args: { } - - id: bert_score - args: - language: en - - id: diversity - args: { } - - -train_evaluation: - n_iters: 100 - eval_every: 10 - diff --git a/benchmark/torch/RL4LMs/t5_ppo_config.py b/benchmark/torch/RL4LMs/t5_ppo_config.py new file mode 100644 index 000000000..507743ae3 --- /dev/null +++ b/benchmark/torch/RL4LMs/t5_ppo_config.py @@ -0,0 +1,101 @@ + +config = { + 'tokenizer': { + 'model_name': 't5-base', + 'padding_side': 'left', + 'truncation_side': 'left', + 'pad_token_as_eos_token': False + }, + 'datapool': { + 'id': 'cnn_daily_mail', + 'args': { + 'prompt_prefix': 'Summarize: ' + } + }, + 'instructor': { + 'parl_master_address': 'localhost:8811', + 'n_instructors': 100, + 'reward_fn': { + 'args': { + 'rouge_type': 'rouge1' + } + }, + 'args': { + 'max_prompt_length': 512, + 'max_episode_length': 100, + 'terminate_on_eos': True, + 'prompt_truncation_side': 'right', + 'context_start_token': 0 + } + }, + 'kl_div': { + 'coeff': 0.001, + 'target_kl': 0.2 + }, + 'rollout_buffer': { + 'args': { + 'n_steps_per_episode': 512 # buffer length = n_steps_per_episode * n_instructors + } + }, + 'agent': { + 'args': { + 'batch_size': 32, + 'n_epochs': 5 + }, + 'alg': { + 'args': { + 'initial_lr': 0.000002, + 'entropy_coef': 0.0 + }, + 'model': { + 'args': { + 'model_name': 't5-base', + 'apply_model_parallel': True, + 'prompt_truncation_side': 'right', + 'generation_kwargs': { + 'do_sample': True, + 'top_k': 50, + 'min_length': 50, + 'max_new_tokens': 100 + } + } + } + } + }, + 'examiner': { + 'args': { + 'max_prompt_length': 512, + 'eval_batch_size': 100, + 'generation_kwargs': { + 'do_sample': True, + 'top_k': 0, + 'temperature': 0.7, + 'min_length': 50, + 'max_new_tokens': 100 + } + }, + 'metrics': [ + { + 'id': 'meteor', + 'args': {} + }, { + 'id': 'rouge' + }, { + 'id': 'bleu', + 'args': {} + }, { + 'id': 'bert_score', + 'args': { + 'language': 'en' + } + }, { + 'id': 'diversity', + 'args': {} + } + ] + }, + 'train_evaluation': { + 'n_iters': 100, + 'eval_every': 10 + } +} diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 8a9acbaf0..e6e4dc4f5 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import sys -from argparse import ArgumentParser -import datetime -import yaml -import collections +from t5_ppo_config import config from parl.utils import logger import torch import time @@ -38,15 +34,6 @@ from seq2seq_model import Seq2SeqLMModel -def recursive_dict_update(d, u): - for k, v in u.items(): - if isinstance(v, collections.Mapping): - d[k] = recursive_dict_update(d.get(k, {}), v) - else: - d[k] = v - return d - - def main(config): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -111,7 +98,6 @@ def main(config): iter_start = 0 examiner.evaluate(policy=agent.alg.model, sample_name_list=["val", "test"], epoch=iter_start) - epoch = 0 for epoch in range(iter_start, n_iters): print("========== BEGIN ==========") print(f"outer epoch: {epoch} / {n_iters - 1}") @@ -136,38 +122,16 @@ def main(config): if (epoch + 1) % config["train_evaluation"]["eval_every"] == 0: examiner.evaluate(policy=agent.alg.model, sample_name_list=["val"], epoch=epoch) - examiner.evaluate(policy=agent.alg.model, sample_name_list=["val", "test"], epoch=epoch) + # during training, we evaluate on VALIDATION set, and finally we evaluate on TEST set + examiner.evaluate(policy=agent.alg.model, sample_name_list=["test"], epoch=epoch) if __name__ == '__main__': - parser = ArgumentParser(description="Fine-tune LM to generate controlled text") - parser.add_argument("--config_path", type=str, help="path to the config file") - parser.add_argument("--project_name", type=str, help="project name", default="rl4lms_exps") - parser.add_argument( - "--experiment_name", - type=str, - help="experiment name", - default="rl4lms_experiment", - ) - parser.add_argument( - "--base_path_to_store_results", - type=str, - help="Base path to store experiment results", - default=os.getcwd(), - ) - parser.add_argument("--entity_name", type=str, help="entity name", default="summarization") - args = parser.parse_args() - - # load the config file - with open(args.config_path, "r") as fp: - config = yaml.safe_load(fp) + logger.auto_set_dir() - recursive_dict_update(config, vars(args)) - log_dir = f"./{args.project_name}/{args.experiment_name}/{args.entity_name}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" - logger.set_dir(log_dir) - config["logging_dir"] = log_dir + config["logging_dir"] = logger.get_dir() config["sys_arg"] = sys.argv + logger.info(config) - logger.set_level("DEBUG") main(config) From 760cc9dcc844a4bb5cdf7200e09311f959727ec4 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Thu, 16 Mar 2023 21:45:05 +0800 Subject: [PATCH 27/34] change code style of PARL-RL4LMs summarization version 0.1 --- benchmark/torch/RL4LMs/t5_ppo_config.py | 45 +++++++++++++++---------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/benchmark/torch/RL4LMs/t5_ppo_config.py b/benchmark/torch/RL4LMs/t5_ppo_config.py index 507743ae3..c1e5def99 100644 --- a/benchmark/torch/RL4LMs/t5_ppo_config.py +++ b/benchmark/torch/RL4LMs/t5_ppo_config.py @@ -1,3 +1,16 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. config = { 'tokenizer': { @@ -74,25 +87,23 @@ 'max_new_tokens': 100 } }, - 'metrics': [ - { - 'id': 'meteor', - 'args': {} - }, { - 'id': 'rouge' - }, { - 'id': 'bleu', - 'args': {} - }, { - 'id': 'bert_score', - 'args': { + 'metrics': [{ + 'id': 'meteor', + 'args': {} + }, { + 'id': 'rouge' + }, { + 'id': 'bleu', + 'args': {} + }, { + 'id': 'bert_score', + 'args': { 'language': 'en' - } - }, { - 'id': 'diversity', - 'args': {} } - ] + }, { + 'id': 'diversity', + 'args': {} + }] }, 'train_evaluation': { 'n_iters': 100, From 1770e45be17e4476370c6e9768aad8d1776b1960 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Fri, 17 Mar 2023 11:29:32 +0800 Subject: [PATCH 28/34] change unreasonable name to n_steps_per_instructor in config --- benchmark/torch/RL4LMs/t5_ppo_config.py | 4 ++-- benchmark/torch/RL4LMs/train.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmark/torch/RL4LMs/t5_ppo_config.py b/benchmark/torch/RL4LMs/t5_ppo_config.py index c1e5def99..407fe84f1 100644 --- a/benchmark/torch/RL4LMs/t5_ppo_config.py +++ b/benchmark/torch/RL4LMs/t5_ppo_config.py @@ -27,7 +27,7 @@ }, 'instructor': { 'parl_master_address': 'localhost:8811', - 'n_instructors': 100, + 'n_instructors': 10, 'reward_fn': { 'args': { 'rouge_type': 'rouge1' @@ -47,7 +47,7 @@ }, 'rollout_buffer': { 'args': { - 'n_steps_per_episode': 512 # buffer length = n_steps_per_episode * n_instructors + 'n_steps_per_instructor': 512 # buffer length = n_steps_per_instructor * n_instructors } }, 'agent': { diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index e6e4dc4f5..b00d57d4b 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -72,7 +72,7 @@ def main(config): buffer_config = config["rollout_buffer"] rollout_buffer = DictRolloutBuffer( - buffer_size=buffer_config["args"]["n_steps_per_episode"] * instructor_group.n_instructors, + buffer_size=buffer_config["args"]["n_steps_per_instructor"] * instructor_group.n_instructors, observation_space=instructor_group.observation_space, action_space=instructor_group.action_space, device=device, @@ -80,7 +80,7 @@ def main(config): rollout_util = RolloutUtil(config["kl_div"]) n_iters = int(config["train_evaluation"]["n_iters"]) - n_steps_per_iter = instructor_group.n_instructors * buffer_config["args"]["n_steps_per_episode"] + n_steps_per_iter = instructor_group.n_instructors * buffer_config["args"]["n_steps_per_instructor"] # gen kwargs for evaluation examiner_config = config["examiner"] From b9c3e5c8e9036bd380af2836c02db1893878924e Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 20 Mar 2023 10:08:29 +0800 Subject: [PATCH 29/34] add object for all classes, adjust add-to-buffer structure --- benchmark/torch/RL4LMs/instructor.py | 4 +- benchmark/torch/RL4LMs/rl4lms_agent.py | 24 ++--- benchmark/torch/RL4LMs/rl4lms_ppo.py | 14 +-- benchmark/torch/RL4LMs/rl4lms_utils/buffer.py | 98 +++++++++++-------- .../torch/RL4LMs/rl4lms_utils/data_pool.py | 2 +- .../torch/RL4LMs/rl4lms_utils/data_wrapper.py | 12 +-- .../torch/RL4LMs/rl4lms_utils/examiner.py | 2 +- .../huggingface_generation_util.py | 2 +- .../RL4LMs/rl4lms_utils/kl_controller.py | 2 +- .../torch/RL4LMs/rl4lms_utils/reward_util.py | 2 +- .../torch/RL4LMs/rl4lms_utils/rollout_util.py | 48 +-------- benchmark/torch/RL4LMs/seq2seq_model.py | 7 +- 12 files changed, 99 insertions(+), 118 deletions(-) diff --git a/benchmark/torch/RL4LMs/instructor.py b/benchmark/torch/RL4LMs/instructor.py index 7d30d844f..3dc313cab 100644 --- a/benchmark/torch/RL4LMs/instructor.py +++ b/benchmark/torch/RL4LMs/instructor.py @@ -32,7 +32,7 @@ def _flatten_obs(obs, space, n_instructor=None): @parl.remote_class(wait=False) -class Instructor: +class Instructor(object): def __init__( self, reward_config=None, @@ -185,7 +185,7 @@ def get_obs_and_action_space(self): return (self.observation_space, self.action_space) -class InstructorGroup: +class InstructorGroup(object): def __init__( self, instructor_config=None, diff --git a/benchmark/torch/RL4LMs/rl4lms_agent.py b/benchmark/torch/RL4LMs/rl4lms_agent.py index dc95e3810..91aa72b6d 100644 --- a/benchmark/torch/RL4LMs/rl4lms_agent.py +++ b/benchmark/torch/RL4LMs/rl4lms_agent.py @@ -73,7 +73,7 @@ def learn(self, rollout_buffer): batch_return = rollout_data.returns batch_value = rollout_data.old_values - continue_training, alg_learn_info = self.alg.learn( + alg_learn_info = self.alg.learn( batch_obs=batch_obs, batch_action=batch_action, batch_value=batch_value, @@ -81,6 +81,8 @@ def learn(self, rollout_buffer): batch_logprob=batch_logprob, batch_adv=batch_adv) + continue_training = alg_learn_info["continue_training"] + entropy_losses.append(alg_learn_info["entropy_losses"]) pg_losses.append(alg_learn_info["pg_losses"]) value_losses.append(alg_learn_info["value_losses"]) @@ -89,12 +91,13 @@ def learn(self, rollout_buffer): if not continue_training: break - self._n_updates += 1 # according to stable-baseline3 + self._n_updates += 1 # fix the calculation of self._n_updates if not continue_training: print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_divs[-1]:.2f}") break - # self._n_updates += self.n_epochs # change original RL4LMs code + # RL4LMs' method may lead to inaccurate calculation of self._n_updates when continue_training is false + # self._n_updates += self.n_epochs explained_var = explained_variance(rollout_buffer.values.flatten(), rollout_buffer.returns.flatten()) # Logs @@ -130,10 +133,6 @@ def learn(self, rollout_buffer): logger.info(ppo_train_info) - def get_inputs_for_generation(self, dict_obs_tensor): - obs_tensor = self.prepare_obs_input(dict_obs_tensor) - return self.alg.model.get_inputs_for_generation(obs_tensor) - def prepare_obs_input(self, obs): return {key: torch.as_tensor(_obs).to(self.device) for (key, _obs) in obs.items()} @@ -149,18 +148,21 @@ def policy(self, obs, actions): actions=actions, ) - def get_log_probs_ref_model(self, obs, action): - return self.alg.get_log_probs_ref_model(obs, action) + def ref_policy(self, obs, action): + return self.alg.ref_policy(obs, action) def predict( self, tokenizer, + dict_obs_tensor=None, texts=None, max_prompt_length=None, - input_ids=None, - attention_mask=None, gen_kwargs=None, ): + obs_tensor = self.prepare_obs_input(dict_obs_tensor) + generation_inputs = self.alg.model.build_inputs(obs_tensor) + input_ids = generation_inputs.inputs + attention_mask = generation_inputs.attention_masks return self.alg.predict( input_ids=input_ids, attention_mask=attention_mask, diff --git a/benchmark/torch/RL4LMs/rl4lms_ppo.py b/benchmark/torch/RL4LMs/rl4lms_ppo.py index 85d916110..fb683bf63 100644 --- a/benchmark/torch/RL4LMs/rl4lms_ppo.py +++ b/benchmark/torch/RL4LMs/rl4lms_ppo.py @@ -67,7 +67,8 @@ def learn(self, batch_obs, batch_action, batch_value, batch_return, batch_logpro "value_losses": None, "clip_fractions": None, "approx_kl_divs": None, - "loss": None + "loss": None, + "continue_training": None } values, _ = self.model.value(batch_obs) @@ -127,7 +128,8 @@ def learn(self, batch_obs, batch_action, batch_value, batch_return, batch_logpro if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: continue_training = False - return continue_training, learn_info + learn_info["continue_training"] = continue_training + return learn_info if lr: for param_group in self.optimizer.param_groups: @@ -139,8 +141,8 @@ def learn(self, batch_obs, batch_action, batch_value, batch_return, batch_logpro # Clip grad norm torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() - - return continue_training, learn_info + learn_info["continue_training"] = continue_training + return learn_info def value(self, obs): return self.model.value(obs) @@ -154,8 +156,8 @@ def policy(self, obs, actions): actions=actions, ) - def get_log_probs_ref_model(self, obs, action): - return self.model.get_log_probs_ref_model(obs, action) + def ref_policy(self, obs, action): + return self.model.ref_policy(obs, action) def predict( self, diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py index a93a1332c..0677a9407 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py @@ -49,7 +49,7 @@ def get_obs_shape(observation_space, ): raise NotImplementedError(f"{observation_space} observation space is not supported") -class DictRolloutBuffer: +class DictRolloutBuffer(object): """ Dict Rollout buffer used in on-policy algorithms like A2C/PPO. Extends the RolloutBuffer to use dictionary observations @@ -118,46 +118,62 @@ def reset(self): self.pos = 0 self.full = False - def add( - self, - obs, - action, - reward, - episode_start, - value, - log_prob, - ): - """ - :param obs: Observation - :param action: Action - :param reward: - :param episode_start: Start of episode signal. - :param value: estimated value of the current state - following the current policy. - :param log_prob: log probability of the action - following the current policy. - """ - - if len(log_prob.shape) == 0: - # Reshape 0-d tensor to avoid error - log_prob = log_prob.reshape(-1, 1) - - for key in self.observations.keys(): - obs_ = np.array(obs[key]).copy() - # Reshape needed when using multiple instructors with discrete observations - # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) - if isinstance(self.observation_space.spaces[key], spaces.Discrete): - obs_ = obs_.reshape((1, ) + self.obs_shape[key]) - self.observations[key][self.pos] = obs_ - - self.actions[self.pos] = np.array(action).copy() - self.rewards[self.pos] = np.array(reward).copy() - self.episode_starts[self.pos] = np.array(episode_start).copy() - self.values[self.pos] = value.clone().cpu().numpy().flatten() - self.log_probs[self.pos] = log_prob.clone().cpu().numpy() - self.pos += 1 - if self.pos == self.buffer_size: - self.full = True + def add(self, episode_wise_transitions, rollout_info): + advantages_computed = False + for ep_ix, transitions in enumerate(episode_wise_transitions): + ep_length = len(transitions) + total_reward = 0.0 + total_kl_reward = 0.0 + for transition_ix, transition in enumerate(transitions): + total_reward += transition.task_reward + total_kl_reward += transition.kl_reward + rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div) + rollout_info["rollout_info/log_prob"].append(transition.log_prob) + rollout_info["rollout_info/ref_log_prob"].append(transition.ref_log_prob) + rollout_info["rollout_info/values"].append(transition.value.numpy()) + + # add to buffer + if not self.full: + obs = transition.observation + action = transition.action + reward = transition.total_reward + episode_start = transition.episode_start + value = transition.value + log_prob = transition.log_prob + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + for key in self.observations.keys(): + obs_ = np.array(obs[key]).copy() + # Reshape needed when using multiple instructors with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space.spaces[key], spaces.Discrete): + obs_ = obs_.reshape((1,) + self.obs_shape[key]) + self.observations[key][self.pos] = obs_ + + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + # if the buffer is full, compute advantages + if self.full and not advantages_computed: + # we fetch the last value for the last time step + # values come from the next transitions's values + next_values = (transitions[transition_ix + 1].value if + (transition_ix + 1) < ep_length else torch.tensor([0.0])) + + self.compute_returns_and_advantage(last_values=next_values, dones=transition.done) + advantages_computed = True + + rollout_info["rollout_info/ep_rew"].append(total_reward) + rollout_info["rollout_info/ep_lens"].append(ep_length) + rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward) def compute_returns_and_advantage(self, last_values, dones): """ diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py b/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py index 7212bd84d..9ae1e0f9b 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py @@ -19,7 +19,7 @@ from nltk.tokenize import word_tokenize -class CNNDailyMail: +class CNNDailyMail(object): def __init__(self, samples): self._samples = samples diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py b/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py index 7c1003513..bd92d56fb 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py @@ -27,7 +27,7 @@ @dataclass -class TransitionInfo: +class TransitionInfo(object): observation: TensorDict action: np.ndarray task_reward: np.ndarray @@ -52,7 +52,7 @@ class DictRolloutBufferSamples(NamedTuple): @dataclass(init=True) -class Sample: +class Sample(object): id: str prompt_or_input_text: str references: List[str] @@ -65,7 +65,7 @@ class PolicyType(Enum): @dataclass -class RefPolicyOutput: +class RefPolicyOutput(object): """ Dataclass for the output of the method policy.get_ref_log_probs() """ @@ -77,7 +77,7 @@ class RefPolicyOutput: @dataclass -class GenerationInputs: +class GenerationInputs(object): # prompt inputs inputs: torch.tensor # prompt attention masks @@ -85,7 +85,7 @@ class GenerationInputs: @dataclass -class GenerationOutputs: +class GenerationOutputs(object): # log probs at each time step step_wise_logprobs: List[List[torch.tensor]] # actions at each time step @@ -99,7 +99,7 @@ class GenerationOutputs: @dataclass -class Observation: +class Observation(object): # encoded input prompt_or_input_encoded_pt: torch.tensor # attention mask for the input diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py b/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py index 657f9505e..698a03c53 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py @@ -17,7 +17,7 @@ # class for results evaluation -class Examiner: +class Examiner(object): def __init__(self, tokenizer, eval_batch_size, metrics, eval_gen_kwargs, samples_by_split, max_prompt_length): self._tokenizer = tokenizer self._batch_size = eval_batch_size diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py index 4060b126a..c72f8c82c 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py @@ -87,7 +87,7 @@ class SampleEncoderDecoderOutput(ModelOutput): decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None -class GenerationMixinWithRawScores: +class GenerationMixinWithRawScores(object): """ A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py b/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py index c359c0e0e..81b3f4f22 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py @@ -15,7 +15,7 @@ import torch -class KLController: +class KLController(object): def __init__(self, kl_coeff, target_kl=None): self._kl_coeff = kl_coeff self._target_kl = target_kl diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py index 163a2afa4..c7c847783 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py @@ -15,7 +15,7 @@ from datasets import load_metric -class RougeRewardFunction: +class RougeRewardFunction(object): def __init__(self, rouge_type, use_single_ref=True): super().__init__() self._metric = load_metric("rouge") diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py index ad3fd7760..cedde9030 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py @@ -39,7 +39,7 @@ def unpack_observations(obs_tensor, n_instructors): return unpacked_obs -class RolloutUtil: +class RolloutUtil(object): def __init__(self, kl_args): self._kl_controller = KLController(kl_args["coeff"], kl_args["target_kl"]) @@ -68,14 +68,10 @@ def collect_rollouts(self, agent, instructor_group, rollout_buffer): # start parallel episodes current_obs = instructor_group.ask() - # generate sentences using the model - generation_inputs = agent.get_inputs_for_generation(current_obs) - # note: RL4LMs uses the same way (language model always does sample() to generate in summarization # task) for collecting data and testing, so here agent uses predict() rather than sample() gen_output = agent.predict( - input_ids=generation_inputs.inputs, - attention_mask=generation_inputs.attention_masks, + dict_obs_tensor=current_obs, tokenizer=tokenizer) # get episode state, reward, dones, infos from instructors @@ -98,43 +94,7 @@ def collect_rollouts(self, agent, instructor_group, rollout_buffer): # now we flush all episode wise info to the 1-D buffer # log transition and add to buffer - advantages_computed = False - for ep_ix, transitions in enumerate(episode_wise_transitions): - ep_length = len(transitions) - total_reward = 0.0 - total_kl_reward = 0.0 - for transition_ix, transition in enumerate(transitions): - total_reward += transition.task_reward - total_kl_reward += transition.kl_reward - rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div) - rollout_info["rollout_info/log_prob"].append(transition.log_prob) - rollout_info["rollout_info/ref_log_prob"].append(transition.ref_log_prob) - rollout_info["rollout_info/values"].append(transition.value.numpy()) - - # add to buffer - if not rollout_buffer.full: - rollout_buffer.add( - transition.observation, - transition.action, - transition.total_reward, - transition.episode_start, - transition.value, - transition.log_prob, - ) - - # if the buffer is full, compute advantages - if rollout_buffer.full and not advantages_computed: - # we fetch the last value for the last time step - # values come from the next transitions's values - next_values = (transitions[transition_ix + 1].value if - (transition_ix + 1) < ep_length else torch.tensor([0.0])) - - rollout_buffer.compute_returns_and_advantage(last_values=next_values, dones=transition.done) - advantages_computed = True - - rollout_info["rollout_info/ep_rew"].append(total_reward) - rollout_info["rollout_info/ep_lens"].append(ep_length) - rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward) + rollout_buffer.add(episode_wise_transitions, rollout_info) # aggregate rollout info aggregated_rollout_info = {} @@ -185,7 +145,7 @@ def _generate_transition(self, values, _ = agent.value(obs_tensor) # get reference log probs - ref_log_probs, _ = agent.get_log_probs_ref_model(obs_tensor, actions_tensor) + ref_log_probs, _, _ = agent.ref_policy(obs_tensor, actions_tensor) # sanity check assert torch.all(torch.isfinite(ref_log_probs)), "Infinite values in log probs" diff --git a/benchmark/torch/RL4LMs/seq2seq_model.py b/benchmark/torch/RL4LMs/seq2seq_model.py index ffcc48bd6..3ef73d1d2 100644 --- a/benchmark/torch/RL4LMs/seq2seq_model.py +++ b/benchmark/torch/RL4LMs/seq2seq_model.py @@ -179,7 +179,7 @@ def to(self, device): else: return super().to(device) - def get_log_probs_ref_model(self, obs, action): + def ref_policy(self, obs, action): # 1. prepare model inputs past_model_kwargs = { "attention_mask": obs["prompt_or_input_attention_mask_pt"], @@ -207,6 +207,7 @@ def get_log_probs_ref_model(self, obs, action): # get log probs dist = Categorical(logits=next_token_logits) log_prob = dist.log_prob(action) + entropy = dist.entropy() # update the model kwargs for further generation past_model_kwargs = unwrap_model(self._ref_model)._update_model_kwargs_for_generation( @@ -218,12 +219,12 @@ def get_log_probs_ref_model(self, obs, action): (decoder_attn_mask, torch.ones(batch_size, 1).to(decoder_attn_mask.device)), dim=-1, ) - return log_prob, past_model_kwargs + return log_prob, entropy, past_model_kwargs def get_policy_first_device(self): return (self._policy_model.get_encoder().first_device if self._apply_model_parallel else self.device) - def get_inputs_for_generation(self, obs): + def build_inputs(self, obs): generation_inputs = GenerationInputs(obs["prompt_or_input_encoded_pt"], obs["prompt_or_input_attention_mask_pt"]) From 59e02fa724a2f1fde1630fdc9ccec58d9b104c4c Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 20 Mar 2023 11:41:19 +0800 Subject: [PATCH 30/34] change t5_ppo_config and README --- benchmark/torch/RL4LMs/README.md | 31 +++++--- benchmark/torch/RL4LMs/instructor.py | 30 ++++++-- benchmark/torch/RL4LMs/rl4lms_utils/buffer.py | 2 +- .../rl4lms_utils/component_build_util.py | 10 ++- .../torch/RL4LMs/rl4lms_utils/rollout_util.py | 2 +- benchmark/torch/RL4LMs/t5_ppo_config.py | 75 ++++++++----------- benchmark/torch/RL4LMs/train.py | 32 ++++---- 7 files changed, 98 insertions(+), 84 deletions(-) diff --git a/benchmark/torch/RL4LMs/README.md b/benchmark/torch/RL4LMs/README.md index e635996bd..722cc3e18 100644 --- a/benchmark/torch/RL4LMs/README.md +++ b/benchmark/torch/RL4LMs/README.md @@ -1,20 +1,28 @@ -## Reproduce (Reconfiguration) Summarization in RL4LMs using PARL +## Reproduce Summarization-RLHF in RL4LMs using PARL > Paper: [Is Reinforcement Learning (Not) for Natural Language Processing: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization](https://arxiv.org/abs/2210.01241) -> -> Official code: [RL4LMs](https://github.com/allenai/RL4LMs) -> -> Other code referenced: [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) +### Background + +- Summarization task in NLP: Summarization is the task of producing a shorter version + of one document that preserves most of the input's meaning. +- RLHF: The abbreviation of Reinforcement Learning with Human Feedback, which uses human knowledge to train RL algorithms. + More information is available in the Hugging Face blog [Illustrating Reinforcement Learning from Human Feedback (RLHF)](https://huggingface.co/blog/rlhf) ### Main contribution -- Change from **\{ trainer: \{ ppo: \{ env, rollout_buffer, policy/model \} \} \}** to - **\{trainer: \{env, rollout_buffer, agent: \{ ppo: \{ model \} \} \} \}** according to PARL architecture. -- Use Parl parallel Training +- Build new Summarization-RLHF framework using PARL +- Use PARL parallel training + +### How to use -### Running command +#### Install dependencies +```bash +pip install -r requirements.txt +``` + +#### Start training ```bash # start xparl xparl start --port 8811 --cpu_num 10 @@ -22,3 +30,8 @@ xparl start --port 8811 --cpu_num 10 # start training python train.py ``` + +### Code Reference + +- Official code: [RL4LMs](https://github.com/allenai/RL4LMs) +- [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) diff --git a/benchmark/torch/RL4LMs/instructor.py b/benchmark/torch/RL4LMs/instructor.py index 3dc313cab..5c4c83db0 100644 --- a/benchmark/torch/RL4LMs/instructor.py +++ b/benchmark/torch/RL4LMs/instructor.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import time from collections import OrderedDict import torch from rl4lms_utils import Observation @@ -43,6 +43,7 @@ def __init__( terminate_on_eos=False, context_start_token=None, prompt_truncation_side="left", + waiting_time_idx=0, ): """ Instructor who gives reward @@ -53,6 +54,7 @@ def __init__( context_start_token (bool, optional): start token for the context (For Encoder-Decoder models! ) prompt_truncation_side (str): truncation side for prompt text (Defaults to "left") """ + time.sleep(waiting_time_idx * 90) # too many Instructors may cause problems if they load datasets at the same time tokenizer = build_tokenizer(tokenizer_config) samples = build_datapool(datapool_config, remote_train=True)["train"] reward_function = build_reward_fn(reward_config) @@ -195,12 +197,14 @@ def __init__( ): self.n_instructors = instructor_config["n_instructors"] # remote instructors need to use config to initialize due to serialization problem - instructor_kwargs = { - "reward_config": instructor_config["reward_fn"], - "tokenizer_config": tokenizer_config, - "datapool_config": datapool_config - } - instructor_kwargs = {**instructor_kwargs, **instructor_config.get("args", {})} + instructor_kwargs = {"reward_config": instructor_config["reward_fn"], + "tokenizer_config": tokenizer_config, + "datapool_config": datapool_config, + "max_prompt_length": instructor_config["max_prompt_length"], + "max_episode_length": instructor_config["max_episode_length"], + "terminate_on_eos": instructor_config["terminate_on_eos"], + "prompt_truncation_side": instructor_config["prompt_truncation_side"], + "context_start_token": instructor_config["context_start_token"]} self.tokenizer = tokenizer self._remote_instructors = self._create_instructors(instructor_kwargs, instructor_config["parl_master_address"]) @@ -258,4 +262,14 @@ def _instructors_feedback_sentence(self, all_sentences): def _create_instructors(self, instructor_kwargs, parl_port=None): parl.connect(parl_port, distributed_files=["./rl4lms_utils/*.py", "./*.py"]) - return [Instructor(**instructor_kwargs) for _ in range(self.n_instructors)] + return [Instructor( + reward_config=instructor_kwargs["reward_config"], + tokenizer_config=instructor_kwargs["tokenizer_config"], + datapool_config=instructor_kwargs["datapool_config"], + max_episode_length=instructor_kwargs["max_episode_length"], + max_prompt_length=instructor_kwargs["max_prompt_length"], + terminate_on_eos=instructor_kwargs["terminate_on_eos"], + context_start_token=instructor_kwargs["context_start_token"], + prompt_truncation_side=instructor_kwargs["prompt_truncation_side"], + waiting_time_idx=idx, + ) for idx in range(self.n_instructors)] diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py index 0677a9407..83cdf8d04 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py @@ -118,7 +118,7 @@ def reset(self): self.pos = 0 self.full = False - def add(self, episode_wise_transitions, rollout_info): + def add_transitions(self, episode_wise_transitions, rollout_info): advantages_computed = False for ep_ix, transitions in enumerate(episode_wise_transitions): ep_length = len(transitions) diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py index 695e0e550..a6bdb5d18 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py @@ -35,7 +35,7 @@ def build_tokenizer(tokenizer_config): def build_reward_fn(reward_config): logger.info(f"loading reward function: rouge") - reward_fn = RougeRewardFunction(**reward_config.get("args", {})) + reward_fn = RougeRewardFunction(rouge_type=reward_config["rouge_type"]) return reward_fn @@ -48,10 +48,12 @@ def build_metrics(metric_configs): def build_datapool(datapool_config, remote_train=False): def _get_datapool_by_split(split): - kwargs = datapool_config.get("args", {}) - kwargs["split"] = split + kwargs = {"prompt_prefix": datapool_config["prompt_prefix"], "split": split} logger.info(f"loading split of dataset: {datapool_config['id']} -- {kwargs['split']}") - dp_split = CNNDailyMail.prepare(**kwargs) + dp_split = CNNDailyMail.prepare( + split=kwargs["split"], + prompt_prefix=kwargs["prompt_prefix"] + ) logger.info(f"finish loading split of dataset: {datapool_config['id']} -- {kwargs['split']}") return dp_split diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py index cedde9030..fa5e50edf 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py @@ -94,7 +94,7 @@ def collect_rollouts(self, agent, instructor_group, rollout_buffer): # now we flush all episode wise info to the 1-D buffer # log transition and add to buffer - rollout_buffer.add(episode_wise_transitions, rollout_info) + rollout_buffer.add_transitions(episode_wise_transitions, rollout_info) # aggregate rollout info aggregated_rollout_info = {} diff --git a/benchmark/torch/RL4LMs/t5_ppo_config.py b/benchmark/torch/RL4LMs/t5_ppo_config.py index 407fe84f1..0caf8cb8c 100644 --- a/benchmark/torch/RL4LMs/t5_ppo_config.py +++ b/benchmark/torch/RL4LMs/t5_ppo_config.py @@ -21,72 +21,57 @@ }, 'datapool': { 'id': 'cnn_daily_mail', - 'args': { - 'prompt_prefix': 'Summarize: ' - } + 'prompt_prefix': 'Summarize: ' }, 'instructor': { 'parl_master_address': 'localhost:8811', 'n_instructors': 10, 'reward_fn': { - 'args': { - 'rouge_type': 'rouge1' - } + 'rouge_type': 'rouge1' }, - 'args': { - 'max_prompt_length': 512, - 'max_episode_length': 100, - 'terminate_on_eos': True, - 'prompt_truncation_side': 'right', - 'context_start_token': 0 - } + 'max_prompt_length': 512, + 'max_episode_length': 100, + 'terminate_on_eos': True, + 'prompt_truncation_side': 'right', + 'context_start_token': 0 }, 'kl_div': { 'coeff': 0.001, 'target_kl': 0.2 }, 'rollout_buffer': { - 'args': { - 'n_steps_per_instructor': 512 # buffer length = n_steps_per_instructor * n_instructors - } + 'n_steps_per_instructor': 512 # buffer length = n_steps_per_instructor * n_instructors }, 'agent': { - 'args': { - 'batch_size': 32, - 'n_epochs': 5 - }, + 'batch_size': 32, + 'n_epochs': 5, 'alg': { - 'args': { - 'initial_lr': 0.000002, - 'entropy_coef': 0.0 - }, - 'model': { - 'args': { - 'model_name': 't5-base', - 'apply_model_parallel': True, - 'prompt_truncation_side': 'right', - 'generation_kwargs': { - 'do_sample': True, - 'top_k': 50, - 'min_length': 50, - 'max_new_tokens': 100 - } - } - } - } - }, - 'examiner': { - 'args': { - 'max_prompt_length': 512, - 'eval_batch_size': 100, + 'initial_lr': 0.000002, + 'entropy_coef': 0.0 + }, + 'model': { + 'model_name': 't5-base', + 'apply_model_parallel': True, + 'prompt_truncation_side': 'right', 'generation_kwargs': { 'do_sample': True, - 'top_k': 0, - 'temperature': 0.7, + 'top_k': 50, 'min_length': 50, 'max_new_tokens': 100 } + } + }, + 'examiner': { + 'max_prompt_length': 512, + 'eval_batch_size': 100, + 'generation_kwargs': { + 'do_sample': True, + 'top_k': 0, + 'temperature': 0.7, + 'min_length': 50, + 'max_new_tokens': 100 }, + # metric list, each (id, args) is one metric 'metrics': [{ 'id': 'meteor', 'args': {} diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index b00d57d4b..71dd6512c 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -49,30 +49,30 @@ def main(config): datapool_config=config["datapool"], ) - model_config = config["agent"]["alg"]["model"] + agent_config = config["agent"] + model_config = agent_config["model"] rl4lms_model = Seq2SeqLMModel( observation_space=instructor_group.observation_space, action_space=instructor_group.action_space, device=device, - model_name=model_config["args"]["model_name"], - apply_model_parallel=model_config["args"]["apply_model_parallel"], - prompt_truncation_side=model_config["args"]["prompt_truncation_side"], - generation_kwargs=model_config["args"]["generation_kwargs"]) - alg_config = config["agent"]["alg"] + model_name=model_config["model_name"], + apply_model_parallel=model_config["apply_model_parallel"], + prompt_truncation_side=model_config["prompt_truncation_side"], + generation_kwargs=model_config["generation_kwargs"]) + alg_config = agent_config["alg"] rl4lm_alg = RL4LMsPPO( model=rl4lms_model, - initial_lr=alg_config["args"]["initial_lr"], - entropy_coef=alg_config["args"]["entropy_coef"]) - agent_config = config["agent"] + initial_lr=alg_config["initial_lr"], + entropy_coef=alg_config["entropy_coef"]) agent = RL4LMsAgent( rl4lm_alg, - n_epochs=agent_config["args"]["n_epochs"], - batch_size=agent_config["args"]["batch_size"], + n_epochs=agent_config["n_epochs"], + batch_size=agent_config["batch_size"], ) buffer_config = config["rollout_buffer"] rollout_buffer = DictRolloutBuffer( - buffer_size=buffer_config["args"]["n_steps_per_instructor"] * instructor_group.n_instructors, + buffer_size=buffer_config["n_steps_per_instructor"] * instructor_group.n_instructors, observation_space=instructor_group.observation_space, action_space=instructor_group.action_space, device=device, @@ -80,7 +80,7 @@ def main(config): rollout_util = RolloutUtil(config["kl_div"]) n_iters = int(config["train_evaluation"]["n_iters"]) - n_steps_per_iter = instructor_group.n_instructors * buffer_config["args"]["n_steps_per_instructor"] + n_steps_per_iter = instructor_group.n_instructors * buffer_config["n_steps_per_instructor"] # gen kwargs for evaluation examiner_config = config["examiner"] @@ -88,9 +88,9 @@ def main(config): metrics = build_metrics(examiner_config["metrics"]) examiner = Examiner( tokenizer=tokenizer, - eval_batch_size=examiner_config["args"]["eval_batch_size"], - max_prompt_length=examiner_config["args"]["max_prompt_length"], - eval_gen_kwargs=examiner_config["args"]["generation_kwargs"], + eval_batch_size=examiner_config["eval_batch_size"], + max_prompt_length=examiner_config["max_prompt_length"], + eval_gen_kwargs=examiner_config["generation_kwargs"], metrics=metrics, samples_by_split=samples_by_split, ) From 4cd67e233ba88b0a43f048207be3db9639b02ab3 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 20 Mar 2023 13:56:56 +0800 Subject: [PATCH 31/34] yapf code style --- benchmark/torch/RL4LMs/instructor.py | 45 ++++++++++--------- benchmark/torch/RL4LMs/rl4lms_utils/buffer.py | 2 +- .../rl4lms_utils/component_build_util.py | 5 +-- .../torch/RL4LMs/rl4lms_utils/rollout_util.py | 4 +- benchmark/torch/RL4LMs/t5_ppo_config.py | 6 ++- benchmark/torch/RL4LMs/train.py | 4 +- 6 files changed, 33 insertions(+), 33 deletions(-) diff --git a/benchmark/torch/RL4LMs/instructor.py b/benchmark/torch/RL4LMs/instructor.py index 5c4c83db0..3913750db 100644 --- a/benchmark/torch/RL4LMs/instructor.py +++ b/benchmark/torch/RL4LMs/instructor.py @@ -54,7 +54,8 @@ def __init__( context_start_token (bool, optional): start token for the context (For Encoder-Decoder models! ) prompt_truncation_side (str): truncation side for prompt text (Defaults to "left") """ - time.sleep(waiting_time_idx * 90) # too many Instructors may cause problems if they load datasets at the same time + time.sleep( + waiting_time_idx * 90) # too many Instructors may cause problems if they load datasets at the same time tokenizer = build_tokenizer(tokenizer_config) samples = build_datapool(datapool_config, remote_train=True)["train"] reward_function = build_reward_fn(reward_config) @@ -197,14 +198,16 @@ def __init__( ): self.n_instructors = instructor_config["n_instructors"] # remote instructors need to use config to initialize due to serialization problem - instructor_kwargs = {"reward_config": instructor_config["reward_fn"], - "tokenizer_config": tokenizer_config, - "datapool_config": datapool_config, - "max_prompt_length": instructor_config["max_prompt_length"], - "max_episode_length": instructor_config["max_episode_length"], - "terminate_on_eos": instructor_config["terminate_on_eos"], - "prompt_truncation_side": instructor_config["prompt_truncation_side"], - "context_start_token": instructor_config["context_start_token"]} + instructor_kwargs = { + "reward_config": instructor_config["reward_fn"], + "tokenizer_config": tokenizer_config, + "datapool_config": datapool_config, + "max_prompt_length": instructor_config["max_prompt_length"], + "max_episode_length": instructor_config["max_episode_length"], + "terminate_on_eos": instructor_config["terminate_on_eos"], + "prompt_truncation_side": instructor_config["prompt_truncation_side"], + "context_start_token": instructor_config["context_start_token"] + } self.tokenizer = tokenizer self._remote_instructors = self._create_instructors(instructor_kwargs, instructor_config["parl_master_address"]) @@ -262,14 +265,16 @@ def _instructors_feedback_sentence(self, all_sentences): def _create_instructors(self, instructor_kwargs, parl_port=None): parl.connect(parl_port, distributed_files=["./rl4lms_utils/*.py", "./*.py"]) - return [Instructor( - reward_config=instructor_kwargs["reward_config"], - tokenizer_config=instructor_kwargs["tokenizer_config"], - datapool_config=instructor_kwargs["datapool_config"], - max_episode_length=instructor_kwargs["max_episode_length"], - max_prompt_length=instructor_kwargs["max_prompt_length"], - terminate_on_eos=instructor_kwargs["terminate_on_eos"], - context_start_token=instructor_kwargs["context_start_token"], - prompt_truncation_side=instructor_kwargs["prompt_truncation_side"], - waiting_time_idx=idx, - ) for idx in range(self.n_instructors)] + return [ + Instructor( + reward_config=instructor_kwargs["reward_config"], + tokenizer_config=instructor_kwargs["tokenizer_config"], + datapool_config=instructor_kwargs["datapool_config"], + max_episode_length=instructor_kwargs["max_episode_length"], + max_prompt_length=instructor_kwargs["max_prompt_length"], + terminate_on_eos=instructor_kwargs["terminate_on_eos"], + context_start_token=instructor_kwargs["context_start_token"], + prompt_truncation_side=instructor_kwargs["prompt_truncation_side"], + waiting_time_idx=idx, + ) for idx in range(self.n_instructors) + ] diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py index 83cdf8d04..a90fbc654 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py @@ -149,7 +149,7 @@ def add_transitions(self, episode_wise_transitions, rollout_info): # Reshape needed when using multiple instructors with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) if isinstance(self.observation_space.spaces[key], spaces.Discrete): - obs_ = obs_.reshape((1,) + self.obs_shape[key]) + obs_ = obs_.reshape((1, ) + self.obs_shape[key]) self.observations[key][self.pos] = obs_ self.actions[self.pos] = np.array(action).copy() diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py index a6bdb5d18..1dd9c81b2 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py @@ -50,10 +50,7 @@ def build_datapool(datapool_config, remote_train=False): def _get_datapool_by_split(split): kwargs = {"prompt_prefix": datapool_config["prompt_prefix"], "split": split} logger.info(f"loading split of dataset: {datapool_config['id']} -- {kwargs['split']}") - dp_split = CNNDailyMail.prepare( - split=kwargs["split"], - prompt_prefix=kwargs["prompt_prefix"] - ) + dp_split = CNNDailyMail.prepare(split=kwargs["split"], prompt_prefix=kwargs["prompt_prefix"]) logger.info(f"finish loading split of dataset: {datapool_config['id']} -- {kwargs['split']}") return dp_split diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py index fa5e50edf..0ac5ba8fe 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py @@ -70,9 +70,7 @@ def collect_rollouts(self, agent, instructor_group, rollout_buffer): # note: RL4LMs uses the same way (language model always does sample() to generate in summarization # task) for collecting data and testing, so here agent uses predict() rather than sample() - gen_output = agent.predict( - dict_obs_tensor=current_obs, - tokenizer=tokenizer) + gen_output = agent.predict(dict_obs_tensor=current_obs, tokenizer=tokenizer) # get episode state, reward, dones, infos from instructors sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos = instructor_group.feedback_sentense( diff --git a/benchmark/torch/RL4LMs/t5_ppo_config.py b/benchmark/torch/RL4LMs/t5_ppo_config.py index 0caf8cb8c..e1a058a4e 100644 --- a/benchmark/torch/RL4LMs/t5_ppo_config.py +++ b/benchmark/torch/RL4LMs/t5_ppo_config.py @@ -62,8 +62,10 @@ } }, 'examiner': { - 'max_prompt_length': 512, - 'eval_batch_size': 100, + 'max_prompt_length': + 512, + 'eval_batch_size': + 100, 'generation_kwargs': { 'do_sample': True, 'top_k': 0, diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 71dd6512c..5bdeb804a 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -61,9 +61,7 @@ def main(config): generation_kwargs=model_config["generation_kwargs"]) alg_config = agent_config["alg"] rl4lm_alg = RL4LMsPPO( - model=rl4lms_model, - initial_lr=alg_config["initial_lr"], - entropy_coef=alg_config["entropy_coef"]) + model=rl4lms_model, initial_lr=alg_config["initial_lr"], entropy_coef=alg_config["entropy_coef"]) agent = RL4LMsAgent( rl4lm_alg, n_epochs=agent_config["n_epochs"], From da9422633b1b69964d3d08ef95fc07a6d23635a3 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 3 Apr 2023 10:52:02 +0800 Subject: [PATCH 32/34] change buffer add(), add save/load --- benchmark/torch/RL4LMs/instructor.py | 1 + benchmark/torch/RL4LMs/rl4lms_utils/buffer.py | 74 ++++++++++++------- benchmark/torch/RL4LMs/t5_ppo_config.py | 7 +- benchmark/torch/RL4LMs/train.py | 18 ++++- 4 files changed, 70 insertions(+), 30 deletions(-) diff --git a/benchmark/torch/RL4LMs/instructor.py b/benchmark/torch/RL4LMs/instructor.py index 3913750db..6f5cb84bb 100644 --- a/benchmark/torch/RL4LMs/instructor.py +++ b/benchmark/torch/RL4LMs/instructor.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import time from collections import OrderedDict import torch diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py index a90fbc654..0137f466c 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py @@ -118,6 +118,46 @@ def reset(self): self.pos = 0 self.full = False + def add(self, + obs, + action, + reward, + episode_start, + value, + log_prob, + ): + """ + :param obs: Observation + :param action: Action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob: log probability of the action + following the current policy. + """ + + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + for key in self.observations.keys(): + obs_ = np.array(obs[key]).copy() + # Reshape needed when using multiple instructors with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space.spaces[key], spaces.Discrete): + obs_ = obs_.reshape((1, ) + self.obs_shape[key]) + self.observations[key][self.pos] = obs_ + + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + def add_transitions(self, episode_wise_transitions, rollout_info): advantages_computed = False for ep_ix, transitions in enumerate(episode_wise_transitions): @@ -134,32 +174,14 @@ def add_transitions(self, episode_wise_transitions, rollout_info): # add to buffer if not self.full: - obs = transition.observation - action = transition.action - reward = transition.total_reward - episode_start = transition.episode_start - value = transition.value - log_prob = transition.log_prob - if len(log_prob.shape) == 0: - # Reshape 0-d tensor to avoid error - log_prob = log_prob.reshape(-1, 1) - - for key in self.observations.keys(): - obs_ = np.array(obs[key]).copy() - # Reshape needed when using multiple instructors with discrete observations - # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) - if isinstance(self.observation_space.spaces[key], spaces.Discrete): - obs_ = obs_.reshape((1, ) + self.obs_shape[key]) - self.observations[key][self.pos] = obs_ - - self.actions[self.pos] = np.array(action).copy() - self.rewards[self.pos] = np.array(reward).copy() - self.episode_starts[self.pos] = np.array(episode_start).copy() - self.values[self.pos] = value.clone().cpu().numpy().flatten() - self.log_probs[self.pos] = log_prob.clone().cpu().numpy() - self.pos += 1 - if self.pos == self.buffer_size: - self.full = True + self.add( + transition.observation, + transition.action, + transition.total_reward, + transition.episode_start, + transition.value, + transition.log_prob, + ) # if the buffer is full, compute advantages if self.full and not advantages_computed: diff --git a/benchmark/torch/RL4LMs/t5_ppo_config.py b/benchmark/torch/RL4LMs/t5_ppo_config.py index e1a058a4e..6314bccbc 100644 --- a/benchmark/torch/RL4LMs/t5_ppo_config.py +++ b/benchmark/torch/RL4LMs/t5_ppo_config.py @@ -93,7 +93,12 @@ }] }, 'train_evaluation': { + 'load_model': False, + 'save_model': True, 'n_iters': 100, - 'eval_every': 10 + 'eval_every': 10, + 'save_every': 10, + 'checkpoint_path': "./checkpoint/checkpoint_0.pth", + 'output_dir': "./checkpoint" } } diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 5bdeb804a..99ef46d36 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -17,6 +17,7 @@ from parl.utils import logger import torch import time +import os # instructor and reward function from instructor import InstructorGroup @@ -77,7 +78,8 @@ def main(config): ) rollout_util = RolloutUtil(config["kl_div"]) - n_iters = int(config["train_evaluation"]["n_iters"]) + train_evaluation_config = config["train_evaluation"] + n_iters = int(train_evaluation_config["n_iters"]) n_steps_per_iter = instructor_group.n_instructors * buffer_config["n_steps_per_instructor"] # gen kwargs for evaluation @@ -93,8 +95,11 @@ def main(config): samples_by_split=samples_by_split, ) + if train_evaluation_config["load_model"]: + logger.info(f"loading model from {train_evaluation_config['checkpoint_path']}") + rl4lms_model.load_state_dict(torch.load(train_evaluation_config["checkpoint_path"])["state_dict"]) iter_start = 0 - examiner.evaluate(policy=agent.alg.model, sample_name_list=["val", "test"], epoch=iter_start) + # examiner.evaluate(policy=agent.alg.model, sample_name_list=["val", "test"], epoch=iter_start) for epoch in range(iter_start, n_iters): print("========== BEGIN ==========") @@ -116,8 +121,15 @@ def main(config): f" {1.0 * (outer_end_time - outer_start_time) * (n_iters - epoch - 1) / 60 / 60} hour(s)") print("========== END ==========") + # save model + if train_evaluation_config['save_model'] and (epoch + 1) % train_evaluation_config["save_every"] == 0: + output_dir = train_evaluation_config['output_dir'] + if not os.path.exists(output_dir): + os.mkdir(output_dir) + rl4lms_model.save(f"{output_dir}/checkpoint_{epoch}.pth") + # evaluate on val set in the given intervals - if (epoch + 1) % config["train_evaluation"]["eval_every"] == 0: + if (epoch + 1) % train_evaluation_config["eval_every"] == 0: examiner.evaluate(policy=agent.alg.model, sample_name_list=["val"], epoch=epoch) # during training, we evaluate on VALIDATION set, and finally we evaluate on TEST set From 68ec090a338a5d7e13a67e2d374e3b22831bdd4c Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 3 Apr 2023 10:55:14 +0800 Subject: [PATCH 33/34] yapf code style --- benchmark/torch/RL4LMs/rl4lms_utils/buffer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py index 0137f466c..f66d7ad7e 100644 --- a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py +++ b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py @@ -118,7 +118,8 @@ def reset(self): self.pos = 0 self.full = False - def add(self, + def add( + self, obs, action, reward, From 21e99e83ca00353a7b8b26e5f7d110cb7db83732 Mon Sep 17 00:00:00 2001 From: dwyzzy <805864608@qq.com> Date: Mon, 3 Apr 2023 11:23:43 +0800 Subject: [PATCH 34/34] evaluate at beginning --- benchmark/torch/RL4LMs/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py index 99ef46d36..7acf7b02d 100644 --- a/benchmark/torch/RL4LMs/train.py +++ b/benchmark/torch/RL4LMs/train.py @@ -99,7 +99,7 @@ def main(config): logger.info(f"loading model from {train_evaluation_config['checkpoint_path']}") rl4lms_model.load_state_dict(torch.load(train_evaluation_config["checkpoint_path"])["state_dict"]) iter_start = 0 - # examiner.evaluate(policy=agent.alg.model, sample_name_list=["val", "test"], epoch=iter_start) + examiner.evaluate(policy=agent.alg.modell, sample_name_list=["val", "test"], epoch=iter_start) for epoch in range(iter_start, n_iters): print("========== BEGIN ==========")