| lang | title | subtitle | author | date | description | tags | |
|---|---|---|---|---|---|---|---|
en |
Unifying Parallelization Schemes to Train Deep Nets |
|
Jan 2025 |
There are several strategies to train large neural networks over multiple GPUs. The best practice today is to employ a mix Data Parallelism, Pipeline Parallelism, and Tensor Parallelism simultaneously. This document describes a master schedule that unifies all these schemes under one code base. From this unified view , we derive a new distributed schedule that can mix and match the tradeoffs between these schemes, obviating the need to deploy all three schemes simultaneously.
There are two main reasons to use multiple GPUs when training large models: First, it speeds up training by making more FLOPS available. Second, it supplies the memory needed to store the weights, activations, and the optimizer state to train a large model. In Distributed Data Parallelism (DDP), each worker is assigned to a subset of the training batch called a micro-batch. Each worker computes the gradient on its micro-batch using backpropagation. The main advantage of this scheme is that all workers operate independently of each other, so they are almost fully utilized during backpropagation. One down side of this method is that each worker must retain the activations of all the layers of the network, which is infeasible for large models. Another down side of this scheme is that the number of workers must exactly match the number of micro-batches. Large batch sizes improve compute utilization, but they also cause each epoch of training to make less progress on the training loss than with smaller batches (see Section 2.1 of this folk wisdom paper and Theorem 5.8 of this more theoretical paper). In other words, the training loss drops more slowly per training datum as the number of micro-batches in a batch grows, and therefore as the number of workers increases. Yet another disadvantage of this method is that each worker must store the weights for all the layers, which is again infeasible for large models. Fully Sharded Data Parallelism (FSDP) addresses this latter issue by having each worker hold the weights for only one stage, and page in the weights for other states from other workers as needed. Pipeline Parallelism (PP) addresses some of these problems by assigning to each worker a subset of the stages. Each worker still processes all micro-batches in the batch, but only for the stage assigned to the worker. A micro-batch works its way through the stages by passing from one worker to the next. An advantage of this method is that each worker need only store the weights for the stages assigned to it. Another advantage is that it accommodates more workers than there are micro-batches, offering a way to add more workers to the task A problem this it is that it has higher latency than DDP: Some workers remain idle until the first mini-batch has made it through the final stage, and then some workers become idle again as the last micro-batch winds its way through the stages. Surprisingly, the amount of activations each worker stores still scales with the number of micro-batches because each worker most store the activations of all the batches that traverse it (though see the table below for a cap on this number). To address these issues, Looped Pipeline Parallelism (LPP) operates several Pipeline Parallel processes in independent groups, and distributes the micro-batches evenly across these groups. This makes LPP run almost as fast as DDP when the number of groups is large. Fully Sharded Looped Pipeline Parallelism (FSLPP) is a variant of LPP where each worker stores the weights of only one stage, and pages in the weights of the other stages it must process from other workers as needed.
The table below summarizes the performance tradeoff between these schemes. To keep the table simple, we assume the stages of the pipeline are identical in the amount of computation they perform and in the shapes of their inputs and outputs.
| Scheme | Constraints | Latency | Activation Xmit per Worker | Weight Xmit per Worker | Activation Storage per Worker | Weight Storage per Worker |
|---|---|---|---|---|---|---|
| Distributed Data Parallel (DDP) | 0 | 0 | ||||
| Fully Sharded DDP (FSDP) | 0 | 1 | ||||
| Pipeline Parallel (PP) | 0 | 1 | ||||
| Looped PP (LPP) | 0 | |||||
| Fully Sharded LPP (FSLPP) | 1 |
Table 1: Latency, network, and memory usage for the schedules unified in this document. $W$ is the number of workers, $S$ is the depth of the network being trained, and $B$ is the number of microbatches. Looped models break $W$ into $G$ groups of $R$ workers. The formulas assume computing a forward and backward step on any stage takes a unit of time, weights for all stages take a unit of weight storage and a unit of transmit, and activations for all stages take a different kind of unit of activation storage and transmit. To keep the latency computation simple, we assume that transmission latencies can be hidden by overlapping them with compute.
Some useful patterns emerge from the table above. LPP's performance smoothly
interpolates between that of PP and DPP as
On the surface, these forms of parallelism appear so different that they're traditionally
implemented with different code bases. For example, PyTorch implements Data
Parallism under torch.nn.parallel.DistributedDataParallel, Fully Sharded Data
Parallism under torch.distributed.FSDP avariants of Pipeline Parallelism
under torch.distributed.pipelining, and Tensor Parallelism
under torch.distributed.tensor.parallel. Large models training
code bases rely on all these packages simultaneously. But all these schemes
turn out to be instances of the same master schedule. This master schedule takes
as input two functions: A function that maps a work unit to a compute worker,
and a second function that maps each stage of the pipeline to a worker that
stores the source of truth for that stage's weights. Together, these two
functions specify the schedule of computation and the implied transfer of weight
and activation data. Changing these two functions produces the entire spectrium
of distributed trainig algorithms. With the master scheduler offered in this
document, all these different schedules and their hybrids can be implemented
with just one compute and one weight storage function. The resulting training
recipe would be easier to debug and maintain, and will make it easier to explore
the larger space of possible schedules.
In summary, this document offers the following contributions:
-
A way to implement several distributed training schemes in a unified code base. Under this abstract, a parallel training scheme is defined by a compute function and a weight storage function. This in turn simplifies the client code by obviating the need to import different APIs for each distribution scheme. The client code can instead supply the the compute and weight storage function that describes a mix of parallel schemes.
-
A variant of Pipeline Parallelism, called Fully Sharded Looped Pipeline Parallelism (FSLPP), whose behavior smoothly interpolates between FSDP, and PP, allowing it to enjoy the best all worlds these methods. Along the way, we also modify Looped Pipeline Parallelism to interpolate between DDP and PP.
-
Guidlelines to set the parameters of LPP. We show that scaling the number groups in LPP to be proportional to the number of micro-batches and setting the number of workers in each group to limit the amount of activation memory required from each worker causes LPP to nearly attain the optimal throughput achievable by any distributed training scheme. This implies that under the assumptions of this doc, not much improvement is possible over the LPP family of schedules.
Our job is to compute, with the help of
Here,
To specify which worker should compute which stage, define a function
This formulation assumes that the smallest unit of work that can be assigned is
a mini-batch at a particular stage. But finer-grained units of work are
possible and in common use. One could, for example, split each
We'll illustrate the various training schedules with pipeline diagrams. In these diagrams, the columns are time indices, and each row corresponds to a worker. The cells in the diagram describe the activity of the worker during the corresponing time. The color of each cell indicates the id of the batch being processed. The cell value indicates whether the work is for the forward or backward pass, and the pipeline stage being processed. For example, here is the pipeline diagram for Distributed Data Parallelism, when there are 8 workers, 8 batches, and 4 stages:
In Distributed Data Parallelism, there is a one-to-one correspondence between workers and batches. Each worker processes every stage of the batch assigned to it in sequence. Distributed Data Parallelism (DDP) can be formally defined as:
The mini-batch index dictates where compute happens and where the
parameters are stored, so
Fully Sharded Data Parallelism (FSDP) overcomes some of these difficulties by sharding model parameters and the optimizer state across workers. The pipeline diagram for FSDP looks identical to that of plain DDP. Formally, FSDP is:
Like DDP, the choice of compute worker only depends on
In Pipeline Parallelism, aka GPipe, there is a one-to-one correspondence between workers and stages. Each worker processes every batch, but only at the stage assigned to the worker:
GPipe is specified by
In GPipe, each stage
An improvement on GPipe is Looping Pipeline
Parallelism (LPP). In LPP, each worker can
process more than one stage of the pipeline. It does this by organizing the
LPP is specified by
The first term of
In LPP, workers store the weights of all the stages they process because we
forced
The hash function
The appendix shows how one scheduler can implement the full variety of
schedules described above by just specifying their corresponding
Backpropagation on a pipeline requires lower stages to be computed before
higher stages, and the entire forward pass to be computed before the backward
pass on the same batch. We'll formalize the dependencies between these
operations with a predictate
The first condition says that stage
This formalism simplifies the code, and will be particularly handy when when we upgrade the model from a pipeline structure to a tree or DAG structure, as is the case in models with multiple encoder legs.
Some schedules differ in how worker break ties when the preconditions for more
than two jobs
Under the foregoing model, DDP and FSDP make optimal use of the workers: They are occupied at all time, performing useful computations. But when activation memory is constrainted, the DDP family of models can't be applied. It turns out that for any bound on the activation memory, there is a configuration of LPP that also nearly optimally occupies the workers, which makes the LPP family of schedules more versatile than the DDP family. The implication is that any improvements on the LPP family of models is due to affects not modeled in Table 1 above (for example, the stages are not identical, or because not all communication latencies can be hidden). To formalize this claim, we'll need the following definition and upper bound:
Definition: The throughput per worker
Definition: The peak activation storage,
For example, for DDP,
The peak activation imposes an upper bound on the throughput per worker:
Theorem: Any schedule whose communications latencies can be hidden on a pipeline with identical stages must satisfy
The proof appears in the appendix. Since with DDP,
Remarkably, LPP offers more freedom in setting
The quantities in Table 1 were computed manually for each strategy. All of these
strategies have a setup phase, where the workers gradually fill up with work, a
steady state phase during which all workers are busy, and a drainage phase where
the pipeline gradually gets lets busy. DDP and FSDP enter the steady state phase
immediately. For GPipe, the setup phase takes
But we can compute these metrics mechanically for arbitrary strategies by
inspecting
-
$j$ : A job$(s,b,d)$ . We'll use$w_\text{compute}(j)$ as a short-hand for -
$J$ : The set of jobs that have executed so far. We'll define the "frontier" to be the set of jobs that have successors that have not yet been executed. Formally, we say$j\in J$ is in the frontier if$\text{succ}(j) \notin J$ .$w_\text{compute}(s,b,d)$ . -
$t(j)$ : When the job$j$ will finish, expressed in units of time. Only valid for jobs in$J$ . -
$t(w)$ : The when worker$w$ finishes its latest job in$J$ . This is a shorthand for$\max_{j\in J : w(j) = w} t(j)$ .
The following operation identifies a job
It does this by examining every job
This process mimicks almost exactly the greedy schedule followed by the
multi-threaded code snippet in the appendix. The difference is that we have not
modeled the tie breaking logic that uses
After
Some statistics of the schedule can be computed more directly from
This is the number of times the worker has to process a work item
The number of times worker
This is the number of times the host has to process work for which it doesn't host the weights.
The amount of activation storage required by a particular worker depends on both
This is just the total number of items the worker processes.
The proof relies on the concept of memory-time for a variable. This is the
number of time steps during which a variable must stay live during a
computation. The memory-time of a computation is the sum of memory-times of all
the variables it computes. A computer with
Backpropagation on a pipeline requires the activations
flowchart-elk LR
f0
b0[∇0]
f1
b1[∇1]
f2
b2[∇2]
f3
b3[∇3]
f0 --> f1 --> f2 --> f3
b3 --> b2 --> b1 --> b0
f0 --> b0
f1 --> b1
f2 --> b2
f3 --> b3
Each batch therefore requires
To train a neural network, we compute the gradient of a loss function of the
output of the neural network with respect to the parameters of each stage of the
neural network's pipeline. We'd like compute the gradients of output of an
The input of the pipeline is
From this, each stage computes the gradient of the loss with respect to the stage's parameters:
These operations can be written in a compact operator form: $$\begin{align*} x_{s+1} &= F_s ; x_s \ z_s &= B_s(x_s) ; z_{s+1} \end{align*}$$
The forward operator
We are to apply the above pipeline on
This section sketches out pseudocode for the generic task scheduler. The
pseudocode is written in a math-inspired version of Python. Like in Python,
ranges are inclusive of their start and exclusive of their end (so 0..W
includes 0 but not W). The match operator is inspired by Haskell and Rust's
match statement. The set syntax follows established mathematical notation.
Before the round of work starts, we globally initialize a list of available work:
# A unit of work is a tuple of a batch id and an operator F[s] or B[s].
Direction = 'forward' | 'backward'
Stage = 0..S
Batch = 0..B
Work = (Stage, Batch, Direction)
Worker = 0..W
# Initially, the first stage of all batches is ready to execute.
ready : Set[Work] = {(-1, b, 'forward') : b ∈ Batch}
# Nothing is currently executing.
working_on : Set[Work] = {}Then each worker goes in a loop picking off available work that it can process:
def available_work(worker: Worker, ready: Set[Work]) → (Work, Work):
# Find jobs that are ready to be executed by the worker. This blocks
# until a job is available.
candidates = {(job, new_job) ∈ ready x ready :
w_compute(new_job) = worker, new_job = pred(job)}
# Among the candidates, return the pair with the highest priority.
return find_smallest(candidates, key=lambda (job, new_job): n(new_job))
def fetch_work(w: Worker, ready: Set[Work]) → (Work, Work):
atomically:
finished_work, work_to_do = available_work(w, ready)
remove finished_work from ready
return (finished_work, work_to_do)
def worker_compute_thread(w: Worker):
while True:
finished_work, work_to_do = fetch_work(w, ready)
match work_to_do:
stage, batch, 'forward':
# Initiate prefetch
insert work_to_do into working_on
# Start computing
new_activations = F[stage](
await_weights(stage),
await_activations(stage - 1, batch)
)
store_activations(stage, batach, new_activations)
insert work_to_do into ready
stage, batch, 'backward':
# Initiate prefetch
insert work_to_do into working_on
# Start computing
new_grad = B[stage](
(stage < S) and await_gradients(stage+1, batch),
await_weights(stage),
await_activations(stage, batch)
)
store_gradient(stage, b, new_grad)
def worker_prefetch_thread(w: Worker):
# Optional thread that hides the fetching latency in the compute thread
# by prefetching weights and activations.
while True:
finished_work, work_to_do = fetch_work(w, working_on)
match work_to_do:
stage, batch, 'forward':
prefetch_weights(stage)
prefetch_activations(stage - 1, batch)
stage, batch, 'backward':
(stage < S) and prefetch_gradients(stage + 1)
prefetch_weights(stage)
prefetch_activations(stage, batch)The code above depends on available_work
function. It depends on await_gradients,
await_activations and their prefetch varaiants.


