diff --git a/RNN_utils.py b/RNN_utils.py index abba2ca..f555da4 100644 --- a/RNN_utils.py +++ b/RNN_utils.py @@ -27,17 +27,17 @@ def load_data(data_dir, seq_length): ix_to_char = {ix:char for ix, char in enumerate(chars)} char_to_ix = {char:ix for ix, char in enumerate(chars)} - X = np.zeros((len(data)/seq_length, seq_length, VOCAB_SIZE)) - y = np.zeros((len(data)/seq_length, seq_length, VOCAB_SIZE)) - for i in range(0, len(data)/seq_length): - X_sequence = data[i*seq_length:(i+1)*seq_length] + X = np.zeros((len(data) - (seq_length-1), seq_length, VOCAB_SIZE)) + y = np.zeros((len(data) - (seq_length-1), seq_length, VOCAB_SIZE)) + for i in range(0, len(data) - (seq_length-1)): + X_sequence = data[i+seq_length:i+1+seq_length] X_sequence_ix = [char_to_ix[value] for value in X_sequence] input_sequence = np.zeros((seq_length, VOCAB_SIZE)) for j in range(seq_length): input_sequence[j][X_sequence_ix[j]] = 1. X[i] = input_sequence - y_sequence = data[i*seq_length+1:(i+1)*seq_length+1] + y_sequence = data[i+seq_length+1:(i+1+seq_length+1] y_sequence_ix = [char_to_ix[value] for value in y_sequence] target_sequence = np.zeros((seq_length, VOCAB_SIZE)) for j in range(seq_length):