Skip to content

Multi-GPU Training Support #52

@without-ordinary

Description

@without-ordinary

Not sure if not supported or not functional. From my casual research in to PEFT, it looked like multi-GPU was possibly not supported in the past but should be now with accelerate.

The current torchrun implementation makes it seem intended for multi-GPU training to be a thing, though it throws this error when more than one GPU is visible:

# torchrun --standalone --nproc_per_node=2 train.py \
--output-dir /mnt/datasets/checkpoints \
--device-batch-size 4 --dataset /mnt/datasets/resized-384-squish.json --max-samples 1800 \
--images-path /mnt/datasets/resized-384-squish --test-every 2000 --test-size 128
============================================================                                                                                                                  
train.py FAILED                                                                                                                                                        
------------------------------------------------------------                                                                                                                  
Failures:                                                                                                                                                                     
  <NO_OTHER_FAILURES>                                                                                                                                                         
------------------------------------------------------------                                                                                                                  
Root Cause (first observed failure):                                                                                                                                          
[0]:                                                                                                                                                                          
  time      : 2025-07-15_09:38:14                                                                                                                                             
  host      : fbb-kohya-bmaltais.tailc6aca4.ts.net                                                                                                                            
  rank      : 1 (local_rank: 1)                                                                                                                                               
  exitcode  : 1 (pid: 138162)                                                                                                                                                 
  error_file: /tmp/torchelastic_vn1tu68b/none_65nq3eb2/attempt_0/1/error.json                                                                                                 
  traceback : Traceback (most recent call last):                                                                                                                              
    File "/home/without/miniforge3/envs/joycaption/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper           
      return f(*args, **kwargs)                                                                                                                                               
             ^^^^^^^^^^^^^^^^^^                                                                                                                                               
    File "/home/without/joycaption/finetuning/train.py", line 122, in main                                                                                             
      trainer.train()                                                                                                                                                         
    File "/home/without/joycaption/finetuning/train.py", line 344, in train                                                                                            
      self.build_dataset()                                                                                                                                                    
    File "/home/without/joycaption/finetuning/train.py", line 235, in build_dataset                                                                                    
      image_token_id=self.model.config.image_token_index,                                                                                                                     
                     ^^^^^^^^^^^^^^^^^                                                                                                                                        
    File "/home/without/miniforge3/envs/joycaption/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1940, in __getattr__                                        
      raise AttributeError(                                                                                                                                                   
  AttributeError: 'DistributedDataParallel' object has no attribute 'config'                                                                                                  
                                                                                                                                                                              
============================================================

Line numbers in train.py are off by +1 due to adding import pillow_jxl, most of our datasets are kept as jxl.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions