diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 2a0b0123b..c6cf0c3eb 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -298,7 +298,7 @@ def _generate_data_iterator(rollout_data, micro_batch_size, micro_batch_indices= dist.all_reduce(num_microbatches, op=dist.ReduceOp.MAX, group=dp_group) if vpp_size > 1: - # vpp requies the number of microbatches to be divisible by vpp_size + # vpp requires the number of microbatches to be divisible by vpp_size num_microbatches = torch.clamp( num_microbatches // microbatch_group_size_per_vp_stage * microbatch_group_size_per_vp_stage, min=1,