Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions 3rdparty/Megatron-Bridge-workspace/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
"mamba-ssm",
"nvidia-resiliency-ext",
"causal-conv1d",
"timm",
"open-clip-torch>=3.2.0",
]

# If the bridge source exists, compare cached dependencies with the submodule's pyproject
Expand Down
216 changes: 216 additions & 0 deletions examples/configs/sft_nanov3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# SFT Algorithm Configuration
sft:
## total number of steps to train will equal
## min((max_num_epochs * len(train_dataloader)), max_num_steps)
max_num_epochs: 1
max_num_steps: 60

val_period: 10
val_batches: 8
val_global_batch_size: 32
val_micro_batch_size: 1
val_at_start: true
seed: 42

checkpointing:
enabled: true
checkpoint_dir: "results/sft_nanov3_lora_permute_fusion_true"
metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name
higher_is_better: false
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null

policy:
model_name: "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16"
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
# chat_template can be a Jinja template string or path to a .jinja file
chat_template: "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}"
chat_template_kwargs: null # can be used to pass kwargs to the chat template, e.g., enable_thinking=true
train_global_batch_size: 32
train_micro_batch_size: 1
max_total_sequence_length: 1024
precision: "bfloat16"

offload_optimizer_for_logprob: false

dtensor_cfg:
enabled: false
env_vars: {}
cpu_offload: False
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

dynamic_batching:
enabled: false
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
sequence_length_round: 64

sequence_packing:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: 8
max_grad_norm: 1.0

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 5.0e-6
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5
# when using Dtensor, we need to set foreach
# and fused to False
foreach: False
fused: False

## ignored since enabled=false, but needed for testing purposes
megatron_cfg:
enabled: true
env_vars: {}
empty_unused_memory_level: 1
activation_checkpointing: false
tensor_model_parallel_size: 8
expert_tensor_parallel_size: 1
expert_model_parallel_size: 8
pipeline_model_parallel_size: 1
context_parallel_size: 1
pipeline_dtype: ${policy.precision}
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
sequence_parallel: true
freeze_moe_router: false
moe_router_dtype: null
moe_router_load_balancing_type: "none"
moe_router_bias_update_rate: 1e-3
moe_permute_fusion: true # try this as true
#gives ~20% training perf speedup with sequence packing
apply_rope_fusion: True
# gives ~25% training perf speedup with sequence packing and apply_rope_fusion
bias_activation_fusion: False
defer_fp32_logits: False
moe_per_layer_logging: False

peft:
enabled: false
target_modules: []
dim: 32
alpha: 32
dropout: 0.0
dropout_position: "pre"
lora_A_init_method: "xavier"
lora_B_init_method: "zero"
a2a_experimental: false
lora_dtype: None


optimizer:
optimizer: "adam"
lr: 5.0e-6
min_lr: 4.9999e-6
weight_decay: 0.1
bf16: false
fp16: false
params_dtype: "float32"

#adam
adam_beta1: 0.9
adam_beta2: 0.98
adam_eps: 1e-5

#sgd
sgd_momentum: 0.9

#distributed optimizer
use_distributed_optimizer: true
use_precision_aware_optimizer: true

clip_grad: ${policy.max_grad_norm}

# optimizer cpu offload
optimizer_cpu_offload: false
optimizer_offload_fraction: 0.0

scheduler:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: 1000
lr_warmup_iters: 50
lr_warmup_init: 4.9999e-6

distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: false
overlap_param_gather: true
data_parallel_sharding_strategy: "optim_grads_params"
use_custom_fsdp: false

data:
max_input_seq_length: ${policy.max_total_sequence_length}
add_bos: true
add_eos: true
add_generation_prompt: false
shuffle: false
num_workers: 1

dataset_name: "squad"
# You can use custom response datasets for training and validation. For example:
# data:
# dataset_name: ResponseDataset
# train_data_path: <PathToTrainingDataset> # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace)
# val_data_path: <PathToValidationDataset>
# input_key: <QuestionKey>, default is "input"
# output_key: <AnswerKey>, default is "output"
# train_split: <TrainSplit>, default is None # used for HuggingFace datasets
# val_split: <ValSplit>, default is None # used for HuggingFace datasets
# See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details.

## unused with squad dataset
prompt_file: null
split: null
output_key: null
seed: null


## OpenAI format specific configs
# train_data_path: "/path/to/train.jsonl" # Path to training data
# val_data_path: "/path/to/val.jsonl" # Path to validation data
# chat_key: "messages" # Key for messages in the data
# system_key: null # Key for system message (optional)
# system_prompt: null # Default system prompt (optional)
# tool_key: "tools" # Key for tools in the data
# use_preserving_dataset: false # If true, uses PreservingDataset to preserve heterogeneous schemas (e.g., tool calls with varying argument structures)

logger:
log_dir: "logs" # Base directory for all logs
wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running
tensorboard_enabled: true
mlflow_enabled: false
swanlab_enabled: false # Disable SwanLab logging
monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard
wandb:
project: "sft-dev"
name: "sft-dev-${data.dataset_name}"
tensorboard:
log_dir: "tb_logs-sft-dev-${data.dataset_name}"
mlflow:
experiment_name: "sft-dev"
run_name: "sft-dev-${data.dataset_name}"
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)

cluster:
gpus_per_node: 8
num_nodes: 1
50 changes: 40 additions & 10 deletions nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _apply_transformer_engine_patch():
except Exception as e:
print(f"Error checking/patching transformer_engine: {e}")


from megatron.bridge.peft.lora import LoRA
from megatron.bridge import AutoBridge
from megatron.bridge.models.model_provider import get_model
from megatron.bridge.training import fault_tolerance
Expand Down Expand Up @@ -143,6 +143,7 @@ def _apply_transformer_engine_patch():
from megatron.bridge.training.optim import setup_optimizer
from megatron.bridge.training.setup import (
_update_model_config_funcs,
_create_peft_pre_wrap_hook
)
from megatron.bridge.training.state import GlobalState
from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer
Expand Down Expand Up @@ -178,6 +179,7 @@ def _apply_transformer_engine_patch():
)
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.module import Float16Module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.training.utils import get_ltor_masks_and_position_ids
Expand Down Expand Up @@ -366,6 +368,32 @@ def freeze_moe_router(megatron_model):

mixed_precision_wrapper = CustomFloat16Module
pre_wrap_hook.extend([freeze_moe_router])

if policy_cfg["megatron_cfg"].get("peft", {}).get("enabled", False):
lora_cfg = policy_cfg["megatron_cfg"].get("peft", {})
peft_cfg = LoRA(
target_modules=lora_cfg["target_modules"],
dim=lora_cfg["dim"],
alpha=lora_cfg["alpha"],
dropout=lora_cfg["dropout"],
dropout_position=lora_cfg["dropout_position"],
lora_A_init_method=lora_cfg["lora_A_init_method"],
lora_B_init_method=lora_cfg["lora_B_init_method"],
a2a_experimental=lora_cfg["a2a_experimental"],
lora_dtype=lora_cfg["lora_dtype"])
else:
peft_cfg = None
cfg.lora_cfg = peft_cfg

if cfg.lora_cfg is not None:
pre_peft_hook = _create_peft_pre_wrap_hook(cfg, state)
cfg.model.register_pre_wrap_hook(pre_peft_hook)
def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
model = pre_peft_hook(model)
return model
peft_hook = composed_peft_hook
else:
peft_hook = []

# Model, optimizer, and learning rate.
model = get_model(
Expand All @@ -374,7 +402,7 @@ def freeze_moe_router(megatron_model):
use_torch_fsdp2=cfg.dist.use_torch_fsdp2,
overlap_param_gather_with_optimizer_step=cfg.optimizer.overlap_param_gather_with_optimizer_step,
data_parallel_random_init=cfg.rng.data_parallel_random_init,
pre_wrap_hook=pre_wrap_hook,
pre_wrap_hook=peft_hook, # @TODO @adithyare should integrate with pre-wrap-hook thats defined for freezing moe layers
mixed_precision_wrapper=mixed_precision_wrapper,
)
if load_optimizer:
Expand All @@ -391,14 +419,16 @@ def freeze_moe_router(megatron_model):
print("Model, optimizer, and learning rate scheduler built")
torch.distributed.barrier()

# Load checkpoint if applicable
if (
cfg.checkpoint.load is not None
or cfg.checkpoint.pretrained_checkpoint is not None
) and (
checkpoint_exists(cfg.checkpoint.load)
or checkpoint_exists(cfg.checkpoint.pretrained_checkpoint)
):
if cfg.lora_cfg is not None:
should_load_checkpoint = (cfg.checkpoint.load is not None and checkpoint_exists(cfg.checkpoint.load))
if should_load_checkpoint:
# The finetune toggle is explicitly set to True in order to avoid loading optimizer and RNG states
# This is switched off here in order to load these states from the checkpoint
cfg.checkpoint.finetune = False
else:
should_load_checkpoint = (cfg.checkpoint.load is not None and checkpoint_exists(cfg.checkpoint.load)) or (cfg.checkpoint.pretrained_checkpoint is not None and checkpoint_exists(cfg.checkpoint.pretrained_checkpoint))

if should_load_checkpoint:
load_checkpoint(
state,
model,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ override-dependencies = [
"transformer-engine[pytorch]==2.8.0",
"opencv-python-headless>=4.11.0",
"nvidia-modelopt[torch]>=0.39.0",
"timm<=1.0.22",
]
# CVE fxies
constraint-dependencies = [
Expand Down
Loading
Loading