Skip to content

Commit a034a36

Browse files
committed
Minor fixes to the evaluation
1 parent b144be2 commit a034a36

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

syncode/evaluation/fol_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def run_eval(syncode, out_path: Optional[str]=None, debug_task_id=None):
9898
for task_id, problem in enumerate(problems):
9999
results[task_id] = []
100100
full_prompt = FOLEval._prompt_folio(problem)
101-
completion = syncode.model.generate_batch_completion_grammar(
101+
completion = syncode.model.generate_grammar_constrained_completion(
102102
full_prompt,
103103
syncode.num_samples,
104104
stop_words=['\n\n', '------']

syncode/evaluation/json_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def run_eval_for_task(syncode, num_samples_per_task, problem, samples, pbar, tas
7070

7171
prompt = syncode.model.tokenizer.apply_chat_template(problem["prompt"], tokenize = False)
7272

73-
batch_completions = syncode.model.generate_batch_completion_grammar(prompt, num_samples_per_task)
73+
batch_completions = syncode.model.generate_grammar_constrained_completion(prompt, num_samples_per_task)
7474
for completion_id, completion in enumerate(batch_completions):
7575
result = dict(
7676
task_id = task_id,

syncode/evaluation/math_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def run_math_eval(syncode, out_path: Optional[str], debug_task_id=None, logger=c
2222

2323
for task_id, problem in enumerate(problems):
2424
results[task_id] = []
25-
batch_completions = syncode.model.generate_batch_completion_grammar(
25+
batch_completions = syncode.model.generate_grammar_constrained_completion(
2626
problem['question'],
2727
syncode.num_samples
2828
)

tests/test_language_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def get_vocab(self) -> Dict[str, int]:
5555
return {v: i for i, v in enumerate(self.vocab)}
5656

5757
class TestHuggingFaceModel(unittest.TestCase):
58-
def test_generate_batch_completion_grammar(self):
58+
def test_generate_grammar_constrained_completion(self):
5959
torch.manual_seed(0)
6060
model = TestModel()
6161
tokenizer = TestTokenizer()
@@ -65,7 +65,7 @@ def test_generate_batch_completion_grammar(self):
6565
output = lm.generate_grammar_constrained_completion(prompt, 1)
6666
self.assertEqual(len(output[0]), 15, "The output length does not match the expected value.")
6767

68-
def test_generate_batch_completion_grammar2(self):
68+
def test_generate_grammar_constrained_completion2(self):
6969
torch.manual_seed(0)
7070
model = TestModel()
7171
tokenizer = TestTokenizer()

0 commit comments

Comments
 (0)