diff --git a/nemo_rl/models/megatron/common.py b/nemo_rl/models/megatron/common.py index 952f65a7c7..6c83851076 100644 --- a/nemo_rl/models/megatron/common.py +++ b/nemo_rl/models/megatron/common.py @@ -18,10 +18,8 @@ import torch.distributed as dist from megatron.bridge.training.state import GlobalState from megatron.core.models.gpt import GPTModel -from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import ( get_context_parallel_group, - get_context_parallel_rank, get_context_parallel_world_size, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, @@ -31,11 +29,8 @@ get_moe_layer_wise_logging_tracker, reduce_aux_losses_tracker_across_ranks, ) -from megatron.training.utils import get_ltor_masks_and_position_ids - from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank def _round_up_to_multiple(value: int, multiple: int) -> int: @@ -45,312 +40,6 @@ def _round_up_to_multiple(value: int, multiple: int) -> int: else value ) - -def _pack_sequences_for_megatron( - input_ids: torch.Tensor, - seq_lengths: torch.Tensor, - pad_individual_seqs_to_multiple_of: int = 1, - pad_packed_seq_to_multiple_of: int = 1, - pad_packed_seq_to: Optional[int] = None, - cp_rank: int = 0, - cp_size: int = 1, -) -> tuple[torch.Tensor, PackedSeqParams, torch.Tensor, Optional[torch.Tensor]]: - """Pack sequences for Megatron model processing with optional context parallelism. - - Args: - input_ids: Input token IDs [batch_size, seq_length] - seq_lengths: Actual sequence lengths for each sample [batch_size] - pad_individual_seqs_to_multiple_of: Pad individual sequences to a multiple of this value - pad_packed_seq_to_multiple_of: Pad packed sequences to a multiple of this value - pad_packed_seq_to: Pad packed sequences to this value (before CP) - - The three parameters above can be calculated using _get_pack_sequence_parameters_for_megatron, we do not recommend users to set these parameters manually. - cp_size: Context parallelism size - - Returns: - Tuple of: - - packed_input_ids: Packed input tensor [1, T] - - input_ids_cp_sharded: Sharded input tensor [cp_size, T // cp_size] - - packed_seq_params: PackedSeqParams object - - cu_seqlens: Cumulative sequence lengths - - cu_seqlens_padded: Padded cumulative sequence lengths - """ - batch_size = input_ids.shape[0] - - # Build cumulative sequence lengths (cu_seqlens) and extract valid tokens - needs_padding = ( - pad_individual_seqs_to_multiple_of > 1 - or pad_packed_seq_to_multiple_of > 1 - or pad_packed_seq_to is not None - ) - - cu_seqlens = [0] - cu_seqlens_padded = [0] if needs_padding else None - valid_tokens = [] - - # Round up the pad_packed_seq_to to the nearest multiple of pad_packed_seq_to_multiple_of - if pad_packed_seq_to is not None: - pad_packed_seq_to = _round_up_to_multiple( - pad_packed_seq_to, pad_packed_seq_to_multiple_of - ) - - pad_factor = pad_individual_seqs_to_multiple_of - - for b in range(batch_size): - seq_len = ( - seq_lengths[b].item() if torch.is_tensor(seq_lengths[b]) else seq_lengths[b] - ) - - # Extract valid tokens for this sequence - valid_tokens.append(input_ids[b, :seq_len]) - - # Update cumulative sequence lengths - cu_seqlens.append(cu_seqlens[-1] + seq_len) - - # For context parallelism, track padded sequence lengths - if needs_padding: - # Pad sequence length to multiple of (cp_size * 2) - padded_seq_len = _round_up_to_multiple(seq_len, pad_factor) - cu_seqlens_padded.append(cu_seqlens_padded[-1] + padded_seq_len) - - # Convert to tensors - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=input_ids.device) - if needs_padding: - cu_seqlens_padded = torch.tensor( - cu_seqlens_padded, dtype=torch.int32, device=input_ids.device - ) - if pad_packed_seq_to is not None: - cu_seqlens_padded[-1] = pad_packed_seq_to - elif pad_packed_seq_to_multiple_of > 1: - cu_seqlens_padded[-1] = _round_up_to_multiple( - cu_seqlens_padded[-1], pad_packed_seq_to_multiple_of - ) - - # Calculate max sequence length (padded if using CP) - if needs_padding: - seq_lens_padded = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] - max_seqlen = seq_lens_padded.max().item() - else: - seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] - max_seqlen = seq_lens.max().item() - - # Concatenate all valid tokens - # If using individual padding, we need to pad individual sequences - # CP will always need padding (of at least cp_size * 2) - running_seq_len = 0 - if pad_factor > 1: - all_input_ids = [] - padded_tokens = [] - for b in range(batch_size): - seq_len = ( - seq_lengths[b].item() - if torch.is_tensor(seq_lengths[b]) - else seq_lengths[b] - ) - # if last element, pad to the max sequence length - if b == batch_size - 1 and needs_padding: - if pad_packed_seq_to is not None: - padded_seq_len = pad_packed_seq_to - running_seq_len - elif pad_packed_seq_to_multiple_of > 1: - padded_seq_len = _round_up_to_multiple(seq_len, pad_factor) - padded_seq_len = ( - _round_up_to_multiple( - running_seq_len + padded_seq_len, - pad_packed_seq_to_multiple_of, - ) - - running_seq_len - ) - else: - padded_seq_len = _round_up_to_multiple(seq_len, pad_factor) - else: - padded_seq_len = _round_up_to_multiple(seq_len, pad_factor) - - running_seq_len += padded_seq_len - - # Pad this sequence to the required length - seq_tokens = input_ids[b, :seq_len] - if padded_seq_len > seq_len: - # Pad with zeros (or use a padding token if available) - seq_tokens = torch.nn.functional.pad( - seq_tokens, (0, padded_seq_len - seq_len), value=0 - ) - all_input_ids.append(seq_tokens) - - if cp_size > 1: - seq_tokens = _get_tokens_on_this_cp_rank( - seq_tokens, cp_rank, cp_size, seq_dim=0 - ) - - padded_tokens.append(seq_tokens) - - # Concatenate all padded tokens - # For 'thd' format, the shape should be [1, T] where T is total tokens - packed_input_ids = torch.cat(padded_tokens, dim=0).unsqueeze(0) - all_input_ids = torch.cat(all_input_ids, dim=0).unsqueeze(0) - else: - # No individual padding, just concatenate valid tokens - # For 'thd' format, the shape should be [1, T] where T is total tokens - packed_input_ids = torch.cat(valid_tokens, dim=0).unsqueeze(0) - all_input_ids = packed_input_ids - if needs_padding: - if pad_packed_seq_to is not None: - pad_len = pad_packed_seq_to - packed_input_ids.shape[1] - elif pad_packed_seq_to_multiple_of > 1: - current_seq_len = packed_input_ids.shape[1] - pad_this_seq_to = _round_up_to_multiple( - current_seq_len, pad_packed_seq_to_multiple_of - ) - pad_len = pad_this_seq_to - current_seq_len - else: - pad_len = 0 - if pad_len > 0: - packed_input_ids = torch.nn.functional.pad( - packed_input_ids, (0, pad_len), value=0 - ) - all_input_ids = torch.nn.functional.pad( - all_input_ids, (0, pad_len), value=0 - ) - - if cu_seqlens_padded is None: - cu_seqlens_padded = cu_seqlens.clone() - - packed_seq_params = PackedSeqParams( - cu_seqlens_q=cu_seqlens_padded, - cu_seqlens_kv=cu_seqlens_padded, - cu_seqlens_q_padded=cu_seqlens_padded, - cu_seqlens_kv_padded=cu_seqlens_padded, - max_seqlen_q=int(max_seqlen), - max_seqlen_kv=int(max_seqlen), - qkv_format="thd", - ) - - return ( - all_input_ids.contiguous(), - packed_input_ids.contiguous(), - packed_seq_params, - cu_seqlens, - cu_seqlens_padded, - ) - - -def _get_pack_sequence_parameters_for_megatron( - megatron_cfg: dict, - max_seq_len_in_batch: int, -): - """Get pack sequence parameters for Megatron model processing with optional context parallelism. - - Args: - megatron_cfg: Megatron configuration - max_seq_len_in_batch: Maximum sequence length in batch - - Returns: - Tuple of: - - pad_individual_seqs_to_multiple_of: Pad individual sequences to a multiple of this value - - pad_packed_seq_to_multiple_of: Pad packed sequences to a multiple of this value - - pad_packed_seq_to: Pad packed sequences to this value (before CP) - """ - tp_size = megatron_cfg["tensor_model_parallel_size"] - sp = megatron_cfg["sequence_parallel"] - pp_size = megatron_cfg["pipeline_model_parallel_size"] - cp_size = megatron_cfg["context_parallel_size"] - fp8_cfg = megatron_cfg.get("fp8_cfg", None) or {} - use_fp8 = fp8_cfg.get("enabled", False) - use_blockwise_fp8 = fp8_cfg.get("fp8_recipe", None) == "blockwise" - - # individual sequence needs to be splitted to CP domain, and to TP domain when SP is enabled. - pad_individual_seqs_to_multiple_of = 1 - if cp_size > 1: - pad_individual_seqs_to_multiple_of *= cp_size * 2 - if tp_size > 1 and sp: - pad_individual_seqs_to_multiple_of *= tp_size - - # packed sequence length, after splitted to TP and CP domains, needs to be divisible by 128 if using blockwise FP8, and divisible by 16 if using other FP8 recipes. - if use_fp8: - divisor = 128 if use_blockwise_fp8 else 16 - pad_packed_seq_to_multiple_of = divisor - if cp_size > 1: - pad_packed_seq_to_multiple_of *= cp_size * 2 - if tp_size > 1 and sp: - pad_packed_seq_to_multiple_of *= tp_size - else: - pad_packed_seq_to_multiple_of = 1 - - # when PP is used, all sequences must have the same length, so we need to pad the packed sequence to the max sequence length in the batch. - if pp_size > 1: - pad_packed_seq_to = max_seq_len_in_batch - else: - pad_packed_seq_to = None - - return ( - pad_individual_seqs_to_multiple_of, - pad_packed_seq_to_multiple_of, - pad_packed_seq_to, - ) - - -def _unpack_sequences_from_megatron( - output_tensor: torch.Tensor, - seq_lengths: torch.Tensor, - cu_seqlens: torch.Tensor, - cu_seqlens_padded: Optional[torch.Tensor], - original_batch_size: int, - original_seq_length: int, -) -> torch.Tensor: - """Unpack sequences from Megatron output format. - - Args: - output_tensor: Packed output tensor [1, T, vocab_size] - seq_lengths: Actual sequence lengths for each sample - cu_seqlens: Cumulative sequence lengths - cu_seqlens_padded: Padded cumulative sequence lengths (if CP was used) - original_batch_size: Original batch size - original_seq_length: Original maximum sequence length - - Returns: - Unpacked output tensor [batch_size, seq_length, vocab_size] - """ - # Remove the batch dimension to get [T, vocab_size] - output_tensor = output_tensor.squeeze(0) - - # Create a padded output tensor with original shape - vocab_size = output_tensor.shape[-1] - unpacked_output = torch.zeros( - (original_batch_size, original_seq_length, vocab_size), - dtype=output_tensor.dtype, - device=output_tensor.device, - ) - - # Get context parallel size to determine which cu_seqlens to use - cp_size = get_context_parallel_world_size() - - # Fill in the unpacked output tensor with valid tokens - for b in range(original_batch_size): - # Get actual sequence length for this sample - seq_len = ( - seq_lengths[b].item() if torch.is_tensor(seq_lengths[b]) else seq_lengths[b] - ) - - if cp_size > 1 and cu_seqlens_padded is not None: - # When using CP, we need to account for padding - # Calculate the padded sequence boundaries - pad_factor = cp_size * 2 - padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor - start_idx = cu_seqlens_padded[b].item() - - # Only copy the valid tokens (not the padding) - unpacked_output[b, :seq_len] = output_tensor[ - start_idx : start_idx + seq_len - ] - else: - # No CP, use regular cu_seqlens - start_idx = cu_seqlens[b].item() - end_idx = cu_seqlens[b + 1].item() - - # Copy the valid tokens to the unpacked tensor - unpacked_output[b, :seq_len] = output_tensor[start_idx:end_idx] - - return unpacked_output - - def forward_step_arbitrary_loss( state: GlobalState, global_valid_seqs: torch.Tensor, @@ -359,10 +48,6 @@ def forward_step_arbitrary_loss( model: GPTModel, loss_fn: LossFunction, pack_sequences: bool = False, - seq_length_key: Optional[str] = None, - pad_individual_seqs_to_multiple_of: int = 1, - pad_packed_seq_to_multiple_of: int = 1, - pad_full_seq_to: Optional[int] = None, defer_fp32_logits: Optional[bool] = None, cp_normalize: bool = True, policy_cfg: Optional[dict] = None, @@ -377,9 +62,6 @@ def forward_step_arbitrary_loss( model (GPTModel): The GPT Model loss_fn (LossFunction): Loss function to apply pack_sequences (bool): Whether to pack sequences for efficiency - seq_length_key (Optional[str]): Key in data_dict containing actual sequence lengths - pad_individual_seqs_to_multiple_of (int): Pad individual sequences to a multiple of this value - pad_full_seq_to (Optional[int]): Pad packed sequences to this value defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32 cp_normalize (bool): Whether to normalize the loss by the cp_size policy_cfg (Optional[dict]): Policy configuration containing generation parameters @@ -393,63 +75,17 @@ def forward_step_arbitrary_loss( """ straggler_timer = state.straggler_timer - with straggler_timer(bdata=True): - data_dict = next(data_iterator).to("cuda") - input_ids = data_dict["input_ids"] - attention_mask = None - position_ids = None - packed_seq_params = None - - original_batch_size = input_ids.shape[0] - original_seq_length = input_ids.shape[1] - seq_lengths = None # Will be set if using packed sequences - cu_seqlens = None - cu_seqlens_padded = None - - if pack_sequences: - # For packed sequences with padded input, we need sequence lengths - assert seq_length_key is not None, ( - "seq_length_key must be provided for packed sequences" - ) - assert seq_length_key in data_dict, ( - f"{seq_length_key} not found in data_dict" - ) - - # Get sequence lengths and context parallel size - seq_lengths = data_dict[seq_length_key] - - # Pack sequences - ( - input_ids, - input_ids_cp_sharded, - packed_seq_params, - cu_seqlens, - cu_seqlens_padded, - ) = _pack_sequences_for_megatron( - input_ids, - seq_lengths, - pad_individual_seqs_to_multiple_of, - pad_packed_seq_to_multiple_of, - pad_full_seq_to, - cp_rank=get_context_parallel_rank(), - cp_size=get_context_parallel_world_size(), - ) - - # For packed sequences, position_ids and attention_mask are typically None - # The PackedSeqParams handles all necessary sequence information - position_ids = None - attention_mask = None - else: - input_ids_cp_sharded = input_ids - attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - data=input_ids, - eod_token=0, # used for loss_mask, which we don't use - pad_token=0, # used for loss_mask, which we don't use - reset_position_ids=False, - reset_attention_mask=False, - eod_mask_loss=False, - pad_mask_loss=False, - ) + # Get the pre-processed microbatch from the iterator + processed_mb = next(data_iterator) + + # Extract the processed components + data_dict = processed_mb.data_dict + input_ids = processed_mb.input_ids + input_ids_cp_sharded = processed_mb.input_ids_cp_sharded + attention_mask = processed_mb.attention_mask + position_ids = processed_mb.position_ids + packed_seq_params = processed_mb.packed_seq_params + cu_seqlens_padded = processed_mb.cu_seqlens_padded multimodal_data = data_dict.get_multimodal_dict( as_tensors=True, device=input_ids_cp_sharded.device diff --git a/nemo_rl/models/megatron/data.py b/nemo_rl/models/megatron/data.py new file mode 100644 index 0000000000..3537cb9b12 --- /dev/null +++ b/nemo_rl/models/megatron/data.py @@ -0,0 +1,603 @@ +from dataclasses import dataclass +from typing import Any, Iterator, Optional, Tuple + +import torch + +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import ( + get_context_parallel_rank, + get_context_parallel_world_size, +) +from megatron.training.utils import get_ltor_masks_and_position_ids +from nemo_rl.models.megatron.common import _round_up_to_multiple +from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank + + +@dataclass +class ProcessedMicrobatch: + """Container for a processed microbatch ready for model forward pass. + + This dataclass holds both the original data dictionary and the processed + tensors needed for the Megatron model forward pass. + + Attributes: + data_dict: The original BatchedDataDict containing raw batch data + input_ids: Processed input token IDs (may be packed for sequence packing) + input_ids_cp_sharded: Context-parallel sharded input token IDs + attention_mask: Attention mask tensor (None for packed sequences) + position_ids: Position IDs tensor (None for packed sequences) + packed_seq_params: PackedSeqParams for sequence packing (None if not packing) + cu_seqlens_padded: Padded cumulative sequence lengths (None if not packing) + """ + data_dict: BatchedDataDict[Any] + input_ids: torch.Tensor + input_ids_cp_sharded: torch.Tensor + attention_mask: Optional[torch.Tensor] + position_ids: Optional[torch.Tensor] + packed_seq_params: Optional[PackedSeqParams] + cu_seqlens_padded: Optional[torch.Tensor] + + +def make_processed_microbatch_iterator( + raw_iterator: Iterator[BatchedDataDict[Any]], + cfg: dict[str, Any], + seq_length_key: Optional[str], + pad_individual_seqs_to_multiple_of: int, + pad_packed_seq_to_multiple_of: int, + pad_full_seq_to: Optional[int], +) -> Iterator[ProcessedMicrobatch]: + """Wrap a raw microbatch iterator to yield processed microbatches. + + This function takes a raw iterator that yields BatchedDataDict objects and + wraps it to yield ProcessedMicrobatch objects that contain both the original + data and the processed tensors ready for model forward pass. + + Args: + raw_iterator: Iterator yielding raw BatchedDataDict microbatches + cfg: Configuration dictionary containing sequence_packing settings + seq_length_key: Key for sequence length in data dict (required for packing) + pad_individual_seqs_to_multiple_of: Padding multiple for individual sequences + pad_packed_seq_to_multiple_of: Padding multiple for packed sequences + pad_full_seq_to: Target length for full sequence padding (optional) + + Yields: + ProcessedMicrobatch objects containing processed tensors ready for model forward + """ + pack_sequences = cfg["sequence_packing"]["enabled"] + + for data_dict in raw_iterator: + # Move to GPU + data_dict = data_dict.to("cuda") + + # Process the microbatch + ( + input_ids, + input_ids_cp_sharded, + attention_mask, + position_ids, + packed_seq_params, + cu_seqlens_padded, + ) = process_microbatch( + data_dict=data_dict, + seq_length_key=seq_length_key, + pad_individual_seqs_to_multiple_of=pad_individual_seqs_to_multiple_of, + pad_packed_seq_to_multiple_of=pad_packed_seq_to_multiple_of, + pad_full_seq_to=pad_full_seq_to, + pack_sequences=pack_sequences, + ) + + yield ProcessedMicrobatch( + data_dict=data_dict, + input_ids=input_ids, + input_ids_cp_sharded=input_ids_cp_sharded, + attention_mask=attention_mask, + position_ids=position_ids, + packed_seq_params=packed_seq_params, + cu_seqlens_padded=cu_seqlens_padded, + ) + + +def get_microbatch_iterator( + data: BatchedDataDict[Any], + cfg: dict[str, Any], + mbs: int, + seq_length_key: Optional[str] = None, +) -> Tuple[Iterator[ProcessedMicrobatch], int, int, int, int]: + """Create a processed microbatch iterator from a batch of data. + + This function creates an iterator that yields ProcessedMicrobatch objects, + which contain both the original data dictionary and the processed tensors + ready for model forward pass. + + Args: + data: The batch data to create microbatches from + cfg: Configuration dictionary + mbs: Microbatch size + seq_length_key: Key for sequence lengths in data dict (auto-detected if None) + + Returns: + Tuple containing the iterator and metadata + - iterator: Iterator yielding ProcessedMicrobatch objects + - data_iterator_len: Number of microbatches in the iterator + - micro_batch_size: Size of each microbatch + - seq_dim_size: Sequence length dimension size + - padded_seq_length: Padded sequence length for pipeline parallelism (may differ from seq_length) + """ + micro_batch_size = mbs + pad_factor = 1 + pad_full_seq_to = None + pad_packed_seq_to_multiple_of = 1 + + # Auto-detect seq_length_key if not provided + if seq_length_key is None and cfg["sequence_packing"]["enabled"]: + seq_length_key = "input_lengths" + + if cfg["dynamic_batching"]["enabled"]: + raw_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + data_iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + elif cfg["sequence_packing"]["enabled"]: + raw_iterator = data.make_microbatch_iterator_for_packable_sequences() + data_iterator_len, pack_seq_dim_size = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + ( + pad_factor, + pad_packed_seq_to_multiple_of, + pad_full_seq_to, + ) = _get_pack_sequence_parameters_for_megatron( + cfg["megatron_cfg"], + pack_seq_dim_size, + ) + micro_batch_size = 1 + else: + raw_iterator = data.make_microbatch_iterator(mbs) + data_iterator_len = data.size // mbs + + _, seq_dim_size = check_sequence_dim(data) + + # Wrap the raw iterator with processing + processed_iterator = make_processed_microbatch_iterator( + raw_iterator=raw_iterator, + cfg=cfg, + seq_length_key=seq_length_key, + pad_individual_seqs_to_multiple_of=pad_factor, + pad_packed_seq_to_multiple_of=pad_packed_seq_to_multiple_of, + pad_full_seq_to=pad_full_seq_to, + ) + + # Compute padded sequence length for pipeline parallelism + padded_seq_length = pad_full_seq_to if pad_full_seq_to is not None else seq_dim_size + + return ( + processed_iterator, + data_iterator_len, + micro_batch_size, + seq_dim_size, + padded_seq_length, + ) + +def process_microbatch( + data_dict: BatchedDataDict[Any], + seq_length_key: Optional[str] = None, + pad_individual_seqs_to_multiple_of: int = 1, + pad_packed_seq_to_multiple_of: int = 1, + pad_full_seq_to: Optional[int] = None, + pack_sequences: bool = False, +): + #with straggler_timer(bdata=True): + input_ids = data_dict["input_ids"] + attention_mask = None + position_ids = None + packed_seq_params = None + + original_batch_size = input_ids.shape[0] + original_seq_length = input_ids.shape[1] + seq_lengths = None # Will be set if using packed sequences + cu_seqlens = None + cu_seqlens_padded = None + + if pack_sequences: + # For packed sequences with padded input, we need sequence lengths + assert seq_length_key is not None, ( + "seq_length_key must be provided for packed sequences" + ) + assert seq_length_key in data_dict, ( + f"{seq_length_key} not found in data_dict" + ) + + # Get sequence lengths and context parallel size + seq_lengths = data_dict[seq_length_key] + + # Pack sequences + ( + input_ids, + input_ids_cp_sharded, + packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) = _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of, + pad_packed_seq_to_multiple_of, + pad_full_seq_to, + cp_rank=get_context_parallel_rank(), + cp_size=get_context_parallel_world_size(), + ) + + # For packed sequences, position_ids and attention_mask are typically None + # The PackedSeqParams handles all necessary sequence information + position_ids = None + attention_mask = None + else: + input_ids_cp_sharded = input_ids + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + data=input_ids, + eod_token=0, # used for loss_mask, which we don't use + pad_token=0, # used for loss_mask, which we don't use + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + pad_mask_loss=False, + ) + return ( + input_ids, + input_ids_cp_sharded, + attention_mask, + position_ids, + packed_seq_params, + cu_seqlens_padded, + ) + +def process_global_batch( + data: BatchedDataDict[Any], + batch_idx: int, + batch_size: int, + loss_fn: LossFunction, + dp_group: torch.distributed.ProcessGroup, +) -> dict[str, Any]: + batch = data.get_batch(batch_idx=batch_idx, batch_size=batch_size) + + assert "sample_mask" in batch, "sample_mask must be present in the data!" + + # Get the normalization factor for the loss + local_valid_seqs = torch.sum(batch["sample_mask"]) + + if "token_mask" not in batch: + local_valid_toks = local_valid_seqs * batch["input_ids"].shape[1] + else: + local_valid_toks = torch.sum( + batch["token_mask"][:, 1:] * batch["sample_mask"].unsqueeze(-1) + ) + + to_reduce = torch.tensor([local_valid_seqs, local_valid_toks]).cuda() + torch.distributed.all_reduce(to_reduce, group=dp_group) + global_valid_seqs, global_valid_toks = to_reduce[0], to_reduce[1] + + if hasattr(loss_fn, "loss_type") and loss_fn.loss_type == LossType.TOKEN_LEVEL: + assert "token_mask" in batch, ( + "token_mask must be present in the data when using token-level loss" + ) + + return ( + batch, + global_valid_seqs, + global_valid_toks, + ) + +def _pack_sequences_for_megatron( + input_ids: torch.Tensor, + seq_lengths: torch.Tensor, + pad_individual_seqs_to_multiple_of: int = 1, + pad_packed_seq_to_multiple_of: int = 1, + pad_packed_seq_to: Optional[int] = None, + cp_rank: int = 0, + cp_size: int = 1, +) -> tuple[torch.Tensor, PackedSeqParams, torch.Tensor, Optional[torch.Tensor]]: + """Pack sequences for Megatron model processing with optional context parallelism. + + Args: + input_ids: Input token IDs [batch_size, seq_length] + seq_lengths: Actual sequence lengths for each sample [batch_size] + pad_individual_seqs_to_multiple_of: Pad individual sequences to a multiple of this value + pad_packed_seq_to_multiple_of: Pad packed sequences to a multiple of this value + pad_packed_seq_to: Pad packed sequences to this value (before CP) + - The three parameters above can be calculated using _get_pack_sequence_parameters_for_megatron, we do not recommend users to set these parameters manually. + cp_size: Context parallelism size + + Returns: + Tuple of: + - packed_input_ids: Packed input tensor [1, T] + - input_ids_cp_sharded: Sharded input tensor [cp_size, T // cp_size] + - packed_seq_params: PackedSeqParams object + - cu_seqlens: Cumulative sequence lengths + - cu_seqlens_padded: Padded cumulative sequence lengths + """ + batch_size = input_ids.shape[0] + + # Build cumulative sequence lengths (cu_seqlens) and extract valid tokens + needs_padding = ( + pad_individual_seqs_to_multiple_of > 1 + or pad_packed_seq_to_multiple_of > 1 + or pad_packed_seq_to is not None + ) + + cu_seqlens = [0] + cu_seqlens_padded = [0] if needs_padding else None + valid_tokens = [] + + # Round up the pad_packed_seq_to to the nearest multiple of pad_packed_seq_to_multiple_of + if pad_packed_seq_to is not None: + pad_packed_seq_to = _round_up_to_multiple( + pad_packed_seq_to, pad_packed_seq_to_multiple_of + ) + + pad_factor = pad_individual_seqs_to_multiple_of + + for b in range(batch_size): + seq_len = ( + seq_lengths[b].item() if torch.is_tensor(seq_lengths[b]) else seq_lengths[b] + ) + + # Extract valid tokens for this sequence + valid_tokens.append(input_ids[b, :seq_len]) + + # Update cumulative sequence lengths + cu_seqlens.append(cu_seqlens[-1] + seq_len) + + # For context parallelism, track padded sequence lengths + if needs_padding: + # Pad sequence length to multiple of (cp_size * 2) + padded_seq_len = _round_up_to_multiple(seq_len, pad_factor) + cu_seqlens_padded.append(cu_seqlens_padded[-1] + padded_seq_len) + + # Convert to tensors + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=input_ids.device) + if needs_padding: + cu_seqlens_padded = torch.tensor( + cu_seqlens_padded, dtype=torch.int32, device=input_ids.device + ) + if pad_packed_seq_to is not None: + cu_seqlens_padded[-1] = pad_packed_seq_to + elif pad_packed_seq_to_multiple_of > 1: + cu_seqlens_padded[-1] = _round_up_to_multiple( + cu_seqlens_padded[-1], pad_packed_seq_to_multiple_of + ) + + # Calculate max sequence length (padded if using CP) + if needs_padding: + seq_lens_padded = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + max_seqlen = seq_lens_padded.max().item() + else: + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + + # Concatenate all valid tokens + # If using individual padding, we need to pad individual sequences + # CP will always need padding (of at least cp_size * 2) + running_seq_len = 0 + if pad_factor > 1: + all_input_ids = [] + padded_tokens = [] + for b in range(batch_size): + seq_len = ( + seq_lengths[b].item() + if torch.is_tensor(seq_lengths[b]) + else seq_lengths[b] + ) + # if last element, pad to the max sequence length + if b == batch_size - 1 and needs_padding: + if pad_packed_seq_to is not None: + padded_seq_len = pad_packed_seq_to - running_seq_len + elif pad_packed_seq_to_multiple_of > 1: + padded_seq_len = _round_up_to_multiple(seq_len, pad_factor) + padded_seq_len = ( + _round_up_to_multiple( + running_seq_len + padded_seq_len, + pad_packed_seq_to_multiple_of, + ) + - running_seq_len + ) + else: + padded_seq_len = _round_up_to_multiple(seq_len, pad_factor) + else: + padded_seq_len = _round_up_to_multiple(seq_len, pad_factor) + + running_seq_len += padded_seq_len + + # Pad this sequence to the required length + seq_tokens = input_ids[b, :seq_len] + if padded_seq_len > seq_len: + # Pad with zeros (or use a padding token if available) + seq_tokens = torch.nn.functional.pad( + seq_tokens, (0, padded_seq_len - seq_len), value=0 + ) + all_input_ids.append(seq_tokens) + + if cp_size > 1: + seq_tokens = _get_tokens_on_this_cp_rank( + seq_tokens, cp_rank, cp_size, seq_dim=0 + ) + + padded_tokens.append(seq_tokens) + + # Concatenate all padded tokens + # For 'thd' format, the shape should be [1, T] where T is total tokens + packed_input_ids = torch.cat(padded_tokens, dim=0).unsqueeze(0) + all_input_ids = torch.cat(all_input_ids, dim=0).unsqueeze(0) + else: + # No individual padding, just concatenate valid tokens + # For 'thd' format, the shape should be [1, T] where T is total tokens + packed_input_ids = torch.cat(valid_tokens, dim=0).unsqueeze(0) + all_input_ids = packed_input_ids + if needs_padding: + if pad_packed_seq_to is not None: + pad_len = pad_packed_seq_to - packed_input_ids.shape[1] + elif pad_packed_seq_to_multiple_of > 1: + current_seq_len = packed_input_ids.shape[1] + pad_this_seq_to = _round_up_to_multiple( + current_seq_len, pad_packed_seq_to_multiple_of + ) + pad_len = pad_this_seq_to - current_seq_len + else: + pad_len = 0 + if pad_len > 0: + packed_input_ids = torch.nn.functional.pad( + packed_input_ids, (0, pad_len), value=0 + ) + all_input_ids = torch.nn.functional.pad( + all_input_ids, (0, pad_len), value=0 + ) + + if cu_seqlens_padded is None: + cu_seqlens_padded = cu_seqlens.clone() + + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=int(max_seqlen), + max_seqlen_kv=int(max_seqlen), + qkv_format="thd", + ) + + return ( + all_input_ids.contiguous(), + packed_input_ids.contiguous(), + packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) + + +def _get_pack_sequence_parameters_for_megatron( + megatron_cfg: dict, + max_seq_len_in_batch: int, +): + """Get pack sequence parameters for Megatron model processing with optional context parallelism. + + Args: + megatron_cfg: Megatron configuration + max_seq_len_in_batch: Maximum sequence length in batch + + Returns: + Tuple of: + - pad_individual_seqs_to_multiple_of: Pad individual sequences to a multiple of this value + - pad_packed_seq_to_multiple_of: Pad packed sequences to a multiple of this value + - pad_packed_seq_to: Pad packed sequences to this value (before CP) + """ + tp_size = megatron_cfg["tensor_model_parallel_size"] + sp = megatron_cfg["sequence_parallel"] + pp_size = megatron_cfg["pipeline_model_parallel_size"] + cp_size = megatron_cfg["context_parallel_size"] + fp8_cfg = megatron_cfg.get("fp8_cfg", None) or {} + use_fp8 = fp8_cfg.get("enabled", False) + use_blockwise_fp8 = fp8_cfg.get("fp8_recipe", None) == "blockwise" + + # individual sequence needs to be splitted to CP domain, and to TP domain when SP is enabled. + pad_individual_seqs_to_multiple_of = 1 + if cp_size > 1: + pad_individual_seqs_to_multiple_of *= cp_size * 2 + if tp_size > 1 and sp: + pad_individual_seqs_to_multiple_of *= tp_size + + # packed sequence length, after splitted to TP and CP domains, needs to be divisible by 128 if using blockwise FP8, and divisible by 16 if using other FP8 recipes. + if use_fp8: + divisor = 128 if use_blockwise_fp8 else 16 + pad_packed_seq_to_multiple_of = divisor + if cp_size > 1: + pad_packed_seq_to_multiple_of *= cp_size * 2 + if tp_size > 1 and sp: + pad_packed_seq_to_multiple_of *= tp_size + else: + pad_packed_seq_to_multiple_of = 1 + + # when PP is used, all sequences must have the same length, so we need to pad the packed sequence to the max sequence length in the batch. + if pp_size > 1: + pad_packed_seq_to = max_seq_len_in_batch + else: + pad_packed_seq_to = None + + return ( + pad_individual_seqs_to_multiple_of, + pad_packed_seq_to_multiple_of, + pad_packed_seq_to, + ) + + +def _unpack_sequences_from_megatron( + output_tensor: torch.Tensor, + seq_lengths: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqlens_padded: Optional[torch.Tensor], + original_batch_size: int, + original_seq_length: int, +) -> torch.Tensor: + """Unpack sequences from Megatron output format. + + Args: + output_tensor: Packed output tensor [1, T, vocab_size] + seq_lengths: Actual sequence lengths for each sample + cu_seqlens: Cumulative sequence lengths + cu_seqlens_padded: Padded cumulative sequence lengths (if CP was used) + original_batch_size: Original batch size + original_seq_length: Original maximum sequence length + + Returns: + Unpacked output tensor [batch_size, seq_length, vocab_size] + """ + # Remove the batch dimension to get [T, vocab_size] + output_tensor = output_tensor.squeeze(0) + + # Create a padded output tensor with original shape + vocab_size = output_tensor.shape[-1] + unpacked_output = torch.zeros( + (original_batch_size, original_seq_length, vocab_size), + dtype=output_tensor.dtype, + device=output_tensor.device, + ) + + # Get context parallel size to determine which cu_seqlens to use + cp_size = get_context_parallel_world_size() + + # Fill in the unpacked output tensor with valid tokens + for b in range(original_batch_size): + # Get actual sequence length for this sample + seq_len = ( + seq_lengths[b].item() if torch.is_tensor(seq_lengths[b]) else seq_lengths[b] + ) + + if cp_size > 1 and cu_seqlens_padded is not None: + # When using CP, we need to account for padding + # Calculate the padded sequence boundaries + pad_factor = cp_size * 2 + padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor + start_idx = cu_seqlens_padded[b].item() + + # Only copy the valid tokens (not the padding) + unpacked_output[b, :seq_len] = output_tensor[ + start_idx : start_idx + seq_len + ] + else: + # No CP, use regular cu_seqlens + start_idx = cu_seqlens[b].item() + end_idx = cu_seqlens[b + 1].item() + + # Copy the valid tokens to the unpacked tensor + unpacked_output[b, :seq_len] = output_tensor[start_idx:end_idx] + + return unpacked_output + +def check_sequence_dim(data: BatchedDataDict[Any]): + # dim 1 is always assumed to be the sequence dim, sanity check this here + sequence_dim = 1 + seq_dim_size = data["input_ids"].shape[sequence_dim] + for k, v in data.items(): + if torch.is_tensor(v) and len(v.shape) > 1: + assert v.shape[sequence_dim] == seq_dim_size, ( + f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" + ) + return sequence_dim, seq_dim_size \ No newline at end of file diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 9d03f6c06b..4560dc2ed2 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -75,7 +75,6 @@ from megatron.core.optimizer import ChainedOptimizer from megatron.core.parallel_state import ( get_context_parallel_group, - get_context_parallel_rank, get_pipeline_model_parallel_group, get_pipeline_model_parallel_last_rank, get_pipeline_model_parallel_world_size, @@ -87,11 +86,10 @@ from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.transformer.module import Float16Module from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training.utils import get_ltor_masks_and_position_ids from ray.util.queue import Queue from transformers import PreTrainedTokenizerBase -from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.algorithms.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, @@ -111,12 +109,14 @@ ) from nemo_rl.models.generation.vllm.config import VllmConfig from nemo_rl.models.megatron.common import ( - _get_pack_sequence_parameters_for_megatron, - _pack_sequences_for_megatron, broadcast_tensor, forward_step_arbitrary_loss, get_moe_metrics, ) +from nemo_rl.models.megatron.data import ( + get_microbatch_iterator, + process_global_batch, +) from nemo_rl.models.megatron.community_import import import_model_from_hf_name from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ( @@ -975,15 +975,6 @@ def train( self.model.train() with ctx: - # dim 1 is always assumed to be the sequence dim, sanity check this here - sequence_dim = 1 - seq_dim_size = data["input_ids"].shape[sequence_dim] - for k, v in data.items(): - if torch.is_tensor(v) and len(v.shape) > 1: - assert v.shape[sequence_dim] == seq_dim_size, ( - f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" - ) - forward_step = partial( forward_step_arbitrary_loss, loss_fn=loss_fn, policy_cfg=self.cfg ) @@ -991,73 +982,27 @@ def train( losses = [] total_num_microbatches = 0 for gb_idx in range(num_global_batches): - global_batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) - - assert "sample_mask" in global_batch, ( - "sample_mask must be present in the data!" - ) - ## get the normalization factor for the loss - local_valid_seqs = torch.sum(global_batch["sample_mask"]) - - if not "token_mask" in global_batch: - local_valid_toks = ( - local_valid_seqs * global_batch["input_ids"].shape[1] - ) - else: - local_valid_toks = torch.sum( - global_batch["token_mask"][:, 1:] - * global_batch["sample_mask"].unsqueeze(-1) - ) - - to_reduce = torch.tensor([local_valid_seqs, local_valid_toks]).cuda() - torch.distributed.all_reduce( - to_reduce, group=parallel_state.get_data_parallel_group() + ( + batch, + global_valid_seqs, + global_valid_toks, + ) = process_global_batch( + data, + gb_idx, + local_gbs, + loss_fn, + dp_group=parallel_state.get_data_parallel_group(), ) - global_valid_seqs, global_valid_toks = to_reduce[0], to_reduce[1] - - if ( - hasattr(loss_fn, "loss_type") - and loss_fn.loss_type == LossType.TOKEN_LEVEL - ): - assert "token_mask" in global_batch, ( - "token_mask must be present in the data when using token-level loss" - ) - - batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) - pack_seqs = False - seqlen_key = None - pad_factor = 1 - pad_full_seq_to = None - pad_packed_seq_to_multiple_of = 1 - if self.cfg["dynamic_batching"]["enabled"]: - data_iterator = batch.make_microbatch_iterator_with_dynamic_shapes() - data_iterator_len = ( - batch.get_microbatch_iterator_dynamic_shapes_len() - ) - elif self.cfg["sequence_packing"]["enabled"]: - data_iterator = ( - batch.make_microbatch_iterator_for_packable_sequences() - ) - data_iterator_len, seq_dim_size = ( - batch.get_microbatch_iterator_for_packable_sequences_len() - ) - mbs = 1 - pack_seqs = True - seqlen_key = "input_lengths" - ( - pad_factor, - pad_packed_seq_to_multiple_of, - pad_full_seq_to, - ) = _get_pack_sequence_parameters_for_megatron( - self.cfg["megatron_cfg"], - seq_dim_size, - ) - else: - data_iterator = batch.make_microbatch_iterator(mbs) - data_iterator_len = local_gbs // mbs + ( + data_iterator, + num_microbatches, + micro_batch_size, + seq_length, + padded_seq_length, + ) = get_microbatch_iterator(batch, self.cfg, mbs) # Track total microbatches for MoE aux-loss averaging - total_num_microbatches += int(data_iterator_len) + total_num_microbatches += int(num_microbatches) rerun_state_machine = get_rerun_state_machine() while rerun_state_machine.should_run_forward_backward(data_iterator): @@ -1073,16 +1018,12 @@ def train( self.mcore_state, global_valid_seqs, global_valid_toks, - pack_sequences=pack_seqs, - seq_length_key=seqlen_key, - pad_individual_seqs_to_multiple_of=pad_factor, - pad_packed_seq_to_multiple_of=pad_packed_seq_to_multiple_of, - pad_full_seq_to=pad_full_seq_to, + pack_sequences=self.cfg["sequence_packing"]["enabled"], defer_fp32_logits=self.defer_fp32_logits, ), data_iterator=data_iterator, model=self.model, - num_microbatches=data_iterator_len, + num_microbatches=num_microbatches, seq_length=seq_dim_size, micro_batch_size=mbs, decoder_seq_length=seq_dim_size, @@ -1225,74 +1166,31 @@ def get_logprobs( else self.cfg["logprob_batch_size"] ) - # dim 1 is always assumed to be the sequence dim, sanity check this here - sequence_dim = 1 - input_seq_dim_size = data["input_ids"].shape[sequence_dim] - for k, v in data.items(): - if torch.is_tensor(v) and len(v.shape) > 1: - assert v.shape[sequence_dim] == input_seq_dim_size, ( - f"Dim 1 must be the sequence dim, expected dim 1={input_seq_dim_size} but got shape {v.shape}" - ) - self.model.eval() - pp_seq_dim_size = input_seq_dim_size pp_grp = get_pipeline_model_parallel_group() - pad_factor = 1 - pad_packed_seq_to_multiple_of = 1 - pad_full_seq_to = None - if self.cfg["sequence_packing"]["enabled"]: - _, seq_dim_size = data.get_microbatch_iterator_for_packable_sequences_len() - ( - pad_factor, - pad_packed_seq_to_multiple_of, - pad_full_seq_to, - ) = _get_pack_sequence_parameters_for_megatron( - self.cfg["megatron_cfg"], - seq_dim_size, - ) - pp_seq_dim_size = pad_full_seq_to or pp_seq_dim_size + + ( + mb_iterator, + num_microbatches, + micro_batch_size, + seq_length, + padded_seq_length, + ) = get_microbatch_iterator(data, self.cfg, logprob_batch_size) def forward_step_fn( data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel ): - nonlocal pad_full_seq_to, pad_packed_seq_to_multiple_of, pad_factor - data_dict = next(data_iterator).to("cuda") - if self.cfg["sequence_packing"]["enabled"]: - original_seq_length = data_dict["input_ids"].shape[1] - cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] - cp_rank = get_context_parallel_rank() - ( - input_ids, - input_ids_cp_sharded, - packed_seq_params, - cu_seqlens, - cu_seqlens_padded, - ) = _pack_sequences_for_megatron( - data_dict["input_ids"].clone(), - data_dict["input_lengths"], - pad_individual_seqs_to_multiple_of=pad_factor, - pad_packed_seq_to_multiple_of=pad_packed_seq_to_multiple_of, - pad_packed_seq_to=pad_full_seq_to, - cp_rank=cp_rank, - cp_size=cp_size, - ) - attention_mask, position_ids = None, None - unpacked_input_ids = data_dict["input_ids"] - else: - input_ids = data_dict["input_ids"] - input_ids_cp_sharded = input_ids - attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - data=input_ids, - eod_token=0, # used for loss_mask, which we don't use - pad_token=0, # used for loss_mask, which we don't use - reset_position_ids=False, - reset_attention_mask=False, - eod_mask_loss=False, - pad_mask_loss=False, - ) - packed_seq_params = None - unpacked_input_ids = input_ids + processed_mb = next(data_iterator) + # Extract the processed components + data_dict = processed_mb.data_dict + input_ids = processed_mb.input_ids + input_ids_cp_sharded = processed_mb.input_ids_cp_sharded + attention_mask = processed_mb.attention_mask + position_ids = processed_mb.position_ids + packed_seq_params = processed_mb.packed_seq_params + cu_seqlens_padded = processed_mb.cu_seqlens_padded + unpacked_input_ids = data_dict["input_ids"] multimodal_data = data_dict.get_multimodal_dict( as_tensors=True, device=input_ids.device @@ -1331,7 +1229,7 @@ def collection_fn(output_tensor): output_tensor, target=input_ids, cu_seqlens_padded=cu_seqlens_padded, - unpacked_seqlen=original_seq_length, + unpacked_seqlen=seq_length, vocab_start_index=tp_rank * output_tensor.shape[-1], vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], group=tp_grp, @@ -1360,37 +1258,22 @@ def collection_fn(output_tensor): return output_tensor, collection_fn - if self.cfg["dynamic_batching"]["enabled"]: - mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() - data_iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() - micro_batch_size = logprob_batch_size - elif self.cfg["sequence_packing"]["enabled"]: - mb_iterator = data.make_microbatch_iterator_for_packable_sequences() - data_iterator_len, _ = ( - data.get_microbatch_iterator_for_packable_sequences_len() - ) - micro_batch_size = 1 - else: - mb_iterator = data.make_microbatch_iterator(logprob_batch_size) - data_iterator_len = max(1, data.size // logprob_batch_size) - micro_batch_size = logprob_batch_size - forward_backward_func = get_forward_backward_func() list_of_logprobs = forward_backward_func( forward_step_func=forward_step_fn, data_iterator=mb_iterator, model=self.model, - num_microbatches=data_iterator_len, - seq_length=pp_seq_dim_size, + num_microbatches=num_microbatches, + seq_length=padded_seq_length, micro_batch_size=micro_batch_size, - decoder_seq_length=pp_seq_dim_size, + decoder_seq_length=padded_seq_length, forward_only=True, ) if is_pipeline_last_stage(ignore_virtual=True): all_log_probs_padded = [] all_logprobs = [l["logprobs"] for l in list_of_logprobs] for lp in all_logprobs: - padding_needed = input_seq_dim_size - lp.shape[1] + padding_needed = seq_length - lp.shape[1] if padding_needed > 0: lp = torch.nn.functional.pad( lp, (0, padding_needed), mode="constant", value=0.0 @@ -1485,77 +1368,31 @@ def get_topk_logits( else self.cfg["logprob_batch_size"] ) - sequence_dim = 1 - input_seq_dim_size = data["input_ids"].shape[sequence_dim] - # Avoid shadowing the function argument `k` by using a distinct variable name - for tensor_name, v in data.items(): - if torch.is_tensor(v) and len(v.shape) > 1: - assert v.shape[sequence_dim] == input_seq_dim_size, ( - f"Tensor {tensor_name} must have sequence dimension {sequence_dim} of size {input_seq_dim_size}, but got shape {v.shape}" - ) - self.model.eval() - pp_seq_dim_size = input_seq_dim_size pp_grp = get_pipeline_model_parallel_group() - pad_factor = 1 - pad_packed_seq_to_multiple_of = 1 - pad_full_seq_to = None - if self.cfg["sequence_packing"]["enabled"]: - _, seq_dim_size = data.get_microbatch_iterator_for_packable_sequences_len() - ( - pad_factor, - pad_packed_seq_to_multiple_of, - pad_full_seq_to, - ) = _get_pack_sequence_parameters_for_megatron( - self.cfg["megatron_cfg"], - seq_dim_size, - ) - pp_seq_dim_size = pad_full_seq_to or pp_seq_dim_size + ( + mb_iterator, + num_microbatches, + micro_batch_size, + seq_length, + padded_seq_length, + ) = get_microbatch_iterator(data, self.cfg, logprob_batch_size) def forward_step_fn( data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel ): - nonlocal pad_full_seq_to, pad_packed_seq_to_multiple_of, pad_factor - data_dict = next(data_iterator).to("cuda") - - pack = self.cfg["sequence_packing"]["enabled"] - if pack: - original_seq_length = data_dict["input_ids"].shape[1] - cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] - cp_rank = get_context_parallel_rank() - - ( - input_ids_unpacked, - input_ids_cp_sharded, - packed_seq_params, - cu_seqlens, - cu_seqlens_padded, - ) = _pack_sequences_for_megatron( - data_dict["input_ids"].clone(), - data_dict["input_lengths"], - pad_individual_seqs_to_multiple_of=pad_factor, - pad_packed_seq_to_multiple_of=pad_packed_seq_to_multiple_of, - pad_packed_seq_to=pad_full_seq_to, - cp_rank=cp_rank, - cp_size=cp_size, - ) - attention_mask, position_ids = None, None - seq_lengths = data_dict["input_lengths"] - unpacked_seqlen = original_seq_length - else: - input_ids_cp_sharded = data_dict["input_ids"] - attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - data=input_ids_cp_sharded, - eod_token=0, - pad_token=0, - reset_position_ids=False, - reset_attention_mask=False, - eod_mask_loss=False, - pad_mask_loss=False, - ) - packed_seq_params = None + processed_mb = next(data_iterator).to("cuda") + # Extract the processed components + data_dict = processed_mb.data_dict + input_ids = processed_mb.input_ids + input_ids_cp_sharded = processed_mb.input_ids_cp_sharded + attention_mask = processed_mb.attention_mask + position_ids = processed_mb.position_ids + packed_seq_params = processed_mb.packed_seq_params + cu_seqlens_padded = processed_mb.cu_seqlens_padded + unpacked_input_ids = data_dict["input_ids"] multimodal_data = data_dict.get_multimodal_dict( as_tensors=True, device=input_ids_cp_sharded.device @@ -1603,7 +1440,8 @@ def collection_fn(_): if self.cfg["megatron_cfg"]["context_parallel_size"] > 1: cp_grp = get_context_parallel_group() - if pack: + if self.cfg["sequence_packing"]["enabled"]: + cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] # Per-sequence CP allgather following packed-sequence logic batch_size = data_dict["input_ids"].shape[0] total_packed_len = int(cu_seqlens_padded[-1].item()) @@ -1663,15 +1501,16 @@ def collection_fn(_): topk_vals_full = topk_vals_local topk_idx_full = topk_idx_local - if pack: + if self.cfg["sequence_packing"]["enabled"]: batch_size = data_dict["input_ids"].shape[0] + seq_lengths = data_dict["input_lengths"] out_vals = torch.zeros( - (batch_size, unpacked_seqlen, k), + (batch_size, seq_length, k), dtype=topk_vals_full.dtype, device=topk_vals_full.device, ) out_idx = torch.zeros( - (batch_size, unpacked_seqlen, k), + (batch_size, seq_length, k), dtype=topk_idx_full.dtype, device=topk_idx_full.device, ) @@ -1697,30 +1536,15 @@ def collection_fn(_): return output_tensor, collection_fn - if self.cfg["dynamic_batching"]["enabled"]: - mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() - data_iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() - micro_batch = logprob_batch_size - elif self.cfg["sequence_packing"]["enabled"]: - mb_iterator = data.make_microbatch_iterator_for_packable_sequences() - data_iterator_len, _ = ( - data.get_microbatch_iterator_for_packable_sequences_len() - ) - micro_batch = 1 - else: - mb_iterator = data.make_microbatch_iterator(logprob_batch_size) - data_iterator_len = max(1, data.size // logprob_batch_size) - micro_batch = logprob_batch_size - forward_backward_func = get_forward_backward_func() list_of_outputs = forward_backward_func( forward_step_func=forward_step_fn, data_iterator=mb_iterator, model=self.model, - num_microbatches=data_iterator_len, - seq_length=pp_seq_dim_size, - micro_batch_size=micro_batch, - decoder_seq_length=pp_seq_dim_size, + num_microbatches=num_microbatches, + seq_length=padded_seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=padded_seq_length, forward_only=True, ) @@ -1730,7 +1554,7 @@ def collection_fn(_): for out in list_of_outputs: tk = out["topk_logits"] ti = out["topk_indices"] - pad_len = input_seq_dim_size - tk.shape[1] + pad_len = padded_seq_length - tk.shape[1] if pad_len > 0: tk = torch.nn.functional.pad(tk, (0, 0, 0, pad_len), value=0.0) ti = torch.nn.functional.pad(ti, (0, 0, 0, pad_len), value=0) diff --git a/tests/unit/algorithms/test_sequence_packing_gradients.py b/tests/unit/algorithms/test_sequence_packing_gradients.py index f586e999c0..7d46252e7c 100644 --- a/tests/unit/algorithms/test_sequence_packing_gradients.py +++ b/tests/unit/algorithms/test_sequence_packing_gradients.py @@ -42,9 +42,9 @@ def __init__(self, cp_size): def test_sequence_packing_gradients(self): from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank from nemo_rl.models.megatron.common import ( - _pack_sequences_for_megatron, forward_step_arbitrary_loss, ) + from nemo_rl.models.megatron.data import _pack_sequences_for_megatron # Initialize process group torch.distributed.init_process_group(backend="nccl") diff --git a/tests/unit/models/megatron/test_common.py b/tests/unit/models/megatron/test_megatron_data.py similarity index 66% rename from tests/unit/models/megatron/test_common.py rename to tests/unit/models/megatron/test_megatron_data.py index 3ee2685af5..04878c5208 100644 --- a/tests/unit/models/megatron/test_common.py +++ b/tests/unit/models/megatron/test_megatron_data.py @@ -11,12 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +Unit tests for Megatron data utilities. + +This module tests the data processing functions in nemo_rl.models.megatron.data, +focusing on: +- Microbatch processing and iteration +- Sequence packing and unpacking +- Global batch processing +- Sequence dimension validation +""" + import os +from unittest.mock import MagicMock, patch import pytest import ray import torch +from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.distributed.ray_actor_environment_registry import ( ACTOR_ENVIRONMENT_REGISTRY, @@ -26,6 +40,564 @@ from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup +class TestProcessedMicrobatchDataclass: + """Tests for ProcessedMicrobatch dataclass.""" + + def test_processed_microbatch_fields(self): + """Test that ProcessedMicrobatch has all expected fields.""" + from nemo_rl.models.megatron.data import ProcessedMicrobatch + + mock_data_dict = MagicMock() + mock_input_ids = torch.tensor([[1, 2, 3]]) + mock_input_ids_cp_sharded = torch.tensor([[1, 2, 3]]) + mock_attention_mask = torch.tensor([[1, 1, 1]]) + mock_position_ids = torch.tensor([[0, 1, 2]]) + mock_packed_seq_params = MagicMock() + mock_cu_seqlens_padded = torch.tensor([0, 3]) + + microbatch = ProcessedMicrobatch( + data_dict=mock_data_dict, + input_ids=mock_input_ids, + input_ids_cp_sharded=mock_input_ids_cp_sharded, + attention_mask=mock_attention_mask, + position_ids=mock_position_ids, + packed_seq_params=mock_packed_seq_params, + cu_seqlens_padded=mock_cu_seqlens_padded, + ) + + assert microbatch.data_dict == mock_data_dict + assert torch.equal(microbatch.input_ids, mock_input_ids) + assert torch.equal(microbatch.input_ids_cp_sharded, mock_input_ids_cp_sharded) + assert torch.equal(microbatch.attention_mask, mock_attention_mask) + assert torch.equal(microbatch.position_ids, mock_position_ids) + assert microbatch.packed_seq_params == mock_packed_seq_params + assert torch.equal(microbatch.cu_seqlens_padded, mock_cu_seqlens_padded) + +class TestCheckSequenceDim: + """Tests for check_sequence_dim function.""" + + def test_check_sequence_dim_valid(self): + """Test check_sequence_dim with valid data.""" + from nemo_rl.models.megatron.data import check_sequence_dim + + # Create mock data with consistent sequence dimension + data = MagicMock() + data.__getitem__ = MagicMock( + side_effect=lambda k: torch.zeros(2, 10) if k == "input_ids" else None + ) + data.items = MagicMock( + return_value=[ + ("input_ids", torch.zeros(2, 10)), + ("attention_mask", torch.zeros(2, 10)), + ] + ) + + sequence_dim, seq_dim_size = check_sequence_dim(data) + + assert sequence_dim == 1 + assert seq_dim_size == 10 + + def test_check_sequence_dim_mismatch(self): + """Test check_sequence_dim with mismatched sequence dimensions.""" + from nemo_rl.models.megatron.data import check_sequence_dim + + # Create mock data with mismatched sequence dimension + data = MagicMock() + data.__getitem__ = MagicMock( + side_effect=lambda k: torch.zeros(2, 10) if k == "input_ids" else None + ) + data.items = MagicMock( + return_value=[ + ("input_ids", torch.zeros(2, 10)), + ("other_tensor", torch.zeros(2, 15)), # Mismatched! + ] + ) + + with pytest.raises(AssertionError) as exc_info: + check_sequence_dim(data) + + assert "Dim 1 must be the sequence dim" in str(exc_info.value) + + def test_check_sequence_dim_skips_1d_tensors(self): + """Test that check_sequence_dim skips 1D tensors.""" + from nemo_rl.models.megatron.data import check_sequence_dim + + # Create mock data with 1D tensor (should be skipped) + data = MagicMock() + data.__getitem__ = MagicMock( + side_effect=lambda k: torch.zeros(2, 10) if k == "input_ids" else None + ) + data.items = MagicMock( + return_value=[ + ("input_ids", torch.zeros(2, 10)), + ("seq_lengths", torch.zeros(2)), # 1D tensor, should be skipped + ] + ) + + # Should not raise + sequence_dim, seq_dim_size = check_sequence_dim(data) + assert seq_dim_size == 10 + + +class TestProcessMicrobatch: + """Tests for process_microbatch function.""" + + @patch("nemo_rl.models.megatron.data.get_ltor_masks_and_position_ids") + def test_process_microbatch_no_packing(self, mock_get_masks): + """Test process_microbatch without sequence packing.""" + from nemo_rl.models.megatron.data import process_microbatch + + # Setup mock + mock_attention_mask = torch.ones(2, 10) + mock_position_ids = torch.arange(10).unsqueeze(0).expand(2, -1) + mock_get_masks.return_value = (mock_attention_mask, None, mock_position_ids) + + # Create test data + data_dict = MagicMock() + input_ids = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0, 0, 0], [6, 7, 8, 9, 10, 11, 12, 0, 0, 0]]) + data_dict.__getitem__ = MagicMock(return_value=input_ids) + + ( + result_input_ids, + input_ids_cp_sharded, + attention_mask, + position_ids, + packed_seq_params, + cu_seqlens_padded, + ) = process_microbatch(data_dict, pack_sequences=False) + + # Verify results + assert torch.equal(result_input_ids, input_ids) + assert torch.equal(input_ids_cp_sharded, input_ids) + assert attention_mask is not None + assert position_ids is not None + assert packed_seq_params is None + assert cu_seqlens_padded is None + + # Verify get_ltor_masks_and_position_ids was called + mock_get_masks.assert_called_once() + + @patch("nemo_rl.models.megatron.data.get_context_parallel_rank", return_value=0) + @patch("nemo_rl.models.megatron.data.get_context_parallel_world_size", return_value=1) + @patch("nemo_rl.models.megatron.data._pack_sequences_for_megatron") + def test_process_microbatch_with_packing( + self, mock_pack, mock_cp_world, mock_cp_rank + ): + """Test process_microbatch with sequence packing.""" + from nemo_rl.models.megatron.data import process_microbatch + + # Setup mocks + mock_packed_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) + mock_packed_seq_params = MagicMock() + mock_cu_seqlens = torch.tensor([0, 5, 8], dtype=torch.int32) + mock_cu_seqlens_padded = torch.tensor([0, 5, 8], dtype=torch.int32) + mock_pack.return_value = ( + mock_packed_input_ids, + mock_packed_input_ids, + mock_packed_seq_params, + mock_cu_seqlens, + mock_cu_seqlens_padded, + ) + + # Create test data + data_dict = MagicMock() + input_ids = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0], [6, 7, 8, 0, 0, 0, 0, 0]]) + seq_lengths = torch.tensor([5, 3]) + data_dict.__getitem__ = MagicMock( + side_effect=lambda k: input_ids if k == "input_ids" else seq_lengths + ) + data_dict.__contains__ = MagicMock(return_value=True) + + ( + result_input_ids, + input_ids_cp_sharded, + attention_mask, + position_ids, + packed_seq_params, + cu_seqlens_padded, + ) = process_microbatch( + data_dict, seq_length_key="input_lengths", pack_sequences=True + ) + + # Verify results + assert torch.equal(result_input_ids, mock_packed_input_ids) + assert packed_seq_params == mock_packed_seq_params + # For packed sequences, attention_mask and position_ids are None + assert attention_mask is None + assert position_ids is None + assert cu_seqlens_padded is not None + + # Verify pack was called + mock_pack.assert_called_once() + + def test_process_microbatch_packing_requires_seq_length_key(self): + """Test that packing requires seq_length_key.""" + from nemo_rl.models.megatron.data import process_microbatch + + data_dict = MagicMock() + input_ids = torch.tensor([[1, 2, 3]]) + data_dict.__getitem__ = MagicMock(return_value=input_ids) + + with pytest.raises(AssertionError) as exc_info: + process_microbatch(data_dict, seq_length_key=None, pack_sequences=True) + + assert "seq_length_key must be provided" in str(exc_info.value) + + def test_process_microbatch_packing_requires_seq_length_in_data(self): + """Test that packing requires seq_length_key to be in data_dict.""" + from nemo_rl.models.megatron.data import process_microbatch + + data_dict = MagicMock() + input_ids = torch.tensor([[1, 2, 3]]) + data_dict.__getitem__ = MagicMock(return_value=input_ids) + data_dict.__contains__ = MagicMock(return_value=False) + + with pytest.raises(AssertionError) as exc_info: + process_microbatch( + data_dict, seq_length_key="input_lengths", pack_sequences=True + ) + + assert "input_lengths not found in data_dict" in str(exc_info.value) + + +class TestProcessGlobalBatch: + """Tests for process_global_batch function.""" + + def test_process_global_batch_basic(self): + """Test basic process_global_batch functionality.""" + from nemo_rl.models.megatron.data import process_global_batch + + # Create mock data + sample_mask = torch.tensor([1.0, 1.0, 0.0]) + input_ids = torch.zeros(3, 10) + mock_batch = BatchedDataDict( + { + "sample_mask": sample_mask, + "input_ids": input_ids, + } + ) + + mock_data = MagicMock() + mock_data.get_batch.return_value = mock_batch + + mock_dp_group = MagicMock() + + # Mock torch.distributed.all_reduce + with patch("torch.distributed.all_reduce") as mock_all_reduce: + batch, global_valid_seqs, global_valid_toks = process_global_batch( + data=mock_data, + batch_idx=0, + batch_size=3, + loss_fn=MagicMock(), + dp_group=mock_dp_group, + ) + + assert torch.equal(batch["sample_mask"], mock_batch["sample_mask"]) + assert torch.equal(batch["input_ids"], mock_batch["input_ids"]) + + # Verify get_batch was called + mock_data.get_batch.assert_called_once_with(batch_idx=0, batch_size=3) + + # Verify all_reduce was called + mock_all_reduce.assert_called_once() + + def test_process_global_batch_requires_sample_mask(self): + """Test that process_global_batch requires sample_mask.""" + from nemo_rl.models.megatron.data import process_global_batch + + # Create mock data without sample_mask + mock_batch = MagicMock() + mock_batch.__contains__ = MagicMock(return_value=False) + + mock_data = MagicMock() + mock_data.get_batch.return_value = mock_batch + + with pytest.raises(AssertionError) as exc_info: + process_global_batch( + data=mock_data, + batch_idx=0, + batch_size=3, + loss_fn=MagicMock(), + dp_group=MagicMock(), + ) + + assert "sample_mask must be present" in str(exc_info.value) + +class TestGetMicrobatchIterator: + """Tests for get_microbatch_iterator function.""" + + @patch("nemo_rl.models.megatron.data.check_sequence_dim") + @patch("nemo_rl.models.megatron.data.make_processed_microbatch_iterator") + def test_get_microbatch_iterator_dynamic_batching( + self, mock_make_iterator, mock_check_seq_dim + ): + """Test get_microbatch_iterator with dynamic batching.""" + from nemo_rl.models.megatron.data import get_microbatch_iterator + + # Setup mocks + mock_check_seq_dim.return_value = (1, 128) + mock_iterator = iter([MagicMock()]) + mock_make_iterator.return_value = mock_iterator + + mock_data = MagicMock() + mock_data.make_microbatch_iterator_with_dynamic_shapes.return_value = iter([]) + mock_data.get_microbatch_iterator_dynamic_shapes_len.return_value = 5 + + cfg = { + "dynamic_batching": {"enabled": True}, + "sequence_packing": {"enabled": False}, + } + + ( + iterator, + data_iterator_len, + micro_batch_size, + seq_dim_size, + padded_seq_length, + ) = get_microbatch_iterator( + data=mock_data, + cfg=cfg, + mbs=4, + ) + + # Verify dynamic batching path was taken + mock_data.make_microbatch_iterator_with_dynamic_shapes.assert_called_once() + mock_data.get_microbatch_iterator_dynamic_shapes_len.assert_called_once() + + assert data_iterator_len == 5 + assert seq_dim_size == 128 + + @patch("nemo_rl.models.megatron.data.check_sequence_dim") + @patch("nemo_rl.models.megatron.data.make_processed_microbatch_iterator") + @patch("nemo_rl.models.megatron.data._get_pack_sequence_parameters_for_megatron") + def test_get_microbatch_iterator_sequence_packing( + self, mock_get_params, mock_make_iterator, mock_check_seq_dim + ): + """Test get_microbatch_iterator with sequence packing.""" + from nemo_rl.models.megatron.data import get_microbatch_iterator + + # Setup mocks + mock_check_seq_dim.return_value = (1, 256) + mock_get_params.return_value = (8, 16, None) + mock_iterator = iter([MagicMock()]) + mock_make_iterator.return_value = mock_iterator + + mock_data = MagicMock() + mock_data.make_microbatch_iterator_for_packable_sequences.return_value = iter([]) + mock_data.get_microbatch_iterator_for_packable_sequences_len.return_value = (10, 512) + + cfg = { + "dynamic_batching": {"enabled": False}, + "sequence_packing": {"enabled": True}, + "megatron_cfg": { + "tensor_model_parallel_size": 1, + "sequence_parallel": False, + "pipeline_model_parallel_size": 1, + "context_parallel_size": 1, + }, + } + + ( + iterator, + data_iterator_len, + micro_batch_size, + seq_dim_size, + padded_seq_length, + ) = get_microbatch_iterator( + data=mock_data, + cfg=cfg, + mbs=4, + ) + + # Verify sequence packing path was taken + mock_data.make_microbatch_iterator_for_packable_sequences.assert_called_once() + + # With sequence packing, micro_batch_size should be 1 + assert micro_batch_size == 1 + assert data_iterator_len == 10 + + @patch("nemo_rl.models.megatron.data.check_sequence_dim") + @patch("nemo_rl.models.megatron.data.make_processed_microbatch_iterator") + def test_get_microbatch_iterator_regular( + self, mock_make_iterator, mock_check_seq_dim + ): + """Test get_microbatch_iterator with regular batching.""" + from nemo_rl.models.megatron.data import get_microbatch_iterator + + # Setup mocks + mock_check_seq_dim.return_value = (1, 64) + mock_iterator = iter([MagicMock()]) + mock_make_iterator.return_value = mock_iterator + + mock_data = MagicMock() + mock_data.size = 16 + mock_data.make_microbatch_iterator.return_value = iter([]) + + cfg = { + "dynamic_batching": {"enabled": False}, + "sequence_packing": {"enabled": False}, + } + + mbs = 4 + + ( + iterator, + data_iterator_len, + micro_batch_size, + seq_dim_size, + padded_seq_length, + ) = get_microbatch_iterator( + data=mock_data, + cfg=cfg, + mbs=mbs, + ) + + # Verify regular batching path was taken + mock_data.make_microbatch_iterator.assert_called_once_with(mbs) + + assert micro_batch_size == mbs + assert data_iterator_len == 16 // mbs + assert seq_dim_size == 64 + + @patch("nemo_rl.models.megatron.data.check_sequence_dim") + @patch("nemo_rl.models.megatron.data.make_processed_microbatch_iterator") + def test_get_microbatch_iterator_auto_detects_seq_length_key( + self, mock_make_iterator, mock_check_seq_dim + ): + """Test that get_microbatch_iterator auto-detects seq_length_key for packing.""" + from nemo_rl.models.megatron.data import get_microbatch_iterator + + # Setup mocks + mock_check_seq_dim.return_value = (1, 128) + mock_iterator = iter([MagicMock()]) + mock_make_iterator.return_value = mock_iterator + + mock_data = MagicMock() + mock_data.make_microbatch_iterator_for_packable_sequences.return_value = iter([]) + mock_data.get_microbatch_iterator_for_packable_sequences_len.return_value = (5, 256) + + cfg = { + "dynamic_batching": {"enabled": False}, + "sequence_packing": {"enabled": True}, + "megatron_cfg": { + "tensor_model_parallel_size": 1, + "sequence_parallel": False, + "pipeline_model_parallel_size": 1, + "context_parallel_size": 1, + }, + } + + get_microbatch_iterator( + data=mock_data, + cfg=cfg, + mbs=4, + seq_length_key=None, # Should be auto-detected + ) + + # Verify make_processed_microbatch_iterator was called with "input_lengths" + call_kwargs = mock_make_iterator.call_args[1] + assert call_kwargs["seq_length_key"] == "input_lengths" + + +class TestMakeProcessedMicrobatchIterator: + """Tests for make_processed_microbatch_iterator function.""" + + @patch("nemo_rl.models.megatron.data.process_microbatch") + def test_make_processed_microbatch_iterator_basic(self, mock_process): + """Test make_processed_microbatch_iterator yields ProcessedMicrobatch.""" + from nemo_rl.models.megatron.data import ( + ProcessedMicrobatch, + make_processed_microbatch_iterator, + ) + + # Setup mocks + mock_input_ids = MagicMock() + mock_input_ids_cp_sharded = MagicMock() + mock_attention_mask = MagicMock() + mock_position_ids = MagicMock() + mock_packed_seq_params = None + mock_cu_seqlens_padded = None + + mock_process.return_value = ( + mock_input_ids, + mock_input_ids_cp_sharded, + mock_attention_mask, + mock_position_ids, + mock_packed_seq_params, + mock_cu_seqlens_padded, + ) + + # Create mock data dict + mock_data_dict = MagicMock() + mock_data_dict.to.return_value = mock_data_dict + + raw_iterator = iter([mock_data_dict]) + + cfg = {"sequence_packing": {"enabled": False}} + + processed_iterator = make_processed_microbatch_iterator( + raw_iterator=raw_iterator, + cfg=cfg, + seq_length_key=None, + pad_individual_seqs_to_multiple_of=1, + pad_packed_seq_to_multiple_of=1, + pad_full_seq_to=None, + ) + + # Get first item from iterator + microbatch = next(processed_iterator) + + # Verify it's a ProcessedMicrobatch + assert isinstance(microbatch, ProcessedMicrobatch) + assert microbatch.data_dict == mock_data_dict + assert microbatch.input_ids == mock_input_ids + + # Verify data was moved to CUDA + mock_data_dict.to.assert_called_once_with("cuda") + + @patch("nemo_rl.models.megatron.data.process_microbatch") + def test_make_processed_microbatch_iterator_with_packing(self, mock_process): + """Test make_processed_microbatch_iterator with sequence packing.""" + from nemo_rl.models.megatron.data import make_processed_microbatch_iterator + + # Setup mocks + mock_process.return_value = ( + MagicMock(), # input_ids + MagicMock(), # input_ids_cp_sharded + None, # attention_mask (None for packed) + None, # position_ids (None for packed) + MagicMock(), # packed_seq_params + MagicMock(), # cu_seqlens_padded + ) + + mock_data_dict = MagicMock() + mock_data_dict.to.return_value = mock_data_dict + + raw_iterator = iter([mock_data_dict]) + + cfg = {"sequence_packing": {"enabled": True}} + + processed_iterator = make_processed_microbatch_iterator( + raw_iterator=raw_iterator, + cfg=cfg, + seq_length_key="input_lengths", + pad_individual_seqs_to_multiple_of=8, + pad_packed_seq_to_multiple_of=16, + pad_full_seq_to=1024, + ) + + microbatch = next(processed_iterator) + + # Verify process_microbatch was called with pack_sequences=True + mock_process.assert_called_once() + call_kwargs = mock_process.call_args[1] + assert call_kwargs["pack_sequences"] is True + assert call_kwargs["seq_length_key"] == "input_lengths" + assert call_kwargs["pad_individual_seqs_to_multiple_of"] == 8 + assert call_kwargs["pad_packed_seq_to_multiple_of"] == 16 + assert call_kwargs["pad_full_seq_to"] == 1024 + + @ray.remote(num_gpus=1) class PackSequencesTestActor: def __init__(self, cp_size): @@ -35,7 +607,7 @@ def __init__(self, cp_size): def run_all_pack_sequences_tests(self): """Run all sequence packing tests in a single call to avoid expensive reinitializations.""" from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank - from nemo_rl.models.megatron.common import _pack_sequences_for_megatron + from nemo_rl.models.megatron.data import _pack_sequences_for_megatron # Initialize process group if CP > 1 if self.cp_size > 1: @@ -714,7 +1286,7 @@ def __init__(self): def run_all_get_pack_sequence_parameters_for_megatron_tests(self): """Test _get_pack_sequence_parameters_for_megatron function with various configurations.""" - from nemo_rl.models.megatron.common import ( + from nemo_rl.models.megatron.data import ( _get_pack_sequence_parameters_for_megatron, ) @@ -1105,3 +1677,4 @@ def test_get_pack_sequence_parameters_for_megatron(get_pack_sequence_parameters_ # Check that all workers succeeded for i, result in enumerate(results): assert result["success"], f"Worker {i} failed: {result['error']}" +