diff --git a/minillm/finetune.py b/minillm/finetune.py index 56ff7584..ecf71a7f 100644 --- a/minillm/finetune.py +++ b/minillm/finetune.py @@ -159,6 +159,10 @@ def get_distil_loss(args, tokenizer, model, teacher_model, model_batch, no_model teacher_model.eval() teacher_outputs = teacher_model(**model_batch, use_cache=False) teacher_logits = teacher_outputs.logits + if args.model_type == 'qwen2': + # If ZeRO, Get vocab size under module, No ZeRO - directly from config + student_vocab_size = model.module.config.vocab_size if hasattr(model, "module") else model.config.vocab_size + teacher_logits = teacher_logits[:, :, :student_vocab_size] if args.model_parallel: distil_losses = mpu.parallel_soft_cross_entropy_loss(logits.float(), teacher_logits.float()) distil_losses = distil_losses.view(-1)