Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion PyTorch/SpeechSynthesis/FastPitch/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.09-py3
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.10-py3
FROM ${FROM_IMAGE_NAME}

ADD requirements.txt .
Expand Down
52 changes: 34 additions & 18 deletions PyTorch/SpeechSynthesis/FastPitch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from torch.utils.data.distributed import DistributedSampler

import common.tb_dllogger as logger
from apex import amp
#from apex import amp
from apex.optimizers import FusedAdam, FusedLAMB

import common
Expand Down Expand Up @@ -172,8 +172,7 @@ def corrupted(fpath):
return None


def save_checkpoint(local_rank, model, ema_model, optimizer, epoch, total_iter,
config, amp_run, filepath):
def save_checkpoint(local_rank, model, ema_model, optimizer, scaler, epoch, total_iter, config, amp_run, filepath):
if local_rank != 0:
return

Expand All @@ -186,11 +185,15 @@ def save_checkpoint(local_rank, model, ema_model, optimizer, epoch, total_iter,
'ema_state_dict': ema_dict,
'optimizer': optimizer.state_dict()}
if amp_run:
checkpoint['amp'] = amp.state_dict()
#checkpoint['amp'] = amp.state_dict()
checkpoint = {"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict()}

torch.save(checkpoint, filepath)


def load_checkpoint(local_rank, model, ema_model, optimizer, epoch, total_iter,
def load_checkpoint(local_rank, model, ema_model, optimizer, scaler, epoch, total_iter,
config, amp_run, filepath, world_size):
if local_rank == 0:
print(f'Loading model and optimizer state from {filepath}')
Expand All @@ -205,7 +208,10 @@ def load_checkpoint(local_rank, model, ema_model, optimizer, epoch, total_iter,
optimizer.load_state_dict(checkpoint['optimizer'])

if amp_run:
amp.load_state_dict(checkpoint['amp'])
#amp.load_state_dict(checkpoint['amp'])
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
scaler.load_state_dict(checkpoint["scaler"])

if ema_model is not None:
ema_model.load_state_dict(checkpoint['ema_state_dict'])
Expand Down Expand Up @@ -336,8 +342,10 @@ def main():
else:
raise ValueError

if args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

#if args.amp:
#model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

if args.ema_decay > 0:
ema_model = copy.deepcopy(model)
Expand Down Expand Up @@ -426,16 +434,20 @@ def main():
model.zero_grad()

x, y, num_frames = batch_to_gpu(batch)
y_pred = model(x, use_gt_durations=True)
loss, meta = criterion(y_pred, y)

loss /= args.gradient_accumulation_steps
#AMP upstream autocast
with torch.cuda.amp.autocast(enabled=args.amp):
y_pred = model(x, use_gt_durations=True)
loss, meta = criterion(y_pred, y)

loss /= args.gradient_accumulation_steps
meta = {k: v / args.gradient_accumulation_steps
for k, v in meta.items()}

if args.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
#with amp.scale_loss(loss, optimizer) as scaled_loss:
#scaled_loss.backward()
scaler.scale(loss).backward()
else:
loss.backward()

Expand All @@ -458,13 +470,17 @@ def main():

logger.log_grads_tb(total_iter, model)
if args.amp:
torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), args.grad_clip_thresh)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_thresh)
scaler.step(optimizer)
scaler.update()
#optimizer.zero_grad(set_to_none=True)
optimizer.zero_grad()
else:
torch.nn.utils.clip_grad_norm_(
model.parameters(), args.grad_clip_thresh)

optimizer.step()
optimizer.step()
apply_ema_decay(model, ema_model, args.ema_decay)

iter_time = time.perf_counter() - iter_start_time
Expand Down Expand Up @@ -517,7 +533,7 @@ def main():

checkpoint_path = os.path.join(
args.output, f"FastPitch_checkpoint_{epoch}.pt")
save_checkpoint(args.local_rank, model, ema_model, optimizer, epoch,
save_checkpoint(args.local_rank, model, ema_model, optimizer, scaler, epoch,
total_iter, model_config, args.amp, checkpoint_path)
logger.flush()

Expand All @@ -538,7 +554,7 @@ def main():
(epoch % args.epochs_per_checkpoint != 0) and args.local_rank == 0):
checkpoint_path = os.path.join(
args.output, f"FastPitch_checkpoint_{epoch}.pt")
save_checkpoint(args.local_rank, model, ema_model, optimizer, epoch,
save_checkpoint(args.local_rank, model, ema_model, optimizer, scaler, epoch,
total_iter, model_config, args.amp, checkpoint_path)


Expand Down
2 changes: 1 addition & 1 deletion PyTorch/SpeechSynthesis/Tacotron2/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.06-py3
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.10-py3
FROM ${FROM_IMAGE_NAME}

ADD . /workspace/tacotron2
Expand Down
1 change: 0 additions & 1 deletion PyTorch/SpeechSynthesis/Tacotron2/tacotron2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,6 @@ def forward(self, inputs):
input_lengths, output_lengths = input_lengths.data, output_lengths.data

embedded_inputs = self.embedding(inputs).transpose(1, 2)

encoder_outputs = self.encoder(embedded_inputs, input_lengths)

mel_outputs, gate_outputs, alignments = self.decoder(
Expand Down
46 changes: 22 additions & 24 deletions PyTorch/SpeechSynthesis/Tacotron2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

from apex.parallel import DistributedDataParallel as DDP
from torch.nn.parallel import DistributedDataParallel as DDP

import models
import loss_functions
Expand All @@ -51,11 +51,6 @@

from scipy.io.wavfile import write as write_wav

from apex import amp
amp.lists.functional_overrides.FP32_FUNCS.remove('softmax')
amp.lists.functional_overrides.FP16_FUNCS.append('softmax')


def parse_args(parser):
"""
Parse commandline arguments.
Expand Down Expand Up @@ -188,7 +183,7 @@ def init_distributed(args, world_size, rank, group_name):
print("Done initializing distributed")


def save_checkpoint(model, optimizer, epoch, config, amp_run, output_dir, model_name,
def save_checkpoint(model, optimizer, scaler, epoch, config, amp_run, output_dir, model_name,
local_rank, world_size):

random_rng_state = torch.random.get_rng_state().cuda()
Expand All @@ -215,7 +210,9 @@ def save_checkpoint(model, optimizer, epoch, config, amp_run, output_dir, model_
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()}
if amp_run:
checkpoint['amp'] = amp.state_dict()
checkpoint = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scaler': scaler.state_dict()}

checkpoint_filename = "checkpoint_{}_{}.pt".format(model_name, epoch)
checkpoint_path = os.path.join(output_dir, checkpoint_filename)
Expand Down Expand Up @@ -256,8 +253,9 @@ def load_checkpoint(model, optimizer, epoch, config, amp_run, filepath, local_ra
optimizer.load_state_dict(checkpoint['optimizer'])

if amp_run:
amp.load_state_dict(checkpoint['amp'])

model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
scaler.load_state_dict(checkpoint['scaler'])

# adapted from: https://discuss.pytorch.org/t/opinion-eval-should-be-a-context-manager/18998/3
# Following snippet is licensed under MIT license
Expand Down Expand Up @@ -384,16 +382,13 @@ def main():
cpu_run=False,
uniform_initialize_bn_weight=not args.disable_uniform_initialize_bn_weight)

if not args.amp and distributed_run:
model = DDP(model)
if distributed_run:
model = DDP(model,device_ids=[local_rank],output_device=local_rank)

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate,
weight_decay=args.weight_decay)

if args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
if distributed_run:
model = DDP(model)
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

try:
sigma = args.sigma
Expand Down Expand Up @@ -475,9 +470,11 @@ def main():
model.zero_grad()
x, y, num_items = batch_to_gpu(batch)

y_pred = model(x)
loss = criterion(y_pred, y)

#AMP upstream autocast
with torch.cuda.amp.autocast(enabled=args.amp):
y_pred = model(x)
loss = criterion(y_pred, y)

if distributed_run:
reduced_loss = reduce_tensor(loss.data, world_size).item()
reduced_num_items = reduce_tensor(num_items.data, 1).item()
Expand All @@ -495,10 +492,11 @@ def main():
reduced_num_items_epoch += reduced_num_items

if args.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), args.grad_clip_thresh)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)

else:
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
Expand Down Expand Up @@ -532,7 +530,7 @@ def main():
batch_to_gpu)

if (epoch % args.epochs_per_checkpoint == 0) and args.bench_class == "":
save_checkpoint(model, optimizer, epoch, model_config,
save_checkpoint(model, optimizer, scaler, epoch, model_config,
args.amp, args.output, args.model_name,
local_rank, world_size)
if local_rank == 0:
Expand Down