diff --git a/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py b/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py new file mode 100644 index 00000000..ab0e113b --- /dev/null +++ b/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py @@ -0,0 +1,235 @@ +from typing import List, Literal, Optional, Union + +import numpy as np +import torch +from pydantic import PositiveInt + +from cerebras.modelzoo.common.input_utils import get_streaming_batch_size +from cerebras.modelzoo.config import DataConfig +from cerebras.modelzoo.config.types import ValidatedPath +from cerebras.modelzoo.data.common.input_utils import is_distributed +import logging +from transformers import AutoTokenizer + + + +class NpyRLDataset(torch.utils.data.Dataset): + def __init__( + self, + data_dir: str, + ): + super().__init__() + rollouts_path = data_dir + "/rollouts.npz" + print("rollouts path = ", rollouts_path) + self.MSL = 131072 + self.prompt_len = 256 + + try: + with open(rollouts_path, 'rb') as f: + samples = np.load(f) + + # For this to work, we build the .npz file exactly as we've built for training. + # TODO: Revisit this later, to make it more robust. + + self.advantages = samples['first'] if 'first' in samples else None + self.attention_mask = samples['second'].astype(np.int32) if 'second' in samples else None + self.input_ids = samples['third'].astype(np.int32) if 'third' in samples else None + self.responses = samples['fourth'].astype(np.int32) if 'fourth' in samples else None + self.old_log_probs = samples['fifth'] if 'fifth' in samples else None + self.position_ids = samples['sixth'].astype(np.int32) if 'sixth' in samples else None + self.ref_log_probs = samples['seventh'] if 'seventh' in samples else None + + # If True, we're in training mode; else we're calculating old log probs. + self._has_log_prob = self.old_log_probs is not None + + if not self._has_log_prob: + self.dataset_size = len(self.input_ids) + else: + self.dataset_size = len(self.advantages) + + #self.prompt = samples['eighth'] + + # These keys are common to both the modes i.e training as well as old-log-prob calculation. + for name in ("attention_mask","input_ids","responses","position_ids"): + if getattr(self, name) is None: + raise KeyError(f"Missing required key for '{name}' in {rollouts_path}") + + self.loss_mask = np.zeros((self.dataset_size, self.MSL), dtype=np.float32) + self.loss_mask[:, self.prompt_len:self.prompt_len + len(self.responses[0])] = 1.0 + self.input_ids, self.attention_mask, self.prompts_len = self.shift_ones_left(self.input_ids, self.attention_mask) + self.position_ids = self.pad_length(self.position_ids) + + if self._has_log_prob: + self.responses = self.pad_in_bw(self.responses) + self.advantages = self.pad_in_bw(self.advantages) + self.old_log_probs = self.pad_in_bw(self.old_log_probs) + self.ref_log_probs = self.pad_in_bw(self.ref_log_probs) + + except Exception as e: + raise RuntimeError(f"Failed to read : {rollouts_path}") from e + + def shift_ones_left(self, input_ids: np.ndarray, attention_mask: np.ndarray): + B, MSL = input_ids.shape + new_input_ids = np.empty_like(input_ids) + new_attention_mask = np.zeros_like(attention_mask) + prompt_lens = np.zeros(B, dtype=np.int32) + + for i in range(B): + mask = attention_mask[i] + tokens = input_ids[i] + + ones_idx = np.where(mask == 1)[0] + first_one_idx = ones_idx[0] + prompt_lens[i] = self.prompt_len - first_one_idx + + # indices of active tokens (1's) + active = mask == 1 + valid_tokens = tokens[active] + inactive_tokens = tokens[~active] + valid_len = len(valid_tokens) + + # Put valid tokens first, followed by the others + new_input_ids[i] = np.concatenate([valid_tokens, inactive_tokens]) + new_attention_mask[i, :valid_len] = 1 + + return new_input_ids, new_attention_mask, prompt_lens + + def pad_input_right(self, input_ids, pad_token_id, attention_mask): + batch_size, seq_len = input_ids.shape + new_input_ids = np.full_like(input_ids, pad_token_id) + new_attention_mask = np.zeros_like(attention_mask) + + for i in range(batch_size): + valid_len = int(attention_mask[i].sum()) + + new_input_ids[i, :valid_len] = input_ids[i, -valid_len:] if valid_len > 0 else [] + new_attention_mask[i, :valid_len] = 1 if valid_len > 0 else 0 + + pad_len = self.MSL - 4096 + last_tokens = new_input_ids[:, -1:] + + pad = np.repeat(last_tokens, pad_len, axis=1) + a = np.concatenate([new_input_ids, pad], axis=1) + + last_tokens = new_attention_mask[:, -1:] + pad = np.repeat(last_tokens, pad_len, axis=1) + b = np.concatenate([new_attention_mask, pad], axis=1) + + return a, b + + + def pad_inputs(self, input_ids, pad_token_id, attention_mask): + batch_size, seq_len = input_ids.shape + + padded_inputs = np.full((batch_size, self.MSL), pad_token_id, dtype=input_ids.dtype) + + padded_inputs[:, :seq_len] = input_ids + + mask = attention_mask.astype(bool) + + padded_inputs[:, :seq_len][~mask] = pad_token_id + return padded_inputs + + def pad_length(self, batch): + return np.array([np.pad(seq, (0, self.MSL - len(seq)), mode='constant', constant_values=0) for seq in batch]) + + def pad_in_bw(self, batch): + batch_size = batch.shape[0] + insert_len = batch.shape[1] + out = np.zeros((batch_size, self.MSL), dtype=batch.dtype) + out[:, self.prompt_len:self.prompt_len+insert_len] = batch + return out + + def __len__(self): + return self.dataset_size + + def __getitem__(self, idx): + data = {} + + logging.info(f"Rahul old log probs = {self._has_log_prob}") + if self._has_log_prob: + data = { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "advantages": self.advantages[idx], + "loss_mask": self.loss_mask[idx], + "ref_log_probs": self.ref_log_probs[idx], + "old_log_probs": self.old_log_probs[idx], + "position_ids": self.position_ids[idx], + } + else: + data = { + "responses": self.responses[idx], + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "prompts_len": self.prompts_len[idx], + #"position_ids": self.position_ids[idx], + } + #if idx < 10: + # for k,v in data.items(): + # logging.info(f"{k}: shape{v.shape}, first 20:{v[:20]}") + + return data + +class GptNpyRLDataProcessorConfig(DataConfig): + data_processor: Literal["GptNpyRLDataProcessor"] + + num_workers: int = 0 + """ The number of PyTorch processes used in the dataloader. """ + + prefetch_factor: Optional[int] = 10 + """ The number of batches to prefetch in the dataloader. """ + + persistent_workers: bool = True + + batch_size: PositiveInt = ... + + data_dir: Union[ValidatedPath, List[ValidatedPath]] = ... + "Path to the data files to use." + + sampler: Optional[torch.utils.data.sampler.Sampler] = None + + +class GptNpyRLDataProcessor: + """ + A map style dataset for GPT style models. + + Supports data saved on disk in either of the following formats: + - `(num_tokens,)`, i.e. a set of documents tokenized and concatenated. + We refer to this as the 'corpus' format in what follows. + - `(num_sequences, 3, sequence_length)`, i.e. data that has already + been preprocessed into sequences. We refer to this as the + 'sample' format in what follows. + + Args: + config: The config used to configure the data processor. + """ + + def __init__(self, config: GptNpyRLDataProcessorConfig): + if isinstance(config, dict): + config = GptNpyRLDataProcessorConfig(**config) + + self.config = config + + self.dataset = NpyRLDataset(config.data_dir) + self.batch_size = get_streaming_batch_size(config.batch_size) + self.sampler = config.sampler + + if is_distributed(): + assert self.sampler is None, "Cannot use sampler in config with DDP" + self.sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + shuffle=False, + seed=1, + ) + + def create_dataloader(self): + return torch.utils.data.DataLoader( + self.dataset, + batch_size=self.batch_size, + collate_fn=None, + batch_sampler=self.sampler, + num_workers=self.config.num_workers, + prefetch_factor=self.config.prefetch_factor, + persistent_workers=self.config.persistent_workers, + ) diff --git a/src/cerebras/modelzoo/losses/GRPOLoss.py b/src/cerebras/modelzoo/losses/GRPOLoss.py new file mode 100644 index 00000000..1116414a --- /dev/null +++ b/src/cerebras/modelzoo/losses/GRPOLoss.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +from typing import Optional + +class GRPOLoss(nn.Module): + def __init__( + self, + clip_ratio: Optional[float] = None, + clip_ratio_low : float = 0.2, + clip_ratio_high : float = 0.28, + use_kl_loss : bool = True, + kl_loss_coef : float = 0.0 + ): + """ + GRPO Loss with optional KL regularization. + Args: + eps_clip (float): Clipping epsilon. + kl_loss_coeff (float): Coefficient for KL loss term. + """ + super().__init__() + self.epsilon = 1e-8 + self.use_kl_loss = use_kl_loss + self.kl_loss_coef = kl_loss_coef + self.clip_ratio_low = clip_ratio_low + self.clip_ratio_high = clip_ratio_high + self.clip_ratio = clip_ratio + + def forward( + self, + old_log_probs, + curr_log_probs, + advantages, + loss_mask, + ref_log_probs=None, + ): + """ + Args: + log_probs (Tensor): New log probabilities. + sub_old_log_probs (Tensor): Log probs from old policy. + advantages (Tensor): Advantage estimates. + loss_mask (Tensor): Binary mask for valid entries. + sub_ref_log_probs (Tensor, optional): Reference policy log probs for KL regularization. + Returns: + policy_loss (Tensor): Scalar GRPO loss. + """ + + cliprange=self.clip_ratio + cliprange_low=self.clip_ratio_low if self.clip_ratio_low is not None else cliprange + cliprange_high=self.clip_ratio_high if self.clip_ratio_high is not None else cliprange + sampling_ratio = torch.exp(curr_log_probs - old_log_probs) + + # Clipping for stability + clipped_ratio = torch.clamp( + sampling_ratio, 1.0 - cliprange_low, 1.0 + cliprange_high + ) + + unclipped_loss = sampling_ratio * advantages + clipped_loss = clipped_ratio * advantages + + loss = -torch.min(unclipped_loss, clipped_loss) + + # Apply mask + policy_loss = (loss_mask * loss).sum() / ( + loss_mask.sum() + self.epsilon + ) + + # Optional KL divergence loss + if self.use_kl_loss is True and self.kl_loss_coef > 0.0 and ref_log_probs is not None: + kl = ref_log_probs - curr_log_probs + ratio = torch.exp(kl) + kld = ratio - kl - 1.0 + kl_loss = torch.clamp(kld, min=-10.0, max=10.0) + masked_kl_loss = (loss_mask * kl_loss).mean() + policy_loss += self.kl_loss_coef * masked_kl_loss + + return policy_loss diff --git a/src/cerebras/modelzoo/models/nlp/rl/__init__.py b/src/cerebras/modelzoo/models/nlp/rl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cerebras/modelzoo/models/nlp/rl/model.py b/src/cerebras/modelzoo/models/nlp/rl/model.py new file mode 100644 index 00000000..cf4eaf76 --- /dev/null +++ b/src/cerebras/modelzoo/models/nlp/rl/model.py @@ -0,0 +1,112 @@ +from copy import deepcopy +from typing import Literal, Optional, Union +from pydantic import Field + +import torch +import torch.nn.functional as F + +import cerebras.pytorch as cstorch +from cerebras.modelzoo.losses.GRPOLoss import GRPOLoss +from cerebras.modelzoo.models.nlp.llama.model import LlamaModelConfig +from cerebras.modelzoo.models.nlp.gpt2.model import GPT2ModelConfig +from cerebras.modelzoo.config import ModelConfig +from typing_extensions import Annotated +import logging + +def logprobs_from_logits(logits: torch.Tensor, labels) -> torch.Tensor: + """ + Implementation taken from verL; modified to fix compile issues on our stack. + """ + logp = F.log_softmax(logits, dim=-1) # batch_Size x 3840 X vocab_size + one_hot = cstorch.nn.functional.one_hot( + labels.to(torch.int64), num_classes=logp.size(-1) + ).to(logp.dtype) # batch_size x 3840 x vocab_size + return (logp * one_hot).sum(dim=-1) + + +class RLModelConfig(ModelConfig): + name: Literal["RL"] + policy_model : Annotated[Union[LlamaModelConfig, GPT2ModelConfig], Field(discriminator="name"),] = ... + + use_kl_loss : Optional[bool] = True + kl_loss_coef : Optional[float] = 0.005 + clip_ratio : Optional[float] = None + clip_ratio_low : Optional[float] = 0.2 + clip_ratio_high : Optional[float] = 0.28 + +class RLModel(torch.nn.Module): + def __init__(self, config: RLModelConfig): + super().__init__() + + self.policy_model = config.policy_model() + + self.loss_fn = GRPOLoss(config.clip_ratio, config.clip_ratio_low, config.clip_ratio_high, config.use_kl_loss, config.kl_loss_coef) + + def forward(self, data): + if "old_log_probs" not in data: + logging.info("Rahul inside old log probs") + _, logits = self.policy_model( + data={ + "input_ids": data["input_ids"], + "attention_mask": data["attention_mask"], + "labels": data["input_ids"], + #"position_ids": data["position_ids"], + }, + output_logits=True, + ) + + # TODO: We don't divide by temperature here. veRL does something like logits.div_(temperature), + # where temperature is present in data.meta_info. + logging.info(f"rahul logits shape is {logits.shape}") + response_length = data["responses"].size(-1) + logging.info(f"rahul response len is {response_length}") + + B, MSL, V = logits.shape + + # Create batch indices [0, 1, ..., B-1] + batch_idx = torch.arange(B, dtype=torch.int, device=logits.device).unsqueeze(1) # [B, 1] + # For each batch, compute the token indices to extract + token_idx = ( + torch.arange(response_length, dtype=torch.int, device=logits.device).unsqueeze(0) # [1, 3840] + + (data["prompts_len"].unsqueeze(1) - 1) # shift start position per batch + ) # shape [B, 3840] + # Gather logits using advanced indexing + selected_logits = logits[batch_idx, token_idx, :] # [B, 3840, vocab_size] + old_log_probs = logprobs_from_logits(selected_logits, data["responses"]) + return {"old_log_probs": old_log_probs, "input_ids": data["input_ids"], "attention_mask" : data["attention_mask"], "responses" : data["responses"], "prompts_len" : data["prompts_len"]} + + #logits = logits[:, 165:(165+response_length), :] # [batch_size, response_length, vocab_size] + #old_log_probs = logprobs_from_logits(logits, data["responses"]) + #return {"logits": logits.to(torch.float32), "input_ids": data["input_ids"], "attention_mask" : data["attention_mask"], "responses" : data["responses"], "prompts_len" : data["prompts_len"]} + + logging.info("Rahul executing training") + _, curr_log_probs = self.policy_model( + data={ + "input_ids": data["input_ids"], + "attention_mask": data["attention_mask"], + "labels": data["input_ids"], + "position_ids" : data["position_ids"] + }, + output_logits=True, + ) + + curr_log_probs = torch.log_softmax(curr_log_probs, dim=-1) + # curr_log_probs = torch.gather( + # input=curr_log_probs, + # dim=-1, + # index=data["input_ids"].to(torch.int64).unsqueeze(-1), + # ).squeeze(-1) + one_hot = cstorch.nn.functional.one_hot( + data["input_ids"].to(torch.int64), + num_classes=curr_log_probs.size(-1), + ) # .to(curr_log.probs.dtype) # (8,128,50257) + # Multiply and sum over vocab dimension + curr_log_probs = torch.sum(curr_log_probs * one_hot, dim=-1) # (8,128) + loss = self.loss_fn( + data["old_log_probs"], + curr_log_probs, + data["advantages"], + data["loss_mask"], + data["ref_log_probs"], + ) + return loss diff --git a/src/cerebras/modelzoo/models/nlp/rl/run.py b/src/cerebras/modelzoo/models/nlp/rl/run.py new file mode 100644 index 00000000..866a05d5 --- /dev/null +++ b/src/cerebras/modelzoo/models/nlp/rl/run.py @@ -0,0 +1,16 @@ +# isort: off +# MZ: import sys +# MZ: import os +# MZ: sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../..")) +# isort: on + +if __name__ == '__main__': + import warnings + from cerebras.modelzoo.common.run_utils import run + + warnings.warn( + "Running models using run.py is deprecated. Please switch to using the ModelZoo CLI. " + "See https://training-docs.cerebras.ai/model-zoo/cli-overview for more details." + ) + + run() diff --git a/src/cerebras/modelzoo/registry/registry.yaml b/src/cerebras/modelzoo/registry/registry.yaml index 0d13046e..9628d2ee 100644 --- a/src/cerebras/modelzoo/registry/registry.yaml +++ b/src/cerebras/modelzoo/registry/registry.yaml @@ -68,9 +68,15 @@ models: - cerebras.modelzoo.data.nlp.gpt.DummyIterableDataProcessor.DummyIterableDataProcessor - cerebras.modelzoo.data.nlp.gpt.GptHDF5DataProcessor.GptHDF5DataProcessor - cerebras.modelzoo.data.nlp.gpt.GptHDF5MapDataProcessor.GptHDF5MapDataProcessor + - cerebras.modelzoo.data.nlp.gpt.GptNpyRLDataProcessor.GptNpyRLDataProcessor - cerebras.modelzoo.data.nlp.gpt.HuggingFaceDataProcessorEli5.HuggingFaceDataProcessorEli5 - cerebras.modelzoo.data.nlp.gpt.HuggingFaceIterableDataProcessorEli5.HuggingFaceIterableDataProcessorEli5 +- name: RL + path: cerebras.modelzoo.models.nlp.rl.model.RLModel + data_processor_paths: &rl_data_processor_paths + - cerebras.modelzoo.data.nlp.gpt.GptNpyRLDataProcessor.GptNpyRLDataProcessor + - name: gpt3 path: cerebras.modelzoo.models.nlp.gpt3.model.GPT3Model data_processor_paths: *gpt2_data_processor_paths diff --git a/src/cerebras/modelzoo/trainer/callbacks/__init__.py b/src/cerebras/modelzoo/trainer/callbacks/__init__.py index 00466009..0392f3b0 100644 --- a/src/cerebras/modelzoo/trainer/callbacks/__init__.py +++ b/src/cerebras/modelzoo/trainer/callbacks/__init__.py @@ -81,6 +81,7 @@ from .selective_grad import SelectiveGrad from .model_stats import CountParams from .notification import EmailNotification, SlackNotification +from .rl_log_probs import SaveOldLogProbs # isort: on @@ -145,4 +146,5 @@ "SlackNotification", "RunSchedule", "NaNChecker", + "SaveOldLogProbs", ] diff --git a/src/cerebras/modelzoo/trainer/callbacks/model.py b/src/cerebras/modelzoo/trainer/callbacks/model.py index 5aabebec..8da9c4a5 100644 --- a/src/cerebras/modelzoo/trainer/callbacks/model.py +++ b/src/cerebras/modelzoo/trainer/callbacks/model.py @@ -69,7 +69,10 @@ def on_before_backward(self, trainer, model, outputs): ) def on_save_checkpoint(self, trainer, state_dict): - state_dict["model"] = trainer.model.state_dict() + if hasattr(trainer.model, "policy_model"): + state_dict["model"] = trainer.model.policy_model.model.state_dict() + else: + state_dict["model"] = trainer.model.state_dict() def on_load_checkpoint(self, trainer, state_dict): if "model" not in state_dict: @@ -93,9 +96,15 @@ def on_load_checkpoint(self, trainer, state_dict): # This should be the case that is used for all checkpoints saved # post rel-2.0.0 else: - trainer.model.load_state_dict( - state_dict["model"], - strict=not trainer.checkpoint.disable_strict_checkpoint_loading, + if hasattr(trainer.model, "policy_model"): + trainer.model.policy_model.model.load_state_dict( + state_dict["model"], + strict=not trainer.checkpoint.disable_strict_checkpoint_loading + ) + else: + trainer.model.load_state_dict( + state_dict["model"], + strict=not trainer.checkpoint.disable_strict_checkpoint_loading, ) trainer.logger.info( diff --git a/src/cerebras/modelzoo/trainer/callbacks/optimizer.py b/src/cerebras/modelzoo/trainer/callbacks/optimizer.py index cb0afd63..c098c78f 100644 --- a/src/cerebras/modelzoo/trainer/callbacks/optimizer.py +++ b/src/cerebras/modelzoo/trainer/callbacks/optimizer.py @@ -52,7 +52,10 @@ def setup(self, trainer): elif isinstance(self.optimizer, cstorch.optim.Optimizer): trainer.optimizer = self.optimizer else: - trainer.optimizer = self.optimizer(trainer.model) + if hasattr(trainer.model, "policy_model"): + trainer.optimizer = self.optimizer(trainer.model.policy_model.model) + else: + trainer.optimizer = self.optimizer(trainer.model) def on_fit_start(self, trainer, train_dataloader, val_dataloader, loop): if trainer.optimizer is None: diff --git a/src/cerebras/modelzoo/trainer/callbacks/rl_log_probs.py b/src/cerebras/modelzoo/trainer/callbacks/rl_log_probs.py new file mode 100644 index 00000000..32800938 --- /dev/null +++ b/src/cerebras/modelzoo/trainer/callbacks/rl_log_probs.py @@ -0,0 +1,63 @@ +import os +import numpy as np +import cerebras.pytorch as cstorch +from cerebras.modelzoo.trainer.callbacks import Callback +import torch.nn.functional as F +import torch + +def logprobs_from_logits(logits: torch.Tensor, labels) -> torch.Tensor: + """ + Implementation taken from verL; modified to fix compile issues on our stack. + """ + logp = F.log_softmax(logits, dim=-1) # batch_Size x 3840 X vocab_size + one_hot = cstorch.nn.functional.one_hot( + labels.to(torch.int64), num_classes=logp.size(-1) + ).to(logp.dtype) # batch_size x 3840 x vocab_size + return (logp * one_hot).sum(dim=-1) + + +class SaveOldLogProbs(Callback): + """ + Saves model outputs to out_dir as old_log_probs when in calc-old mode. + This saves the log-probs, per batch. + """ + + def __init__(self, prefix: str = "oldlp"): + self.out_dir = "/n0/lab/sota-rl-inference/eval_rollouts" + self.prefix = prefix + self._count = 0 + + def on_fit_start(self, trainer, train_dataloader, val_dataloader, loop): + os.makedirs(self.out_dir, exist_ok=True) + self._ready = True + + def on_after_forward(self, trainer, model, outputs, batch): + self.post_process(outputs) + + @cstorch.step_closure + def post_process(self, outputs): + # Save tensor as old_log_probs shard + #arr = outputs['output'].numpy() + path = os.path.join(self.out_dir, f"{self.prefix}_{self._count:07d}.npz") + os.makedirs(os.path.dirname(path), exist_ok=True) + + #B, MSL, V = outputs['logits'].shape + #response_length = 3840 + + #batch_idx = torch.arange(B, dtype=torch.int).unsqueeze(1) # [B, 1] + + # For each batch, compute the token indices to extract + '''token_idx = ( + torch.arange(response_length, dtype=torch.int).unsqueeze(0) # [1, 3840] + + (outputs["prompts_len"].unsqueeze(1) - 1) # shift start position per batch + ) ''' # shape [B, 3840] + + # Gather logits using advanced indexing + #selected_logits = outputs['logits'][batch_idx, token_idx, :] # [B, 3840, vocab_size] + + #logits = logits[:, prompt_len:(prompt_len+response_length), :] # [batch_size, response_length, vocab_size] + #old_log_probs = logprobs_from_logits(selected_logits, outputs["responses"]) + #return {"old_log_probs": old_log_probs, "input_ids": data["input_ids"], "attention_mask" : data["attention_mask"], "responses" : data["responses"], "prompts_len" : data["prompts_len"]} + + np.savez(path, old_log_probs=outputs['old_log_probs'].numpy(), inputs=outputs['input_ids'].numpy(), mask=outputs['attention_mask'].numpy(), responses=outputs['responses'], promptlen=outputs['prompts_len']) + self._count += 1