-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathevaluate.py
More file actions
179 lines (146 loc) · 7.3 KB
/
evaluate.py
File metadata and controls
179 lines (146 loc) · 7.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#evaluate.py
import torch
import argparse
import os
import time
import math
import dill
from model import MiniscanRBBaseline, WordToNumber
from util import get_episode_generator, timeSince, tabu_update, cuda_a_dict, get_supervised_batchsize
from train import gen_samples, train_batched_step, eval_ll
from test import batched_test_with_sampling
from generate_episode import exact_perm_doubled_rules
from interpret_grammar import Grammar
from collections import namedtuple
SearchResult = namedtuple("SearchResult", "hit solution stats")
def compute_val_ll(model, samples_val=None):
if samples_val==None:
samples_val = model.samples_val[:N_TEST_NEW]
model.val_states = []
for s in model.samples_val[:N_TEST_NEW]:
states, rules = model.sample_to_statelist(s)
for i in range(len(rules)):
model.val_states.append( model.state_rule_to_sample(states[i], rules[i]) )
val_loss = eval_ll(model.val_states, model)
return val_loss
if __name__=='__main__':
#args
parser = argparse.ArgumentParser()
parser.add_argument('--max_length_eval', type=int, default=15,
help='maximum generated sequence length when evaluating the network')
parser.add_argument('--max_num_rules', type=int, default=12, help='maximum generated num rules')
parser.add_argument('--fn_out_model', type=str, default='', help='filename for saving the model')
parser.add_argument('--dir_model', type=str, default='out_models', help='directory for saving model files')
parser.add_argument('--gpu', type=int, default=0, help='set which GPU we want to use')
parser.add_argument('--batchsize', type=int, default=64 )
parser.add_argument('--timeout', type=int, default=30)
parser.add_argument('--mode', type=str, default='sample', choices=['smc', 'sample']) #beam,
parser.add_argument('--type', type=str, default="miniscanRBbase")
parser.add_argument('--savefile', type=str, default="results/smcREPL.p")
parser.add_argument('--use_large_support', action='store_true')
parser.add_argument('--use_rules_hard', action='store_true')
parser.add_argument('--use_scan_large_s', action='store_true')
parser.add_argument('--new_test_ep', type=str, default='')
parser.add_argument('--load_data', type=str, default='')
parser.add_argument('--val_ll', action='store_true')
parser.add_argument('--val_ll_only', action='store_true')
parser.add_argument('--n_test', type=int, default=20)
parser.add_argument('--duplicate_test', action='store_true')
parser.add_argument('--positional', action='store_true', default=True) # positional rule encodings
parser.add_argument('--hack_gt_g', action='store_true')
parser.add_argument('--nosearch', action='store_true')
parser.add_argument('--partial_credit', action='store_true', default=True)
parser.add_argument('--seperate_query', action='store_true')
parser.add_argument('--human_miniscan', action='store_true')
args = parser.parse_args()
torch.cuda.set_device(args.gpu)
if args.val_ll_only: args.val_ll = True
path = os.path.join(args.dir_model, args.fn_out_model)
filename = args.savefile
N_TEST_NEW = args.n_test
# model = MiniscanModel.load('out_models/REPLMiniscan0')
# samples_val = model.samples_val
#load model
if args.type == 'miniscanRBbase':
model = MiniscanRBBaseline.load(path)
elif args.type == 'WordToNumber':
model = WordToNumber.load(path)
else:
assert False, "not implemented yet"
#
if args.new_test_ep:
print("generating new test examples")
generate_episode_train, generate_episode_test, input_lang, output_lang, prog_lang = get_episode_generator(
args.new_test_ep, model_in_lang=model.input_lang,
model_out_lang=model.output_lang,
model_prog_lang=model.prog_lang)
#model.tabu_episodes = set([])
model.samples_val = []
for i in range(N_TEST_NEW):
sample = generate_episode_test(model.tabu_episodes)
if args.hack_gt_g: sample['grammar'] = Grammar( exact_perm_doubled_rules() , model.input_lang.symbols)
model.samples_val.append(sample)
if not args.duplicate_test: model.tabu_episodes = tabu_update(model.tabu_episodes, sample['identifier'])
model.input_lang = input_lang
model.output_lang = output_lang
if not args.val_ll_only:
model.prog_lang = prog_lang
if args.load_data:
if os.path.isfile(args.load_data):
print('loading test data ... ')
with open(args.load_data, 'rb') as h:
test_samples = dill.load(h)
model.samples_val = test_samples
else:
print("no test data found, so saving current test data as new")
with open(args.load_data, 'wb') as h:
dill.dump(model.samples_val, h)
if args.val_ll:
val_ll = compute_val_ll(model)
print("val ll:", val_ll)
if args.val_ll_only: assert False
#print("batchsize:", args.batchsize)
count = 0
results = []
frac_exs_hits = []
for j, sample in enumerate(model.samples_val):
print()
print(f"Task {j+1} out of {len(model.samples_val)}")
print("ground truth grammar:")
print(sample['identifier'])
if args.human_miniscan:
from miniscan_state import examples_train, examples_test
examples = examples_train
query_examples = examples_test
else:
examples, query_examples = None, None
hit, solution, stats = batched_test_with_sampling(sample, model, max_len=1 if 'RB' in args.type or 'Word' in args.type else 15,
examples=examples,
query_examples=query_examples,
timeout=args.timeout,
verbose=True,
min_len=0,
batch_size=args.batchsize,
nosearch=args.nosearch,
partial_credit=args.partial_credit,
seperate_query=args.seperate_query,
max_rule_size= 100 if 'RB' in args.type or 'Word' in args.type else 15)
#should be in stats
frac_exs_hits.append(stats['fraction_query_hit'])
if solution:
if hit: print("SUCCESS!!!!!!!")
print("found grammar:", flush=True)
rules = solution.rules
for r in rules: print(' '.join(r) if 'scan' in args.type else r)
if hit: count +=1
results.append( (sample , SearchResult(hit, solution, stats)) )
with open(filename, 'wb') as savefile:
dill.dump(results, savefile)
print("prelim results file saved at", filename)
print(f"HIT {count} out of {len(model.samples_val)}")
avg = sum(frac_exs_hits)/len(frac_exs_hits)
print(f"AVERAGE {avg*100} % examples")
print(f"average nodes expanded: {sum(result.stats['nodes_expanded'] for samp, result in results )/len(results)}")
variance = sum([(f - avg)**2 for f in frac_exs_hits ])/len(frac_exs_hits)
print(f"standard error: {math.sqrt(variance)/math.sqrt(len(frac_exs_hits))*100}")
#save results