Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 238 additions & 13 deletions notebooks/tests/builtin_grammar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,9 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"outputs": [],
"source": [
"import torch\n",
"from syncode import SyncodeLogitsProcessor\n",
Expand Down Expand Up @@ -75,8 +66,6 @@
}
],
"source": [
"# grammar_str = \"python\"\n",
"# grammar_str = \"go\"\n",
"grammar_str = \"java\"\n",
"\n",
"grammar = Grammar(grammar_str)\n",
Expand Down Expand Up @@ -105,6 +94,242 @@
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
"print(\"[OUTPUT]\", output_str)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Write a python function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
" \n",
"\n",
"--------------------------------------------------\n",
"Parsing failed! Falling back to unconstrained decoding.\n",
"Exception: Unexpected token Token('NAME', 'simple') at line 3, column 11.\n",
"Expected one of: \n",
"\t* __ANON_9\n",
"\t* __ANON_21\n",
"\t* AMPERSAND\n",
"\t* RPAR\n",
"\t* __ANON_4\n",
"\t* LESSTHAN\n",
"\t* IF\n",
"\t* STAR\n",
"\t* RSQB\n",
"\t* __ANON_5\n",
"\t* __ANON_17\n",
"\t* LSQB\n",
"\t* SLASH\n",
"\t* MINUS\n",
"\t* VBAR\n",
"\t* _NL\n",
"\t* FROM\n",
"\t* __ANON_20\n",
"\t* EQUAL\n",
"\t* __ANON_22\n",
"\t* __ANON_13\n",
"\t* OR\n",
"\t* SEMICOLON\n",
"\t* PLUS\n",
"\t* LPAR\n",
"\t* CIRCUMFLEX\n",
"\t* FOR\n",
"\t* __ANON_2\n",
"\t* NOT\n",
"\t* AT\n",
"\t* __ANON_10\n",
"\t* COMMA\n",
"\t* __ANON_18\n",
"\t* COLON\n",
"\t* MORETHAN\n",
"\t* AS\n",
"\t* __ANON_6\n",
"\t* ELSE\n",
"\t* __ANON_16\n",
"\t* __ANON_11\n",
"\t* DOT\n",
"\t* IN\n",
"\t* __ANON_7\n",
"\t* ASYNC\n",
"\t* IS\n",
"\t* RBRACE\n",
"\t* __ANON_8\n",
"\t* __ANON_3\n",
"\t* AND\n",
"\t* __ANON_15\n",
"\t* __ANON_19\n",
"\t* __ANON_14\n",
"\t* PERCENT\n",
"\t* __ANON_12\n",
"\t* __ANON_1\n",
"\n",
"Partial code: ### Printing 'Hello World' in Reverse\n",
"\n",
"Here is a simple Python\n",
"Parsed lexical tokens: [Token('_NL', \"### Printing 'Hello World' in Reverse\\n\\n\"), Token('NAME', 'Here'), Token('IS', 'is'), Token('NAME', 'a'), Token('NAME', 'simple')]\n",
"--------------------------------------------------\n",
"[OUTPUT] ### Printing 'Hello World' in Reverse\n",
"\n",
"Here is a simple Python function that prints 'Hello World' in reverse:\n",
"\n",
"```python\n",
"def print_hello_world_reverse():\n",
" \"\"\"\n",
" Prints 'Hello World' in reverse.\n",
" \"\"\"\n",
" print(\"Hello World\")\n",
"\n",
"# Example usage:\n",
"print_hello_world_reverse()\n",
"```\n",
"\n",
"When you run this code, it will output:\n",
"```\n",
"olleH dlroW\n",
"```\n",
"\n",
"Alternatively, you can also use slicing to reverse the string:\n",
"\n",
"```python\n",
"def print_hello_world_reverse():\n",
" \"\"\"\n",
" Prints 'Hello World' in reverse.\n",
" \"\"\"\n",
" print(\" \".join([\"H\", \"e\", \"l\", \"l\", \"o\", \" \", \"W\", \"o\", \"r\", \"l\", \"d\"])\n",
"\n",
"# Example usage:\n",
"print_hello_world_reverse()\n",
"```\n",
"\n",
"This will output:\n",
"```\n",
"olleH dlroW\n",
"```\n",
"\n",
"Note: The `join()` method is used to concatenate the elements of a list into a single string, which is then printed.\n"
]
}
],
"source": [
"grammar_str = \"python\"\n",
"\n",
"grammar = Grammar(grammar_str)\n",
"syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n",
"\n",
"prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n",
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
"prompt = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
" )\n",
"print(\"[PROMPT]\", prompt, \"\\n\")\n",
"\n",
"syncode_logits_processor.reset(prompt)\n",
"\n",
"inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
"\n",
"attention_mask = torch.ones_like(inputs)\n",
"output = model.generate(\n",
" inputs,\n",
" attention_mask=attention_mask,\n",
" max_length=512, \n",
" num_return_sequences=1, \n",
" pad_token_id=tokenizer.eos_token_id, \n",
" logits_processor=[syncode_logits_processor]\n",
" )\n",
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
"print(\"[OUTPUT]\", output_str)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Write a go function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
" \n",
"\n",
"[OUTPUT] // \n",
"\n",
"package main\n",
"\n",
"import (\n",
" \"fmt\" // Import the fmt package\n",
")\n",
"\n",
"// Function to print 'hello world' in reverse\n",
"func printHelloWorld() {\n",
" // Declare a variable to hold the string 'hello world'\n",
" var s string = \"hello world\" // Define the string\n",
" // Use string reverse() to reverse the string\n",
" var reversed string = strings.Reverses(s) // Reverse the string\n",
" // Print the reversed string\n",
" fmt.Println(reversed) // Print the reversed string\n",
"}\n",
"\n",
"func main() {\n",
" // Call the function to print 'hello world' in reverse\n",
" printHelloWorld() // Call the function\n",
"} \n",
"\n",
"// Note: The string reverse() function in Go returns a string slice, not a string.\n",
"// If you want to convert the string slice to a string, you can use the string slice's string() method.\n",
"// Here, we are using the Reverses() function which returns a string slice.\n"
]
}
],
"source": [
"grammar_str = \"go\"\n",
"\n",
"grammar = Grammar(grammar_str)\n",
"syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n",
"\n",
"prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n",
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
"prompt = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
" )\n",
"print(\"[PROMPT]\", prompt, \"\\n\")\n",
"\n",
"syncode_logits_processor.reset(prompt)\n",
"\n",
"inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
"\n",
"attention_mask = torch.ones_like(inputs)\n",
"output = model.generate(\n",
" inputs,\n",
" attention_mask=attention_mask,\n",
" max_length=512, \n",
" num_return_sequences=1, \n",
" pad_token_id=tokenizer.eos_token_id, \n",
" logits_processor=[syncode_logits_processor]\n",
" )\n",
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
"print(\"[OUTPUT]\", output_str)"
]
}
],
"metadata": {
Expand Down
25 changes: 19 additions & 6 deletions syncode/parsers/python_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
self.dedent_queue.append(token)
continue
else:
self.parsed_lexer_tokens.append(token) # parser_token_seq holds all tokens except _INDENT and _DEDENT
self.parsed_lexer_tokens.append(token) # parsed_token_seq holds all tokens except _INDENT and _DEDENT

while not len(self.dedent_queue)==0: # Shoot all the dedent tokens that are in the queue
self.indent_level.pop()
dedent_token = self.dedent_queue.pop()
interactive.feed_token(dedent_token)
self.cur_ac_terminals, self.next_ac_terminals = self.next_ac_terminals, self._accepts(interactive)

interactive.feed_token(token)

Expand All @@ -83,8 +84,11 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:

# Compute current terminal string
remainder_state, current_term_str, final_terminal = self._get_remainder(partial_code, lexing_incomplete=lexing_incomplete, parse_incomplete=parse_incomplete)


cur_ac_terminals = self.cur_ac_terminals
next_ac_terminals = self.next_ac_terminals
next_ac_indents = None

if remainder_state == RemainderState.MAYBE_COMPLETE or remainder_state == RemainderState.COMPLETE:
if len(self.parsed_lexer_tokens) > 0 and self.parsed_lexer_tokens[-1].type == '_NL':
last_indent_str = self.parsed_lexer_tokens[-1].value.split('\n')[-1]
Expand All @@ -99,10 +103,19 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
next_ac_indents = IndentationConstraint(accept_indents=next_ac_indents)

# '_NL' is always accepted in this case
self.cur_ac_terminals.add('_NL')
self.next_ac_terminals.add('_NL')

return ParseResult.from_accept_terminals(self.cur_ac_terminals, self.next_ac_terminals, current_term_str, remainder_state, next_ac_indents=next_ac_indents, final_terminal=final_terminal, ignore_terminals=self.base_parser.lexer_conf.ignore)
cur_ac_terminals.add('_NL')
next_ac_terminals.add('_NL')

# feed _DEDENT tokens in the interactive parser
# See test_grammar_python.test_parser25
while not len(self.dedent_queue)==0 and '_DEDENT' in self.next_ac_terminals:
dedent_token = self.dedent_queue.pop()
interactive.feed_token(dedent_token)
self.cur_ac_terminals = self.next_ac_terminals
self.next_ac_terminals = self._accepts(interactive)
next_ac_terminals |= self.next_ac_terminals

return ParseResult.from_accept_terminals(cur_ac_terminals, next_ac_terminals, current_term_str, remainder_state, next_ac_indents=next_ac_indents, final_terminal=final_terminal, ignore_terminals=self.base_parser.lexer_conf.ignore)

def _update_indent_levels(self, indent_level, indent):
# if self.cur_pos != len(lexer_tokens): # Store previous indentation levels except the last one
Expand Down
29 changes: 25 additions & 4 deletions tests/test_grammar_python.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import unittest
import sys, os
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
from syncode.parsers.python_parser import PythonIncrementalParser
from syncode.parsers import create_parser
import syncode.common
from transformers import (
LlamaTokenizer,
)
Expand Down Expand Up @@ -317,8 +315,31 @@ def test_parser24(self):
print(r)
assert r.remainder == 'i'
assert AcceptSequence(['IN']) in r.accept_sequences
# TODO: FIX THIS TEST.
# assert r.remainder_state == RemainderState.INCOMPLETE

def test_parser25(self):
inc_parser.reset()
partial_code = "def foo(string: str):\n\tif string == 'hello':\n\t\treturn 'world'\n\t"
r = inc_parser.get_acceptable_next_terminals(partial_code)
assert AcceptSequence(['_NL', 'ELSE']) in r.accept_sequences
assert AcceptSequence(['_NL', 'ELIF']) in r.accept_sequences

inc_parser.reset()
partial_code = "def foo(string1: str, string2: str):\n\tif string1 == 'hello':\n\t\tif string2 == 'world':\n\t\t\treturn 'world'\n\t\t"
r = inc_parser.get_acceptable_next_terminals(partial_code)
assert AcceptSequence(['_NL', 'ELSE']) in r.accept_sequences
assert AcceptSequence(['_NL', 'ELIF']) in r.accept_sequences

inc_parser.reset()
partial_code = "def foo(string1: str, string2: str):\n\tif string1 == 'hello':\n\t\tif string2 == 'world':\n\t\t\treturn 'world'\n\t\t\t"
r = inc_parser.get_acceptable_next_terminals(partial_code)
assert not AcceptSequence(['_NL', 'ELSE']) in r.accept_sequences
assert not AcceptSequence(['_NL', 'ELIF']) in r.accept_sequences

inc_parser.reset()
partial_code = "def foo(string1: str, string2: str):\n\tif string1 == 'hello':\n\t\tif string2 == 'world':\n\t\t\treturn 'world'\n\t\telse"
r = inc_parser.get_acceptable_next_terminals(partial_code)
assert AcceptSequence(['ELSE', 'COLON']) in r.accept_sequences


if __name__ == "__main__":
unittest.main()
Loading