diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 5277d7010..864b5abf8 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -49,11 +49,13 @@ def default_ppo_config(): ref_mean=None, ref_std=None, cliprange_reward=10, + num_topk_samples=1, gen_kwargs=dict( max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True, + num_return_sequences=1, ), ), ) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 51d54cf36..de098ac22 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -112,6 +112,9 @@ class PPOConfig(MethodConfig): :param gen_experience_kwargs: if this is not None, then the experience is generated using this :type gen_experience_kwargs: Dict[str, Any] + + :param num_topk_samples: top_k of n sampled sequences from prompt + :type num_topk_samples: int """ ppo_epochs: int @@ -131,6 +134,7 @@ class PPOConfig(MethodConfig): cliprange_reward: float gen_kwargs: dict gen_experience_kwargs: Optional[dict] = None + num_topk_samples: int = 1 num_value_layers_unfrozen: int = 0 def get_advantages_and_returns( diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 9dd1f99a3..c20bab21a 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -4,6 +4,7 @@ import sys from abc import abstractmethod from contextlib import contextmanager +from copy import copy from time import time from typing import Dict, List, Optional, Tuple @@ -13,6 +14,7 @@ from ray.air import session from rich.console import Console from rich.table import Table +from torch.nn.utils.rnn import pad_sequence from transformers import AutoTokenizer import trlx.utils.logging as logging @@ -220,21 +222,17 @@ def decode( str_prompt = self.tokenizer.decode(prompt[:prompt_size], skip_special_tokens=True) str_output = self.tokenizer.decode(sample[output_start_ix:], skip_special_tokens=True) # Trim outputs up to `self.stop_sequences` if any are present - trimmed = False if self.stop_sequences: for stop in self.stop_sequences: stop_ix = str_output.find(stop) if stop_ix >= 0: str_output = str_output[:stop_ix].rstrip() - trimmed = True # Recover the last if it was present in the original sample # or add one if it was trimmed with `self.stop_sequences`. # When a generation ended due to `max_new_tokens` exhaustion, # only then or token would not be present in the original sample at the end - if append_eos_token and ( - trimmed or sample[-1] == self.tokenizer.eos_token_id or sample[-1] == self.tokenizer.pad_token_id - ): + if append_eos_token: str_output += self.tokenizer.eos_token str_prompts.append(str_prompt) @@ -249,33 +247,51 @@ def decode( return str_samples, str_prompts, str_outputs - def generate(self, input_ids, attention_mask=None, **kwargs): + def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" + # Decide into chunk sizes and generate saples input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) - if self.generate_experience_kwargs is not None: - kwargs = dict(self.generate_experience_kwargs, **kwargs) - else: - kwargs = dict(self.generate_kwargs, **kwargs) - with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs + generate_kwargs = copy(self.generate_kwargs) + generate_kwargs.update(kwargs) + + # Update max_new_tokens to respect max_seq_length + prompt_length = input_ids.shape[1] + if generate_kwargs.get("max_new_tokens") is not None: + generate_kwargs["max_new_tokens"] = min( + max(self.max_length - prompt_length, 0), generate_kwargs["max_new_tokens"] ) + else: + generate_kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) - def generate_eval(self, input_ids, attention_mask=None, **kwargs): - """Wraps hf's `generate` adding some specific method's defaults""" - input_ids = input_ids.to(self.accelerator.device) + # Repeat prompts, attention_masks for chunking if returning multiple sequences + if generate_kwargs.get("num_return_sequences") is None: + generate_kwargs["num_return_sequences"] = 1 + + num_return_sequences = generate_kwargs.pop("num_return_sequences") # Pop to hide from model.generate call + input_ids = input_ids.repeat_interleave(num_return_sequences, dim=0) if attention_mask is not None: - attention_mask = attention_mask.to(self.accelerator.device) + attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) - kwargs = dict(self.generate_kwargs, **kwargs) + if chunk_size is None: + chunk_size = input_ids.shape[0] - with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs - ) + # Chunk input_ids and attention_mask + input_ids = input_ids.split(chunk_size, dim=0) + if attention_mask is not None: + attention_mask = attention_mask.split(chunk_size, dim=0) + all_samples = [] + for chunk_idx in range(len(input_ids)): + with torch.no_grad(): + samples = self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **generate_kwargs + ) + all_samples += [sample for sample in samples] + # Pad all_samples into one tensor + all_samples = pad_sequence(all_samples, batch_first=True, padding_value=self.tokenizer.pad_token_id) + return all_samples def save_pretrained(self, directory: Optional[str] = None, **kwargs): """Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for @@ -373,11 +389,20 @@ def evaluate(self): # noqa: C901 for i_prompt, prompts in enumerate(self.eval_dataloader): metadata = {k: v for k, v in prompts.items() if k != "input_ids" and k != "attention_mask"} if self.generate_sweep_kwarg: - samples = self.generate_eval( + samples = self.generate( prompts["input_ids"], prompts["attention_mask"], **{gen_sweep_arg: gen_sweep_value} ) else: - samples = self.generate_eval(prompts["input_ids"], prompts["attention_mask"]) + chunk_size = self.config.method.chunk_size if hasattr(self.config.method, "chunk_size") else None + samples = self.generate(prompts["input_ids"], prompts["attention_mask"], chunk_size=chunk_size) + + # Repeat prompts, metadata num_return_sequence times + num_return_sequences = 1 + if self.generate_kwargs.get("num_return_sequences") is not None: + num_return_sequences = self.generate_kwargs["num_return_sequences"] + prompts["input_ids"] = prompts["input_ids"].repeat_interleave(num_return_sequences, dim=0) + prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(num_return_sequences, dim=0) + metadata = {k: self.repeat_interleave(v, num_return_sequences) for k, v in metadata.items()} # TODO(reciprocated): this should be moved into `decode` # but that needs to be synced with indexing in `make_experience` @@ -447,7 +472,13 @@ def evaluate(self): # noqa: C901 if self.metric_fn: logger.info("Computing metrics") metric_time = time() - metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata) + metrics = self.metric_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + tokenizer=self.tokenizer, + **metadata, + ) stats["time/metric"] = time() - metric_time mean_metrics = { @@ -648,6 +679,15 @@ def learn(self): # noqa: C901 self.post_epoch_callback() tbar.close() + @staticmethod + def repeat_interleave(l, n): + if type(l) is torch.Tensor: + l = l.repeat_interleave(n, dim=0) + elif type(l) is list: + l = [[s] * n for s in l] + l = [item for sublist in l for item in sublist] + return l + @abstractmethod def create_train_dataloader(self): """Returns a new dataloader for training.""" diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 27ed4b5aa..97308d86f 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -94,10 +94,13 @@ def __init__(self, config: TRLConfig, **kwargs): ) self.generate_kwargs = {**generate_kwargs, **config.method.gen_kwargs} + if self.generate_kwargs.get("num_return_sequences") is None: + self.generate_kwargs["num_return_sequences"] = 1 + if config.method.gen_experience_kwargs is not None: self.generate_experience_kwargs = {**generate_kwargs, **config.method.gen_experience_kwargs} else: - self.generate_experience_kwargs = None + self.generate_experience_kwargs = {**self.generate_kwargs} # Setup stats tracker self.running_moments = RunningMoments() @@ -241,7 +244,7 @@ def prepare_learning(self): def add_prompt_pipeline(self, pipeline: PromptPipeline): """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" - prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) + prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=False) prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) self.prompt_iterator = infinite_dataloader(prompt_dataloader) @@ -272,6 +275,12 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ppo_rl_elements = [] accumulated_stats = [] + num_return_sequences = self.generate_experience_kwargs["num_return_sequences"] + + # Require chunk_size * num_topk_samples divides num_rollouts + assert num_rollouts % (self.config.method.chunk_size * self.config.method.num_topk_samples) == 0 + assert self.config.method.num_topk_samples <= num_return_sequences + while len(ppo_rl_elements) < num_rollouts: stats = {} # Get next batch in prompt dataset @@ -280,10 +289,15 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() # Generate samples from the language model (similar to using HuggingFace `generate` method) - samples = self.generate(batch["input_ids"], batch["attention_mask"]) + samples = self.generate( + batch["input_ids"], + batch["attention_mask"], + chunk_size=self.config.method.chunk_size, + **self.generate_experience_kwargs, + ) stats["time/rollout_generate"] = time() - rollout_generate_time - prompt_tensors = batch.input_ids + prompt_tensors = batch.input_ids.repeat_interleave(num_return_sequences, dim=0) device = samples.device prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) @@ -296,7 +310,13 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq gathered_samples = self.accelerator.gather(padded_samples) gathered_prompts = self.accelerator.gather(padded_prompts) gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) - metadata = gather_dict({k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}) + metadata = gather_dict( + { + k: self.repeat_interleave(v, num_return_sequences) + for k, v in batch.items() + if k != "input_ids" and k != "attention_mask" + } + ) if self.accelerator.is_main_process: all_str_samples, all_str_prompts, all_str_outputs = self.decode( @@ -336,7 +356,19 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() + + # Best-of-N Sampling. scores_mask = scores != -np.inf + train_indices = self.get_topk_indices( + input_tensor=scores_mask * scores, + window_size=num_return_sequences, + k=self.config.method.num_topk_samples, + device=device, + ) + scores = scores[train_indices] + scores_mask = scores_mask[train_indices] + samples = samples[train_indices] + prompt_tensors = prompt_tensors[train_indices] str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) @@ -410,38 +442,67 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq return_dict=True, ).logits else: + values_chunks = [] + logits_chunks = [] + ref_logits_chunks = [] + log_probs_chunks = [] + ref_logprobs_chunks = [] all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - with torch.no_grad(): - logits, *_, values = self.model( - all_tokens, attention_mask=attention_mask, position_ids=position_ids - ) - # TODO(dahoas): When hydra model works need to also support generation on hydra head - if hasattr(self.model, "frozen_head") or self.model.peft_type: - ref_logits = self.model.forward_hydra( - all_tokens, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - ).logits + all_tokens_chunks = torch.chunk(all_tokens, chunks=self.config.method.chunk_size, dim=0) + attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.chunk_size, dim=0) + position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.chunk_size, dim=0) + for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip( + all_tokens_chunks, attention_mask_chunks, position_ids_chunks + ): + all_tokens_chunk = all_tokens_chunk.to(device) + attention_mask_chunk = attention_mask_chunk.to(device) + position_ids_chunk = position_ids_chunk.to(device) + with torch.no_grad(): + logits, *_, values = self.model( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids_chunk, + ) + # TODO(dahoas): When hydra model works need to also support generation on hydra head + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids_chunk, + return_dict=True, + ).logits + elif hasattr(self, "ref_model"): + ref_logits = self.ref_model( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids_chunk, + return_dict=True, + ).logits + ref_logits = ref_logits.to(device) + else: + ref_logits = logits.clone().detach() + if self.config.model.model_arch_type == "seq2seq": + logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) else: - ref_logits = self.ref_model( - all_tokens, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - ).logits - ref_logits = ref_logits.to(device) - - if self.config.model.model_arch_type == "seq2seq": - logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) - else: - # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled - logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) + # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled + logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens_chunk[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens_chunk[:, 1:]) + + values_chunks.append(values.cpu()) + logits_chunks.append(logits.cpu()) + ref_logits_chunks.append(ref_logits.cpu()) + log_probs_chunks.append(logprobs.cpu()) + ref_logprobs_chunks.append(ref_logprobs.cpu()) + + values = torch.cat(values_chunks, dim=0) + logits = torch.cat(logits_chunks, dim=0) + ref_logits = torch.cat(ref_logits_chunks, dim=0) + logprobs = torch.cat(log_probs_chunks, dim=0) + ref_logprobs = torch.cat(ref_logprobs_chunks, dim=0) n_samples: int = samples.shape[0] @@ -450,8 +511,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq attention_mask = sample_outputs != self.tokenizer.pad_token_id start = 0 else: + # NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response start = prompt_tensors.shape[1] - 1 + attention_mask = attention_mask.cpu() log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] kl = log_ratio.exp() - 1 - log_ratio mean_kl_per_token = kl.mean() @@ -467,6 +530,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # from the end of the prompt up to the token, while also including the latter # (these are taken from the student model and not the reference model) ends = start + attention_mask[:, start:].sum(1) + 1 + # NOTE: values[i] is the value of the state after response token i all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] @@ -476,6 +540,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_count = 0 for sample_idx in range(n_samples): + # To compute per token reward first add in kl penalties over trajectory + # NOTE: kl_penalty[i] is kl_diff at token i+1 in the output (w/o EOS) rewards = kl_penalty[sample_idx] # Then add in rewards if scores.shape[1] == 1: @@ -502,7 +568,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_count += 1 if torch.distributed.is_initialized(): - torch.distributed.all_reduce(mean_kl, torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce(mean_kl.to(self.accelerator.device), torch.distributed.ReduceOp.AVG) stats["time/rollout_time"] = clock.tick() stats["policy/sqrt_kl"] = torch.sqrt(mean_kl).item() @@ -549,3 +615,18 @@ def save_pretrained(self, directory: Optional[str] = None, **kwargs): if self.accelerator.is_main_process: self.tokenizer.save_pretrained(directory) + + @staticmethod + def get_topk_indices(input_tensor, window_size: int, k: int, device): + """Computes the indices of the top_k values among `input_tensor` on chunks of size `window_size`""" + # Sum the scores along dim 1 + input_tensor = input_tensor.sum(1).unsqueeze(1) + # Use unfold to create the sliding windows + unfolded = input_tensor.unfold(0, window_size, window_size) + # Find the topk values and indices along the unfolded dimension + _, indices = torch.topk(unfolded, k, dim=2) + # Adjust indices to be relative to original tensor + indices = indices.squeeze(1) + torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to( + device + ).unsqueeze(1) + return indices.reshape(-1)