-
Notifications
You must be signed in to change notification settings - Fork 66
Description
我在跑conformer的时候遇到了这个问题。
`Traceback (most recent call last):
File "run.py", line 134, in
main(cmd_args, params, expdir)
File "run.py", line 72, in main
main(cmd_args, params, expdir)
File "run.py", line 72, in main
trainer.train(train_loader=train_loader)
File "/media/shiyanshi/E/2021_XZW/OpenTransformer/otrans/train/trainer.py", line 90, in train
train_loss = self.train_one_epoch(epoch, train_loader.loader)
File "/media/shiyanshi/E/2021_XZW/OpenTransformer/otrans/train/trainer.py", line 152, in train_one_epoch
loss, aux_loss = self.model(inputs, targets)
File "/home/shiyanshi/anaconda3/envs/work/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/media/shiyanshi/E/2021_XZW/OpenTransformer/otrans/model/speech2text.py", line 47, in forward
enc_inputs, enc_mask = self.frontend(enc_inputs, enc_mask)
File "/home/shiyanshi/anaconda3/envs/work/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/media/shiyanshi/E/2021_XZW/OpenTransformer/otrans/frontend/conv.py", line 142, in forward
x, mask = self.conv2(x, mask)
File "/home/shiyanshi/anaconda3/envs/work/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/media/shiyanshi/E/2021_XZW/OpenTransformer/otrans/frontend/conv.py", line 63, in forward
out = self.conv_layer(x)
File "/home/shiyanshi/anaconda3/envs/work/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/home/shiyanshi/anaconda3/envs/work/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 349, in forward
return self._conv_forward(input, self.weight)
File "/home/shiyanshi/anaconda3/envs/work/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 346, in _conv_forward
self.padding, self.dilation, self.groups)
RuntimeError: Calculated padded input size per channel: (2 x 42). Kernel size: (3 x 3). Kernel size can't be greater than actual input size
`
请问下,怎么修改呀。