From 5a86a15de69c60ccd95fc64a1bf92433abf1f4d6 Mon Sep 17 00:00:00 2001 From: rahul shrivastava Date: Tue, 26 Aug 2025 11:34:35 +0000 Subject: [PATCH 1/6] Introduce NPYData processor to read npz dumped from verl Signed-off-by: Rahul Shrivastava --- .../data/nlp/gpt/GptNpyRLDataProcessor.py | 135 ++++++++++++++++++ src/cerebras/modelzoo/losses/GRPOLoss.py | 62 ++++++++ .../modelzoo/models/nlp/rl/__init__.py | 0 src/cerebras/modelzoo/models/nlp/rl/model.py | 72 ++++++++++ src/cerebras/modelzoo/models/nlp/rl/run.py | 16 +++ src/cerebras/modelzoo/registry/registry.yaml | 6 + 6 files changed, 291 insertions(+) create mode 100644 src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py create mode 100644 src/cerebras/modelzoo/losses/GRPOLoss.py create mode 100644 src/cerebras/modelzoo/models/nlp/rl/__init__.py create mode 100644 src/cerebras/modelzoo/models/nlp/rl/model.py create mode 100644 src/cerebras/modelzoo/models/nlp/rl/run.py diff --git a/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py b/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py new file mode 100644 index 00000000..b591fc7d --- /dev/null +++ b/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py @@ -0,0 +1,135 @@ +from typing import List, Literal, Optional, Union + +import numpy as np +import torch +from pydantic import PositiveInt + +from cerebras.modelzoo.common.input_utils import get_streaming_batch_size +from cerebras.modelzoo.config import DataConfig +from cerebras.modelzoo.config.types import ValidatedPath +from cerebras.modelzoo.data.common.input_utils import is_distributed + + +class NpyRLDataset(torch.utils.data.Dataset): + def __init__( + self, + data_dir: str, + ): + super().__init__() + rollouts_path = data_dir + "/rollouts.npz" + self.MSL = 131072 + self.prompt_len = 256 + try: + with open(rollouts_path, 'rb') as f: + samples = np.load(f) + self.advantages = samples['first'] + self.attention_mask = samples['second'].astype(np.int32) + self.input_ids = samples['third'].astype(np.int32) + self.responses = samples['fourth'].astype(np.int32) + self.old_log_probs = samples['fifth'] + self.position_ids = samples['sixth'].astype(np.int32) + self.ref_log_probs = samples['seventh'] + self.dataset_size = len(self.advantages) + #self.prompt = samples['eighth'] + + self.loss_mask = np.zeros((self.dataset_size, self.MSL), dtype=np.float32) + self.loss_mask[:, self.prompt_len:self.prompt_len + len(self.responses[0])] = 1.0 + self.input_ids = self.pad_length(self.input_ids) + self.attention_mask = self.pad_length(self.attention_mask) + self.position_ids = self.pad_length(self.position_ids) + + self.advantages = self.pad_in_bw(self.advantages) + self.responses = self.pad_in_bw(self.responses) + self.old_log_probs = self.pad_in_bw(self.old_log_probs) + self.ref_log_probs = self.pad_in_bw(self.ref_log_probs) + + except Exception as e: + raise RuntimeError(f"Failed to read : {rollouts_path}") from e + + def pad_length(self, batch): + return np.array([np.pad(seq, (0, self.MSL - len(seq)), mode='constant', constant_values=0) for seq in batch]) + + def pad_in_bw(self, batch): + batch_size = batch.shape[0] + insert_len = batch.shape[1] + out = np.zeros((batch_size, self.MSL), dtype=batch.dtype) + out[:, self.prompt_len:self.prompt_len+insert_len] = batch + return out + + def __len__(self): + return self.dataset_size + + def __getitem__(self, idx): + data = { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "advantages": self.advantages[idx], + "loss_mask": self.loss_mask[idx], + "ref_log_probs": self.ref_log_probs[idx], + "old_log_probs": self.old_log_probs[idx], + "position_ids": self.position_ids[idx], + } + return data + +class GptNpyRLDataProcessorConfig(DataConfig): + data_processor: Literal["GptNpyRLDataProcessor"] + + num_workers: int = 0 + """ The number of PyTorch processes used in the dataloader. """ + + prefetch_factor: Optional[int] = 10 + """ The number of batches to prefetch in the dataloader. """ + + persistent_workers: bool = True + + batch_size: PositiveInt = ... + + data_dir: Union[ValidatedPath, List[ValidatedPath]] = ... + "Path to the data files to use." + + sampler: Optional[torch.utils.data.sampler.Sampler] = None + + +class GptNpyRLDataProcessor: + """ + A map style dataset for GPT style models. + + Supports data saved on disk in either of the following formats: + - `(num_tokens,)`, i.e. a set of documents tokenized and concatenated. + We refer to this as the 'corpus' format in what follows. + - `(num_sequences, 3, sequence_length)`, i.e. data that has already + been preprocessed into sequences. We refer to this as the + 'sample' format in what follows. + + Args: + config: The config used to configure the data processor. + """ + + def __init__(self, config: GptNpyRLDataProcessorConfig): + if isinstance(config, dict): + config = GptNpyRLDataProcessorConfig(**config) + + self.config = config + + self.dataset = NpyRLDataset(config.data_dir) + self.batch_size = get_streaming_batch_size(config.batch_size) + self.sampler = config.sampler + + if is_distributed(): + assert self.sampler is None, "Cannot use sampler in config with DDP" + self.sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + shuffle=False, + seed=1, + ) + + def create_dataloader(self): + return torch.utils.data.DataLoader( + self.dataset, + batch_size=self.batch_size, + collate_fn=None, + batch_sampler=self.sampler, + num_workers=self.config.num_workers, + prefetch_factor=self.config.prefetch_factor, + persistent_workers=self.config.persistent_workers, + ) diff --git a/src/cerebras/modelzoo/losses/GRPOLoss.py b/src/cerebras/modelzoo/losses/GRPOLoss.py new file mode 100644 index 00000000..6d366d33 --- /dev/null +++ b/src/cerebras/modelzoo/losses/GRPOLoss.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn + + +class GRPOLoss(nn.Module): + def __init__(self, eps_clip: float = 0.2, kl_loss_coeff: float = 0.1): + """ + GRPO Loss with optional KL regularization. + Args: + eps_clip (float): Clipping epsilon. + kl_loss_coeff (float): Coefficient for KL loss term. + """ + super().__init__() + self.eps_clip = eps_clip + self.kl_loss_coeff = kl_loss_coeff + self.epsilon = 1e-6 + + def forward( + self, + old_log_probs, + curr_log_probs, + advantages, + loss_mask, + ref_log_probs=None, + ): + """ + Args: + log_probs (Tensor): New log probabilities. + sub_old_log_probs (Tensor): Log probs from old policy. + advantages (Tensor): Advantage estimates. + loss_mask (Tensor): Binary mask for valid entries. + sub_ref_log_probs (Tensor, optional): Reference policy log probs for KL regularization. + Returns: + policy_loss (Tensor): Scalar GRPO loss. + """ + sampling_ratio = torch.exp(curr_log_probs - old_log_probs) + + # Clipping for stability + clipped_ratio = torch.clamp( + sampling_ratio, 1.0 - self.eps_clip, 1.0 + self.eps_clip + ) + + unclipped_loss = sampling_ratio * advantages + clipped_loss = clipped_ratio * advantages + + loss = -torch.min(unclipped_loss, clipped_loss) + + # Apply mask + policy_loss = (loss_mask * loss).sum() / ( + loss_mask.sum() + self.epsilon + ) + + # Optional KL divergence loss + if self.kl_loss_coeff > 0.0 and ref_log_probs is not None: + kl = ref_log_probs - curr_log_probs + ratio = torch.exp(kl) + kld = ratio - kl - 1.0 + kl_loss = torch.clamp(kld, min=-10.0, max=10.0) + masked_kl_loss = (loss_mask * kl_loss).mean() + policy_loss += self.kl_loss_coeff * masked_kl_loss + + return policy_loss diff --git a/src/cerebras/modelzoo/models/nlp/rl/__init__.py b/src/cerebras/modelzoo/models/nlp/rl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cerebras/modelzoo/models/nlp/rl/model.py b/src/cerebras/modelzoo/models/nlp/rl/model.py new file mode 100644 index 00000000..7c689493 --- /dev/null +++ b/src/cerebras/modelzoo/models/nlp/rl/model.py @@ -0,0 +1,72 @@ +from copy import deepcopy +from typing import Literal + +import torch + +import cerebras.pytorch as cstorch +from cerebras.modelzoo.losses.GRPOLoss import GRPOLoss +from cerebras.modelzoo.models.nlp.llama.model import LlamaModel, LlamaModelConfig + + +class RLModelConfig(LlamaModelConfig): + name: Literal["RL"] + + +class RLModel(torch.nn.Module): + def __init__(self, config: RLModelConfig): + super().__init__() + + self.policy_model = LlamaModel(config) + + #self.ref_model = deepcopy(self.policy_model) + #self.ref_model.eval() + + self.loss_fn = GRPOLoss() + + def forward(self, data): + '''with torch.no_grad(): + _, old_log_probs = self.ref_model( + data={ + "input_ids": data["input_ids"], + "attention_mask": data["attention_mask"], + "labels": data["input_ids"], + }, + output_logits=True, + ) + old_log_probs = torch.log_softmax(old_log_probs, dim=-1) + old_log_probs = torch.gather( + input=old_log_probs, + dim=-1, + index=data["input_ids"].to(torch.int64).unsqueeze(-1), + ).squeeze(-1)''' + + _, curr_log_probs = self.policy_model( + data={ + "input_ids": data["input_ids"], + "attention_mask": data["attention_mask"], + "labels": data["input_ids"], + "position_ids" : data["position_ids"] + }, + output_logits=True, + ) + + curr_log_probs = torch.log_softmax(curr_log_probs, dim=-1) + # curr_log_probs = torch.gather( + # input=curr_log_probs, + # dim=-1, + # index=data["input_ids"].to(torch.int64).unsqueeze(-1), + # ).squeeze(-1) + one_hot = cstorch.nn.functional.one_hot( + data["input_ids"].to(torch.int64), + num_classes=curr_log_probs.size(-1), + ) # .to(curr_log.probs.dtype) # (8,128,50257) + # Multiply and sum over vocab dimension + curr_log_probs = torch.sum(curr_log_probs * one_hot, dim=-1) # (8,128) + loss = self.loss_fn( + data["old_log_probs"], + curr_log_probs, + data["advantages"], + data["loss_mask"], + data["ref_log_probs"], + ) + return loss diff --git a/src/cerebras/modelzoo/models/nlp/rl/run.py b/src/cerebras/modelzoo/models/nlp/rl/run.py new file mode 100644 index 00000000..866a05d5 --- /dev/null +++ b/src/cerebras/modelzoo/models/nlp/rl/run.py @@ -0,0 +1,16 @@ +# isort: off +# MZ: import sys +# MZ: import os +# MZ: sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../..")) +# isort: on + +if __name__ == '__main__': + import warnings + from cerebras.modelzoo.common.run_utils import run + + warnings.warn( + "Running models using run.py is deprecated. Please switch to using the ModelZoo CLI. " + "See https://training-docs.cerebras.ai/model-zoo/cli-overview for more details." + ) + + run() diff --git a/src/cerebras/modelzoo/registry/registry.yaml b/src/cerebras/modelzoo/registry/registry.yaml index 0d13046e..9628d2ee 100644 --- a/src/cerebras/modelzoo/registry/registry.yaml +++ b/src/cerebras/modelzoo/registry/registry.yaml @@ -68,9 +68,15 @@ models: - cerebras.modelzoo.data.nlp.gpt.DummyIterableDataProcessor.DummyIterableDataProcessor - cerebras.modelzoo.data.nlp.gpt.GptHDF5DataProcessor.GptHDF5DataProcessor - cerebras.modelzoo.data.nlp.gpt.GptHDF5MapDataProcessor.GptHDF5MapDataProcessor + - cerebras.modelzoo.data.nlp.gpt.GptNpyRLDataProcessor.GptNpyRLDataProcessor - cerebras.modelzoo.data.nlp.gpt.HuggingFaceDataProcessorEli5.HuggingFaceDataProcessorEli5 - cerebras.modelzoo.data.nlp.gpt.HuggingFaceIterableDataProcessorEli5.HuggingFaceIterableDataProcessorEli5 +- name: RL + path: cerebras.modelzoo.models.nlp.rl.model.RLModel + data_processor_paths: &rl_data_processor_paths + - cerebras.modelzoo.data.nlp.gpt.GptNpyRLDataProcessor.GptNpyRLDataProcessor + - name: gpt3 path: cerebras.modelzoo.models.nlp.gpt3.model.GPT3Model data_processor_paths: *gpt2_data_processor_paths From 49c596fd5b183cd8f8c9f8eaf57c2891551c89cb Mon Sep 17 00:00:00 2001 From: Rahul Shrivastava Date: Wed, 8 Oct 2025 11:31:31 +0000 Subject: [PATCH 2/6] Checkpoint changes Signed-off-by: Rahul Shrivastava --- .../modelzoo/trainer/callbacks/model.py | 17 +++++++++++++---- .../modelzoo/trainer/callbacks/optimizer.py | 5 ++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/cerebras/modelzoo/trainer/callbacks/model.py b/src/cerebras/modelzoo/trainer/callbacks/model.py index 5aabebec..8da9c4a5 100644 --- a/src/cerebras/modelzoo/trainer/callbacks/model.py +++ b/src/cerebras/modelzoo/trainer/callbacks/model.py @@ -69,7 +69,10 @@ def on_before_backward(self, trainer, model, outputs): ) def on_save_checkpoint(self, trainer, state_dict): - state_dict["model"] = trainer.model.state_dict() + if hasattr(trainer.model, "policy_model"): + state_dict["model"] = trainer.model.policy_model.model.state_dict() + else: + state_dict["model"] = trainer.model.state_dict() def on_load_checkpoint(self, trainer, state_dict): if "model" not in state_dict: @@ -93,9 +96,15 @@ def on_load_checkpoint(self, trainer, state_dict): # This should be the case that is used for all checkpoints saved # post rel-2.0.0 else: - trainer.model.load_state_dict( - state_dict["model"], - strict=not trainer.checkpoint.disable_strict_checkpoint_loading, + if hasattr(trainer.model, "policy_model"): + trainer.model.policy_model.model.load_state_dict( + state_dict["model"], + strict=not trainer.checkpoint.disable_strict_checkpoint_loading + ) + else: + trainer.model.load_state_dict( + state_dict["model"], + strict=not trainer.checkpoint.disable_strict_checkpoint_loading, ) trainer.logger.info( diff --git a/src/cerebras/modelzoo/trainer/callbacks/optimizer.py b/src/cerebras/modelzoo/trainer/callbacks/optimizer.py index cb0afd63..c098c78f 100644 --- a/src/cerebras/modelzoo/trainer/callbacks/optimizer.py +++ b/src/cerebras/modelzoo/trainer/callbacks/optimizer.py @@ -52,7 +52,10 @@ def setup(self, trainer): elif isinstance(self.optimizer, cstorch.optim.Optimizer): trainer.optimizer = self.optimizer else: - trainer.optimizer = self.optimizer(trainer.model) + if hasattr(trainer.model, "policy_model"): + trainer.optimizer = self.optimizer(trainer.model.policy_model.model) + else: + trainer.optimizer = self.optimizer(trainer.model) def on_fit_start(self, trainer, train_dataloader, val_dataloader, loop): if trainer.optimizer is None: From 8bbf4b344e03e68e209dc2ecdc3556fc3aa4d5da Mon Sep 17 00:00:00 2001 From: rahul shrivastava Date: Mon, 13 Oct 2025 04:36:10 -0700 Subject: [PATCH 3/6] Add config params for loss parameters. Add policy model subclass as a config param. Signed-off-by: rahul shrivastava --- src/cerebras/modelzoo/losses/GRPOLoss.py | 27 +++++++++----- src/cerebras/modelzoo/models/nlp/rl/model.py | 38 ++++++++------------ 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/src/cerebras/modelzoo/losses/GRPOLoss.py b/src/cerebras/modelzoo/losses/GRPOLoss.py index 6d366d33..f42af387 100644 --- a/src/cerebras/modelzoo/losses/GRPOLoss.py +++ b/src/cerebras/modelzoo/losses/GRPOLoss.py @@ -1,9 +1,14 @@ import torch import torch.nn as nn - class GRPOLoss(nn.Module): - def __init__(self, eps_clip: float = 0.2, kl_loss_coeff: float = 0.1): + def __init__( + self, + clip_ratio: Optional[float] = None, + clip_ratio_low : float = 0.2, + clip_ratio_high : float = 0.28, + use_kl_loss : bool = True, + kl_loss_coef : float = 0.0) """ GRPO Loss with optional KL regularization. Args: @@ -11,9 +16,11 @@ def __init__(self, eps_clip: float = 0.2, kl_loss_coeff: float = 0.1): kl_loss_coeff (float): Coefficient for KL loss term. """ super().__init__() - self.eps_clip = eps_clip - self.kl_loss_coeff = kl_loss_coeff - self.epsilon = 1e-6 + self.epsilon = 1e-8 + self.use_kl_loss = use_kl_loss + self.kl_loss_coef = kl_loss_coef + self.clip_ratio_low = clip_ratio_low + self.clip_ratio_high = clip_ratio_high def forward( self, @@ -33,11 +40,15 @@ def forward( Returns: policy_loss (Tensor): Scalar GRPO loss. """ + + cliprange=self.clip_ratio + cliprange_low=config.clip_ratio_low if config.clip_ratio_low is not None else cliprange + cliprange_high=config.clip_ratio_high if config.clip_ratio_high is not None else cliprange sampling_ratio = torch.exp(curr_log_probs - old_log_probs) # Clipping for stability clipped_ratio = torch.clamp( - sampling_ratio, 1.0 - self.eps_clip, 1.0 + self.eps_clip + sampling_ratio, 1.0 - cliprange_low, 1.0 + cliprange_high) ) unclipped_loss = sampling_ratio * advantages @@ -51,12 +62,12 @@ def forward( ) # Optional KL divergence loss - if self.kl_loss_coeff > 0.0 and ref_log_probs is not None: + if self.use_kl_loss is True and self.kl_loss_coef > 0.0 and ref_log_probs is not None: kl = ref_log_probs - curr_log_probs ratio = torch.exp(kl) kld = ratio - kl - 1.0 kl_loss = torch.clamp(kld, min=-10.0, max=10.0) masked_kl_loss = (loss_mask * kl_loss).mean() - policy_loss += self.kl_loss_coeff * masked_kl_loss + policy_loss += self.kl_loss_coef * masked_kl_loss return policy_loss diff --git a/src/cerebras/modelzoo/models/nlp/rl/model.py b/src/cerebras/modelzoo/models/nlp/rl/model.py index 7c689493..1d9a5265 100644 --- a/src/cerebras/modelzoo/models/nlp/rl/model.py +++ b/src/cerebras/modelzoo/models/nlp/rl/model.py @@ -1,45 +1,35 @@ from copy import deepcopy -from typing import Literal +from typing import Literal, Optional, Union import torch import cerebras.pytorch as cstorch from cerebras.modelzoo.losses.GRPOLoss import GRPOLoss -from cerebras.modelzoo.models.nlp.llama.model import LlamaModel, LlamaModelConfig +from cerebras.modelzoo.models.nlp.llama.model import LlamaModelConfig +from cerebras.modelzoo.models.nlp.gpt2.model import GPT2ModelConfig +from cerebras.modelzoo.config import ModelConfig +from typing_extensions import Annotated -class RLModelConfig(LlamaModelConfig): +class RLModelConfig(ModelConfig): name: Literal["RL"] + policy_model : Annotated[Union[LlamaModelConfig, GPT2ModelConfig], Field(discriminator="name"),] = ... + use_kl_loss : Optional[bool] = True + kl_loss_coef : Optional[float] = 0.005 + clip_ratio : Optional[float] + clip_ratio_low : Optional[float] = 0.2 + clip_ratio_high : Optional[float] = 0.28 class RLModel(torch.nn.Module): def __init__(self, config: RLModelConfig): super().__init__() - self.policy_model = LlamaModel(config) + self.policy_model = config.policy_model() - #self.ref_model = deepcopy(self.policy_model) - #self.ref_model.eval() - - self.loss_fn = GRPOLoss() + self.loss_fn = GRPOLoss(config.clip_ratio, config.clip_ratio_low, config.clip_ratio_high, config.use_kl_loss, config.kl_loss_coef) def forward(self, data): - '''with torch.no_grad(): - _, old_log_probs = self.ref_model( - data={ - "input_ids": data["input_ids"], - "attention_mask": data["attention_mask"], - "labels": data["input_ids"], - }, - output_logits=True, - ) - old_log_probs = torch.log_softmax(old_log_probs, dim=-1) - old_log_probs = torch.gather( - input=old_log_probs, - dim=-1, - index=data["input_ids"].to(torch.int64).unsqueeze(-1), - ).squeeze(-1)''' - _, curr_log_probs = self.policy_model( data={ "input_ids": data["input_ids"], From 555e72872e0153057738746fe896293428a61ee4 Mon Sep 17 00:00:00 2001 From: Thomas Kidd Date: Thu, 23 Oct 2025 11:46:22 +0000 Subject: [PATCH 4/6] Fix typos Signed-off-by: Thomas Kidd --- src/cerebras/modelzoo/losses/GRPOLoss.py | 11 +++++++---- src/cerebras/modelzoo/models/nlp/rl/model.py | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/cerebras/modelzoo/losses/GRPOLoss.py b/src/cerebras/modelzoo/losses/GRPOLoss.py index f42af387..1116414a 100644 --- a/src/cerebras/modelzoo/losses/GRPOLoss.py +++ b/src/cerebras/modelzoo/losses/GRPOLoss.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from typing import Optional class GRPOLoss(nn.Module): def __init__( @@ -8,7 +9,8 @@ def __init__( clip_ratio_low : float = 0.2, clip_ratio_high : float = 0.28, use_kl_loss : bool = True, - kl_loss_coef : float = 0.0) + kl_loss_coef : float = 0.0 + ): """ GRPO Loss with optional KL regularization. Args: @@ -21,6 +23,7 @@ def __init__( self.kl_loss_coef = kl_loss_coef self.clip_ratio_low = clip_ratio_low self.clip_ratio_high = clip_ratio_high + self.clip_ratio = clip_ratio def forward( self, @@ -42,13 +45,13 @@ def forward( """ cliprange=self.clip_ratio - cliprange_low=config.clip_ratio_low if config.clip_ratio_low is not None else cliprange - cliprange_high=config.clip_ratio_high if config.clip_ratio_high is not None else cliprange + cliprange_low=self.clip_ratio_low if self.clip_ratio_low is not None else cliprange + cliprange_high=self.clip_ratio_high if self.clip_ratio_high is not None else cliprange sampling_ratio = torch.exp(curr_log_probs - old_log_probs) # Clipping for stability clipped_ratio = torch.clamp( - sampling_ratio, 1.0 - cliprange_low, 1.0 + cliprange_high) + sampling_ratio, 1.0 - cliprange_low, 1.0 + cliprange_high ) unclipped_loss = sampling_ratio * advantages diff --git a/src/cerebras/modelzoo/models/nlp/rl/model.py b/src/cerebras/modelzoo/models/nlp/rl/model.py index 1d9a5265..191e8c93 100644 --- a/src/cerebras/modelzoo/models/nlp/rl/model.py +++ b/src/cerebras/modelzoo/models/nlp/rl/model.py @@ -1,5 +1,6 @@ from copy import deepcopy from typing import Literal, Optional, Union +from pydantic import Field import torch @@ -17,7 +18,7 @@ class RLModelConfig(ModelConfig): use_kl_loss : Optional[bool] = True kl_loss_coef : Optional[float] = 0.005 - clip_ratio : Optional[float] + clip_ratio : Optional[float] = None clip_ratio_low : Optional[float] = 0.2 clip_ratio_high : Optional[float] = 0.28 From 233018c5cb9cef240154b38b9719313160f63145 Mon Sep 17 00:00:00 2001 From: Thomas Kidd Date: Fri, 24 Oct 2025 07:04:18 +0000 Subject: [PATCH 5/6] MZ changes for old log prob offloading --- .../data/nlp/gpt/GptNpyRLDataProcessor.py | 71 +++++++++++++------ src/cerebras/modelzoo/models/nlp/rl/model.py | 30 ++++++++ .../modelzoo/trainer/callbacks/__init__.py | 2 + .../trainer/callbacks/rl_log_probs.py | 34 +++++++++ 4 files changed, 116 insertions(+), 21 deletions(-) create mode 100644 src/cerebras/modelzoo/trainer/callbacks/rl_log_probs.py diff --git a/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py b/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py index b591fc7d..f21bb8c8 100644 --- a/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py +++ b/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py @@ -22,26 +22,44 @@ def __init__( try: with open(rollouts_path, 'rb') as f: samples = np.load(f) - self.advantages = samples['first'] - self.attention_mask = samples['second'].astype(np.int32) - self.input_ids = samples['third'].astype(np.int32) - self.responses = samples['fourth'].astype(np.int32) - self.old_log_probs = samples['fifth'] - self.position_ids = samples['sixth'].astype(np.int32) - self.ref_log_probs = samples['seventh'] - self.dataset_size = len(self.advantages) + + # For this to work, we build the .npz file exactly as we've built for training. + # TODO: Revisit this later, to make it more robust. + + self.advantages = samples['first'] if 'first' in samples else None + self.attention_mask = samples['second'].astype(np.int32) if 'second' in samples else None + self.input_ids = samples['third'].astype(np.int32) if 'third' in samples else None + self.responses = samples['fourth'].astype(np.int32) if 'fourth' in samples else None + self.old_log_probs = samples['fifth'] if 'fifth' in samples else None + self.position_ids = samples['sixth'].astype(np.int32) if 'sixth' in samples else None + self.ref_log_probs = samples['seventh'] if 'seventh' in samples else None + + # If True, we're in training mode; else we're calculating old log probs. + self._has_log_prob = self.old_log_probs is not None + + if not self._has_log_prob: + self.dataset_size = len(self.input_ids) + else: + self.dataset_size = len(self.advantages) + #self.prompt = samples['eighth'] + # These keys are common to both the modes i.e training as well as old-log-prob calculation. + for name in ("attention_mask","input_ids","responses","position_ids"): + if getattr(self, name) is None: + raise KeyError(f"Missing required key for '{name}' in {rollouts_path}") + self.loss_mask = np.zeros((self.dataset_size, self.MSL), dtype=np.float32) self.loss_mask[:, self.prompt_len:self.prompt_len + len(self.responses[0])] = 1.0 self.input_ids = self.pad_length(self.input_ids) self.attention_mask = self.pad_length(self.attention_mask) self.position_ids = self.pad_length(self.position_ids) - self.advantages = self.pad_in_bw(self.advantages) - self.responses = self.pad_in_bw(self.responses) - self.old_log_probs = self.pad_in_bw(self.old_log_probs) - self.ref_log_probs = self.pad_in_bw(self.ref_log_probs) + if self._has_log_prob: + self.advantages = self.pad_in_bw(self.advantages) + self.responses = self.pad_in_bw(self.responses) + self.old_log_probs = self.pad_in_bw(self.old_log_probs) + self.ref_log_probs = self.pad_in_bw(self.ref_log_probs) except Exception as e: raise RuntimeError(f"Failed to read : {rollouts_path}") from e @@ -60,15 +78,26 @@ def __len__(self): return self.dataset_size def __getitem__(self, idx): - data = { - "input_ids": self.input_ids[idx], - "attention_mask": self.attention_mask[idx], - "advantages": self.advantages[idx], - "loss_mask": self.loss_mask[idx], - "ref_log_probs": self.ref_log_probs[idx], - "old_log_probs": self.old_log_probs[idx], - "position_ids": self.position_ids[idx], - } + data = {} + + if self._has_log_prob: + data = { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "advantages": self.advantages[idx], + "loss_mask": self.loss_mask[idx], + "ref_log_probs": self.ref_log_probs[idx], + "old_log_probs": self.old_log_probs[idx], + "position_ids": self.position_ids[idx], + } + else: + data = { + "responses": self.responses[idx], + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "position_ids": self.position_ids[idx], + } + return data class GptNpyRLDataProcessorConfig(DataConfig): diff --git a/src/cerebras/modelzoo/models/nlp/rl/model.py b/src/cerebras/modelzoo/models/nlp/rl/model.py index 191e8c93..8dbfb856 100644 --- a/src/cerebras/modelzoo/models/nlp/rl/model.py +++ b/src/cerebras/modelzoo/models/nlp/rl/model.py @@ -3,6 +3,7 @@ from pydantic import Field import torch +import torch.nn.functional as F import cerebras.pytorch as cstorch from cerebras.modelzoo.losses.GRPOLoss import GRPOLoss @@ -11,6 +12,16 @@ from cerebras.modelzoo.config import ModelConfig from typing_extensions import Annotated +def logprobs_from_logits(logits: torch.Tensor, labels) -> torch.Tensor: + """ + Implementation taken from verL; modified to fix compile issues on our stack. + """ + logp = F.log_softmax(logits, dim=-1) + one_hot = cstorch.nn.functional.one_hot( + labels.to(torch.int64), num_classes=logp.size(-1) + ).to(logp.dtype) + return (logp * one_hot).sum(dim=-1) + class RLModelConfig(ModelConfig): name: Literal["RL"] @@ -31,6 +42,25 @@ def __init__(self, config: RLModelConfig): self.loss_fn = GRPOLoss(config.clip_ratio, config.clip_ratio_low, config.clip_ratio_high, config.use_kl_loss, config.kl_loss_coef) def forward(self, data): + if "old_log_probs" not in data: + _, logits = self.policy_model( + data={ + "input_ids": data["input_ids"], + "attention_mask": data["attention_mask"], + "labels": data["input_ids"], + "position_ids": data["position_ids"], + }, + output_logits=True, + ) + + # TODO: We don't divide by temperature here. veRL does something like logits.div_(temperature), + # where temperature is present in data.meta_info. + + response_length = data["responses"].size(-1) + logits = logits[:, -response_length - 1 : -1, :] # [batch_size, response_length, vocab_size] + old_log_probs = logprobs_from_logits(logits, data["responses"]) + return old_log_probs + _, curr_log_probs = self.policy_model( data={ "input_ids": data["input_ids"], diff --git a/src/cerebras/modelzoo/trainer/callbacks/__init__.py b/src/cerebras/modelzoo/trainer/callbacks/__init__.py index 00466009..0392f3b0 100644 --- a/src/cerebras/modelzoo/trainer/callbacks/__init__.py +++ b/src/cerebras/modelzoo/trainer/callbacks/__init__.py @@ -81,6 +81,7 @@ from .selective_grad import SelectiveGrad from .model_stats import CountParams from .notification import EmailNotification, SlackNotification +from .rl_log_probs import SaveOldLogProbs # isort: on @@ -145,4 +146,5 @@ "SlackNotification", "RunSchedule", "NaNChecker", + "SaveOldLogProbs", ] diff --git a/src/cerebras/modelzoo/trainer/callbacks/rl_log_probs.py b/src/cerebras/modelzoo/trainer/callbacks/rl_log_probs.py new file mode 100644 index 00000000..d6caa978 --- /dev/null +++ b/src/cerebras/modelzoo/trainer/callbacks/rl_log_probs.py @@ -0,0 +1,34 @@ +import os +import numpy as np +import torch +from cerebras.modelzoo.trainer.callbacks import Callback + +class SaveOldLogProbs(Callback): + """ + Saves model outputs to out_dir as old_log_probs when in calc-old mode. + This saves the log-probs, per batch. + """ + + def __init__(self, prefix: str = "oldlp"): + self.out_dir = os.getcwd() + self.prefix = prefix + self._count = 0 + + def on_fit_start(self, trainer, train_dataloader, val_dataloader, loop): + os.makedirs(self.out_dir, exist_ok=True) + self._ready = True + + def on_after_forward(self, trainer, model, outputs, batch): + if batch is not None and isinstance(batch, dict) and ("old_log_probs" in batch): + # Training step -- can be ignored. + return + + if not isinstance(outputs, torch.Tensor): + # When we're calculating old log probs, we should get a tensor. + return + + # Save tensor as old_log_probs shard + arr = outputs.detach().cpu().numpy() + path = os.path.join(self.out_dir, f"{self.prefix}_{self._count:07d}.npz") + np.savez(path, old_log_probs=arr) + self._count += 1 From 384d168132ca3c3f9f533bd8189ffc6acb387447 Mon Sep 17 00:00:00 2001 From: Thomas Kidd Date: Wed, 12 Nov 2025 11:05:16 +0000 Subject: [PATCH 6/6] - Remove left padding and add right padding - Model level changes to account for different padding for different prompts Signed-off-by: Thomas Kidd --- .../data/nlp/gpt/GptNpyRLDataProcessor.py | 81 +++++++++++++++++-- src/cerebras/modelzoo/models/nlp/rl/model.py | 33 ++++++-- .../trainer/callbacks/rl_log_probs.py | 51 +++++++++--- 3 files changed, 142 insertions(+), 23 deletions(-) diff --git a/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py b/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py index f21bb8c8..ab0e113b 100644 --- a/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py +++ b/src/cerebras/modelzoo/data/nlp/gpt/GptNpyRLDataProcessor.py @@ -8,6 +8,9 @@ from cerebras.modelzoo.config import DataConfig from cerebras.modelzoo.config.types import ValidatedPath from cerebras.modelzoo.data.common.input_utils import is_distributed +import logging +from transformers import AutoTokenizer + class NpyRLDataset(torch.utils.data.Dataset): @@ -17,8 +20,10 @@ def __init__( ): super().__init__() rollouts_path = data_dir + "/rollouts.npz" + print("rollouts path = ", rollouts_path) self.MSL = 131072 self.prompt_len = 256 + try: with open(rollouts_path, 'rb') as f: samples = np.load(f) @@ -51,22 +56,83 @@ def __init__( self.loss_mask = np.zeros((self.dataset_size, self.MSL), dtype=np.float32) self.loss_mask[:, self.prompt_len:self.prompt_len + len(self.responses[0])] = 1.0 - self.input_ids = self.pad_length(self.input_ids) - self.attention_mask = self.pad_length(self.attention_mask) + self.input_ids, self.attention_mask, self.prompts_len = self.shift_ones_left(self.input_ids, self.attention_mask) self.position_ids = self.pad_length(self.position_ids) if self._has_log_prob: - self.advantages = self.pad_in_bw(self.advantages) self.responses = self.pad_in_bw(self.responses) + self.advantages = self.pad_in_bw(self.advantages) self.old_log_probs = self.pad_in_bw(self.old_log_probs) self.ref_log_probs = self.pad_in_bw(self.ref_log_probs) except Exception as e: raise RuntimeError(f"Failed to read : {rollouts_path}") from e + def shift_ones_left(self, input_ids: np.ndarray, attention_mask: np.ndarray): + B, MSL = input_ids.shape + new_input_ids = np.empty_like(input_ids) + new_attention_mask = np.zeros_like(attention_mask) + prompt_lens = np.zeros(B, dtype=np.int32) + + for i in range(B): + mask = attention_mask[i] + tokens = input_ids[i] + + ones_idx = np.where(mask == 1)[0] + first_one_idx = ones_idx[0] + prompt_lens[i] = self.prompt_len - first_one_idx + + # indices of active tokens (1's) + active = mask == 1 + valid_tokens = tokens[active] + inactive_tokens = tokens[~active] + valid_len = len(valid_tokens) + + # Put valid tokens first, followed by the others + new_input_ids[i] = np.concatenate([valid_tokens, inactive_tokens]) + new_attention_mask[i, :valid_len] = 1 + + return new_input_ids, new_attention_mask, prompt_lens + + def pad_input_right(self, input_ids, pad_token_id, attention_mask): + batch_size, seq_len = input_ids.shape + new_input_ids = np.full_like(input_ids, pad_token_id) + new_attention_mask = np.zeros_like(attention_mask) + + for i in range(batch_size): + valid_len = int(attention_mask[i].sum()) + + new_input_ids[i, :valid_len] = input_ids[i, -valid_len:] if valid_len > 0 else [] + new_attention_mask[i, :valid_len] = 1 if valid_len > 0 else 0 + + pad_len = self.MSL - 4096 + last_tokens = new_input_ids[:, -1:] + + pad = np.repeat(last_tokens, pad_len, axis=1) + a = np.concatenate([new_input_ids, pad], axis=1) + + last_tokens = new_attention_mask[:, -1:] + pad = np.repeat(last_tokens, pad_len, axis=1) + b = np.concatenate([new_attention_mask, pad], axis=1) + + return a, b + + + def pad_inputs(self, input_ids, pad_token_id, attention_mask): + batch_size, seq_len = input_ids.shape + + padded_inputs = np.full((batch_size, self.MSL), pad_token_id, dtype=input_ids.dtype) + + padded_inputs[:, :seq_len] = input_ids + + mask = attention_mask.astype(bool) + + padded_inputs[:, :seq_len][~mask] = pad_token_id + return padded_inputs + def pad_length(self, batch): return np.array([np.pad(seq, (0, self.MSL - len(seq)), mode='constant', constant_values=0) for seq in batch]) - + def pad_in_bw(self, batch): batch_size = batch.shape[0] insert_len = batch.shape[1] @@ -80,6 +146,7 @@ def __len__(self): def __getitem__(self, idx): data = {} + logging.info(f"Rahul old log probs = {self._has_log_prob}") if self._has_log_prob: data = { "input_ids": self.input_ids[idx], @@ -95,8 +162,12 @@ def __getitem__(self, idx): "responses": self.responses[idx], "input_ids": self.input_ids[idx], "attention_mask": self.attention_mask[idx], - "position_ids": self.position_ids[idx], + "prompts_len": self.prompts_len[idx], + #"position_ids": self.position_ids[idx], } + #if idx < 10: + # for k,v in data.items(): + # logging.info(f"{k}: shape{v.shape}, first 20:{v[:20]}") return data diff --git a/src/cerebras/modelzoo/models/nlp/rl/model.py b/src/cerebras/modelzoo/models/nlp/rl/model.py index 8dbfb856..cf4eaf76 100644 --- a/src/cerebras/modelzoo/models/nlp/rl/model.py +++ b/src/cerebras/modelzoo/models/nlp/rl/model.py @@ -11,15 +11,16 @@ from cerebras.modelzoo.models.nlp.gpt2.model import GPT2ModelConfig from cerebras.modelzoo.config import ModelConfig from typing_extensions import Annotated +import logging def logprobs_from_logits(logits: torch.Tensor, labels) -> torch.Tensor: """ Implementation taken from verL; modified to fix compile issues on our stack. """ - logp = F.log_softmax(logits, dim=-1) + logp = F.log_softmax(logits, dim=-1) # batch_Size x 3840 X vocab_size one_hot = cstorch.nn.functional.one_hot( labels.to(torch.int64), num_classes=logp.size(-1) - ).to(logp.dtype) + ).to(logp.dtype) # batch_size x 3840 x vocab_size return (logp * one_hot).sum(dim=-1) @@ -43,24 +44,42 @@ def __init__(self, config: RLModelConfig): def forward(self, data): if "old_log_probs" not in data: + logging.info("Rahul inside old log probs") _, logits = self.policy_model( data={ "input_ids": data["input_ids"], "attention_mask": data["attention_mask"], "labels": data["input_ids"], - "position_ids": data["position_ids"], + #"position_ids": data["position_ids"], }, output_logits=True, ) # TODO: We don't divide by temperature here. veRL does something like logits.div_(temperature), # where temperature is present in data.meta_info. - + logging.info(f"rahul logits shape is {logits.shape}") response_length = data["responses"].size(-1) - logits = logits[:, -response_length - 1 : -1, :] # [batch_size, response_length, vocab_size] - old_log_probs = logprobs_from_logits(logits, data["responses"]) - return old_log_probs + logging.info(f"rahul response len is {response_length}") + + B, MSL, V = logits.shape + + # Create batch indices [0, 1, ..., B-1] + batch_idx = torch.arange(B, dtype=torch.int, device=logits.device).unsqueeze(1) # [B, 1] + # For each batch, compute the token indices to extract + token_idx = ( + torch.arange(response_length, dtype=torch.int, device=logits.device).unsqueeze(0) # [1, 3840] + + (data["prompts_len"].unsqueeze(1) - 1) # shift start position per batch + ) # shape [B, 3840] + # Gather logits using advanced indexing + selected_logits = logits[batch_idx, token_idx, :] # [B, 3840, vocab_size] + old_log_probs = logprobs_from_logits(selected_logits, data["responses"]) + return {"old_log_probs": old_log_probs, "input_ids": data["input_ids"], "attention_mask" : data["attention_mask"], "responses" : data["responses"], "prompts_len" : data["prompts_len"]} + + #logits = logits[:, 165:(165+response_length), :] # [batch_size, response_length, vocab_size] + #old_log_probs = logprobs_from_logits(logits, data["responses"]) + #return {"logits": logits.to(torch.float32), "input_ids": data["input_ids"], "attention_mask" : data["attention_mask"], "responses" : data["responses"], "prompts_len" : data["prompts_len"]} + logging.info("Rahul executing training") _, curr_log_probs = self.policy_model( data={ "input_ids": data["input_ids"], diff --git a/src/cerebras/modelzoo/trainer/callbacks/rl_log_probs.py b/src/cerebras/modelzoo/trainer/callbacks/rl_log_probs.py index d6caa978..32800938 100644 --- a/src/cerebras/modelzoo/trainer/callbacks/rl_log_probs.py +++ b/src/cerebras/modelzoo/trainer/callbacks/rl_log_probs.py @@ -1,7 +1,20 @@ import os import numpy as np -import torch +import cerebras.pytorch as cstorch from cerebras.modelzoo.trainer.callbacks import Callback +import torch.nn.functional as F +import torch + +def logprobs_from_logits(logits: torch.Tensor, labels) -> torch.Tensor: + """ + Implementation taken from verL; modified to fix compile issues on our stack. + """ + logp = F.log_softmax(logits, dim=-1) # batch_Size x 3840 X vocab_size + one_hot = cstorch.nn.functional.one_hot( + labels.to(torch.int64), num_classes=logp.size(-1) + ).to(logp.dtype) # batch_size x 3840 x vocab_size + return (logp * one_hot).sum(dim=-1) + class SaveOldLogProbs(Callback): """ @@ -10,7 +23,7 @@ class SaveOldLogProbs(Callback): """ def __init__(self, prefix: str = "oldlp"): - self.out_dir = os.getcwd() + self.out_dir = "/n0/lab/sota-rl-inference/eval_rollouts" self.prefix = prefix self._count = 0 @@ -19,16 +32,32 @@ def on_fit_start(self, trainer, train_dataloader, val_dataloader, loop): self._ready = True def on_after_forward(self, trainer, model, outputs, batch): - if batch is not None and isinstance(batch, dict) and ("old_log_probs" in batch): - # Training step -- can be ignored. - return - - if not isinstance(outputs, torch.Tensor): - # When we're calculating old log probs, we should get a tensor. - return + self.post_process(outputs) + @cstorch.step_closure + def post_process(self, outputs): # Save tensor as old_log_probs shard - arr = outputs.detach().cpu().numpy() + #arr = outputs['output'].numpy() path = os.path.join(self.out_dir, f"{self.prefix}_{self._count:07d}.npz") - np.savez(path, old_log_probs=arr) + os.makedirs(os.path.dirname(path), exist_ok=True) + + #B, MSL, V = outputs['logits'].shape + #response_length = 3840 + + #batch_idx = torch.arange(B, dtype=torch.int).unsqueeze(1) # [B, 1] + + # For each batch, compute the token indices to extract + '''token_idx = ( + torch.arange(response_length, dtype=torch.int).unsqueeze(0) # [1, 3840] + + (outputs["prompts_len"].unsqueeze(1) - 1) # shift start position per batch + ) ''' # shape [B, 3840] + + # Gather logits using advanced indexing + #selected_logits = outputs['logits'][batch_idx, token_idx, :] # [B, 3840, vocab_size] + + #logits = logits[:, prompt_len:(prompt_len+response_length), :] # [batch_size, response_length, vocab_size] + #old_log_probs = logprobs_from_logits(selected_logits, outputs["responses"]) + #return {"old_log_probs": old_log_probs, "input_ids": data["input_ids"], "attention_mask" : data["attention_mask"], "responses" : data["responses"], "prompts_len" : data["prompts_len"]} + + np.savez(path, old_log_probs=outputs['old_log_probs'].numpy(), inputs=outputs['input_ids'].numpy(), mask=outputs['attention_mask'].numpy(), responses=outputs['responses'], promptlen=outputs['prompts_len']) self._count += 1