-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathencoder.py
More file actions
40 lines (35 loc) · 1.24 KB
/
encoder.py
File metadata and controls
40 lines (35 loc) · 1.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
import rnn_utils
class EncoderRNN(nn.Module):
def __init__(self, model_config, device=None):
super(EncoderRNN, self).__init__()
self.device = device
self.hidden_size = model_config.hidden_size
self.embedding = nn.Embedding(
model_config.input_size,
model_config.hidden_size)
self.bidirectional = model_config.bidirectional
self.rnn_type = model_config.rnn_type
self.num_layers = model_config.num_layers_encoder
self.rnn = rnn_utils.initRNN(
model_config.rnn_type,
self.hidden_size,
self.hidden_size,
self.num_layers,
bidirectional=self.bidirectional)
def forward(self, input, hidden):
embedded = self.embedding(input)
return self.rnn(embedded, hidden)
def initHidden(self):
return torch.zeros(
(2 if self.bidirectional else 1) *
self.num_layers,
1,
self.hidden_size,
device=self.device)
def initEncoderHidden(self):
if self.rnn_type == 'lstm':
return (self.initHidden(), self.initHidden())
elif self.rnn_type == 'gru':
return self.initHidden()