Skip to content

Commit db2c121

Browse files
committed
Fix Python parsing issue
1 parent 2c542d2 commit db2c121

File tree

4 files changed

+160
-13
lines changed

4 files changed

+160
-13
lines changed

notebooks/tests/python.ipynb

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stderr",
10+
"output_type": "stream",
11+
"text": [
12+
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13+
" from .autonotebook import tqdm as notebook_tqdm\n",
14+
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 3.26it/s]\n"
15+
]
16+
}
17+
],
18+
"source": [
19+
"from syncode.grammar_decoder import SyncodeLogitsProcessor\n",
20+
"from syncode.parsers.grammars import Grammar\n",
21+
"import torch\n",
22+
"from transformers import AutoModelForCausalLM\n",
23+
"from transformers import AutoTokenizer\n",
24+
"\n",
25+
"# Step 1. Load model and tokenizer\n",
26+
"model_name = \"microsoft/phi-2\"\n",
27+
"model = AutoModelForCausalLM.from_pretrained(model_name)\n",
28+
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": 2,
34+
"metadata": {},
35+
"outputs": [
36+
{
37+
"name": "stderr",
38+
"output_type": "stream",
39+
"text": [
40+
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
41+
]
42+
},
43+
{
44+
"name": "stdout",
45+
"output_type": "stream",
46+
"text": [
47+
"Generated token id: 25 (:)\n",
48+
"Score of this token: 28.345928192138672\n",
49+
"Colon token id: 25 (:)\n",
50+
"Score of colon token: 28.345928192138672\n"
51+
]
52+
}
53+
],
54+
"source": [
55+
"# Step 2. Set prompt\n",
56+
"prompt = \"\"\"def is_palindrome(n):\n",
57+
" if str(n) == str(n)[::-1]:\n",
58+
" return True\n",
59+
" else\"\"\"\n",
60+
"\n",
61+
"# Step 3. Initialize SyncodeLogitsProcessor\n",
62+
"syncode_processor = SyncodeLogitsProcessor(grammar=Grammar(\"python\"),\n",
63+
" tokenizer=tokenizer,\n",
64+
" parse_output_only=False)\n",
65+
"syncode_processor.reset(prompt)\n",
66+
"\n",
67+
"# Step 4. Generate scores with SyncodeLogitsProcessor\n",
68+
"inputs = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
69+
"outputs = model.generate(inputs,\n",
70+
" attention_mask=torch.ones_like(inputs),\n",
71+
" do_sample=True,\n",
72+
" logits_processor=[syncode_processor],\n",
73+
" return_dict_in_generate=True,\n",
74+
" output_scores=True,\n",
75+
" max_new_tokens=1)\n",
76+
"\n",
77+
"# Step 5. Print scores\n",
78+
"scores = outputs.scores\n",
79+
"generated_id = outputs.sequences[0, -1]\n",
80+
"generated_str = tokenizer.decode(generated_id, skip_special_tokens=True)\n",
81+
"print(f\"Generated token id: {generated_id} ({generated_str})\")\n",
82+
"print(f\"Score of this token: {scores[0][-1][generated_id]}\")\n",
83+
"\n",
84+
"colon_id = tokenizer.encode(':')[0]\n",
85+
"colon_str = tokenizer.decode(colon_id)\n",
86+
"print(f\"Colon token id: {colon_id} ({colon_str})\")\n",
87+
"print(f\"Score of colon token: {scores[0][-1][colon_id]}\")"
88+
]
89+
}
90+
],
91+
"metadata": {
92+
"kernelspec": {
93+
"display_name": "codex",
94+
"language": "python",
95+
"name": "python3"
96+
},
97+
"language_info": {
98+
"codemirror_mode": {
99+
"name": "ipython",
100+
"version": 3
101+
},
102+
"file_extension": ".py",
103+
"mimetype": "text/x-python",
104+
"name": "python",
105+
"nbconvert_exporter": "python",
106+
"pygments_lexer": "ipython3",
107+
"version": "3.11.4"
108+
}
109+
},
110+
"nbformat": 4,
111+
"nbformat_minor": 2
112+
}

syncode/evaluation/sql_eval.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Optional
44
from tqdm import tqdm
55
from mxeval.data import write_jsonl
6-
6+
from datasets import load_dataset
77

88
class SQLEval:
99
"""
@@ -36,8 +36,9 @@ def run_eval(syncode, out_path: Optional[str], num_tasks: Optional[int]=None, de
3636
for task_id, problem in enumerate(problems):
3737
results[task_id] = []
3838
start_time = time.time()
39-
batch_completions = syncode.model.generate_batch_completion_grammar(
40-
problem['prompt'],
39+
prompt = problem['prompt']
40+
batch_completions = syncode.model.generate_grammar_constrained_completion(
41+
prompt,
4142
syncode.num_samples
4243
)
4344
end_time = time.time()

syncode/parsers/python_parser.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,13 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
5858
self.dedent_queue.append(token)
5959
continue
6060
else:
61-
self.parsed_lexer_tokens.append(token) # parser_token_seq holds all tokens except _INDENT and _DEDENT
61+
self.parsed_lexer_tokens.append(token) # parsed_token_seq holds all tokens except _INDENT and _DEDENT
6262

6363
while not len(self.dedent_queue)==0: # Shoot all the dedent tokens that are in the queue
6464
self.indent_level.pop()
6565
dedent_token = self.dedent_queue.pop()
6666
interactive.feed_token(dedent_token)
67+
self.cur_ac_terminals, self.next_ac_terminals = self.next_ac_terminals, self._accepts(interactive)
6768

6869
interactive.feed_token(token)
6970

@@ -83,8 +84,11 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
8384

8485
# Compute current terminal string
8586
remainder_state, current_term_str, final_terminal = self._get_remainder(partial_code, lexing_incomplete=lexing_incomplete, parse_incomplete=parse_incomplete)
86-
87+
88+
cur_ac_terminals = self.cur_ac_terminals
89+
next_ac_terminals = self.next_ac_terminals
8790
next_ac_indents = None
91+
8892
if remainder_state == RemainderState.MAYBE_COMPLETE or remainder_state == RemainderState.COMPLETE:
8993
if len(self.parsed_lexer_tokens) > 0 and self.parsed_lexer_tokens[-1].type == '_NL':
9094
last_indent_str = self.parsed_lexer_tokens[-1].value.split('\n')[-1]
@@ -99,10 +103,19 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
99103
next_ac_indents = IndentationConstraint(accept_indents=next_ac_indents)
100104

101105
# '_NL' is always accepted in this case
102-
self.cur_ac_terminals.add('_NL')
103-
self.next_ac_terminals.add('_NL')
104-
105-
return ParseResult.from_accept_terminals(self.cur_ac_terminals, self.next_ac_terminals, current_term_str, remainder_state, next_ac_indents=next_ac_indents, final_terminal=final_terminal, ignore_terminals=self.base_parser.lexer_conf.ignore)
106+
cur_ac_terminals.add('_NL')
107+
next_ac_terminals.add('_NL')
108+
109+
# feed _DEDENT tokens in the interactive parser
110+
# See test_grammar_python.test_parser25
111+
while not len(self.dedent_queue)==0 and '_DEDENT' in self.next_ac_terminals:
112+
dedent_token = self.dedent_queue.pop()
113+
interactive.feed_token(dedent_token)
114+
self.cur_ac_terminals = self.next_ac_terminals
115+
self.next_ac_terminals = self._accepts(interactive)
116+
next_ac_terminals |= self.next_ac_terminals
117+
118+
return ParseResult.from_accept_terminals(cur_ac_terminals, next_ac_terminals, current_term_str, remainder_state, next_ac_indents=next_ac_indents, final_terminal=final_terminal, ignore_terminals=self.base_parser.lexer_conf.ignore)
106119

107120
def _update_indent_levels(self, indent_level, indent):
108121
# if self.cur_pos != len(lexer_tokens): # Store previous indentation levels except the last one

tests/test_grammar_python.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import unittest
22
import sys, os
33
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
4-
from syncode.parsers.python_parser import PythonIncrementalParser
54
from syncode.parsers import create_parser
6-
import syncode.common
75
from transformers import (
86
LlamaTokenizer,
97
)
@@ -317,8 +315,31 @@ def test_parser24(self):
317315
print(r)
318316
assert r.remainder == 'i'
319317
assert AcceptSequence(['IN']) in r.accept_sequences
320-
# TODO: FIX THIS TEST.
321-
# assert r.remainder_state == RemainderState.INCOMPLETE
318+
319+
def test_parser25(self):
320+
inc_parser.reset()
321+
partial_code = "def foo(string: str):\n\tif string == 'hello':\n\t\treturn 'world'\n\t"
322+
r = inc_parser.get_acceptable_next_terminals(partial_code)
323+
assert AcceptSequence(['_NL', 'ELSE']) in r.accept_sequences
324+
assert AcceptSequence(['_NL', 'ELIF']) in r.accept_sequences
325+
326+
inc_parser.reset()
327+
partial_code = "def foo(string1: str, string2: str):\n\tif string1 == 'hello':\n\t\tif string2 == 'world':\n\t\t\treturn 'world'\n\t\t"
328+
r = inc_parser.get_acceptable_next_terminals(partial_code)
329+
assert AcceptSequence(['_NL', 'ELSE']) in r.accept_sequences
330+
assert AcceptSequence(['_NL', 'ELIF']) in r.accept_sequences
331+
332+
inc_parser.reset()
333+
partial_code = "def foo(string1: str, string2: str):\n\tif string1 == 'hello':\n\t\tif string2 == 'world':\n\t\t\treturn 'world'\n\t\t\t"
334+
r = inc_parser.get_acceptable_next_terminals(partial_code)
335+
assert not AcceptSequence(['_NL', 'ELSE']) in r.accept_sequences
336+
assert not AcceptSequence(['_NL', 'ELIF']) in r.accept_sequences
337+
338+
inc_parser.reset()
339+
partial_code = "def foo(string1: str, string2: str):\n\tif string1 == 'hello':\n\t\tif string2 == 'world':\n\t\t\treturn 'world'\n\t\telse"
340+
r = inc_parser.get_acceptable_next_terminals(partial_code)
341+
assert AcceptSequence(['ELSE', 'COLON']) in r.accept_sequences
342+
322343

323344
if __name__ == "__main__":
324345
unittest.main()

0 commit comments

Comments
 (0)