Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions sequence_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from discriminator import Discriminator
from rollout import ROLLOUT
from target_lstm import TARGET_LSTM
import cPickle
import _pickle as pickle

# encoding
#########################################################################################
# Generator Hyper-parameters
######################################################################################
Expand Down Expand Up @@ -57,7 +58,7 @@ def target_loss(sess, target_lstm, data_loader):
nll = []
data_loader.reset_pointer()

for it in xrange(data_loader.num_batch):
for it in range(data_loader.num_batch):
batch = data_loader.next_batch()
g_loss = sess.run(target_lstm.pretrain_loss, {target_lstm.x: batch})
nll.append(g_loss)
Expand All @@ -70,7 +71,7 @@ def pre_train_epoch(sess, trainable_model, data_loader):
supervised_g_losses = []
data_loader.reset_pointer()

for it in xrange(data_loader.num_batch):
for it in range(data_loader.num_batch):
batch = data_loader.next_batch()
_, g_loss = trainable_model.pretrain_step(sess, batch)
supervised_g_losses.append(g_loss)
Expand All @@ -89,7 +90,7 @@ def main():
dis_data_loader = Dis_dataloader(BATCH_SIZE)

generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
target_params = cPickle.load(open('save/target_params.pkl'))
target_params = pickle.load(open('save/target_params_py3.pkl', 'rb'))
target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

discriminator = Discriminator(sequence_length=20, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim,
Expand All @@ -106,26 +107,26 @@ def main():

log = open('save/experiment-log.txt', 'w')
# pre-train generator
print 'Start pre-training...'
print('Start pre-training...')
log.write('pre-training...\n')
for epoch in xrange(PRE_EPOCH_NUM):
for epoch in range(PRE_EPOCH_NUM):
loss = pre_train_epoch(sess, generator, gen_data_loader)
if epoch % 5 == 0:
generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
likelihood_data_loader.create_batches(eval_file)
test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
print 'pre-train epoch ', epoch, 'test_loss ', test_loss
print('pre-train epoch ', epoch, 'test_loss ', test_loss)
buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
log.write(buffer)

print 'Start pre-training discriminator...'
print('Start pre-training discriminator...')
# Train 3 epoch on the generated data and do this for 50 times
for _ in range(50):
generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
dis_data_loader.load_train_data(positive_file, negative_file)
for _ in range(3):
dis_data_loader.reset_pointer()
for it in xrange(dis_data_loader.num_batch):
for it in range(dis_data_loader.num_batch):
x_batch, y_batch = dis_data_loader.next_batch()
feed = {
discriminator.input_x: x_batch,
Expand All @@ -136,8 +137,8 @@ def main():

rollout = ROLLOUT(generator, 0.8)

print '#########################################################################'
print 'Start Adversarial Training...'
print('#########################################################################')
print('Start Adversarial Training...')
log.write('adversarial training...\n')
for total_batch in range(TOTAL_BATCH):
# Train the generator for one step
Expand All @@ -153,7 +154,7 @@ def main():
likelihood_data_loader.create_batches(eval_file)
test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
print 'total_batch: ', total_batch, 'test_loss: ', test_loss
print('total_batch: ', total_batch, 'test_loss: ', test_loss)
log.write(buffer)

# Update roll-out parameters
Expand All @@ -166,7 +167,7 @@ def main():

for _ in range(3):
dis_data_loader.reset_pointer()
for it in xrange(dis_data_loader.num_batch):
for it in range(dis_data_loader.num_batch):
x_batch, y_batch = dis_data_loader.next_batch()
feed = {
discriminator.input_x: x_batch,
Expand Down