diff --git a/notebooks/tests/builtin_grammar.ipynb b/notebooks/tests/builtin_grammar.ipynb index 0231f7fc..10fe62be 100644 --- a/notebooks/tests/builtin_grammar.ipynb +++ b/notebooks/tests/builtin_grammar.ipynb @@ -2,18 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "import torch\n", "from syncode import SyncodeLogitsProcessor\n", @@ -75,8 +66,6 @@ } ], "source": [ - "# grammar_str = \"python\"\n", - "# grammar_str = \"go\"\n", "grammar_str = \"java\"\n", "\n", "grammar = Grammar(grammar_str)\n", @@ -105,6 +94,242 @@ "output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n", "print(\"[OUTPUT]\", output_str)" ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n", + "\n", + "Cutting Knowledge Date: December 2023\n", + "Today Date: 26 Jul 2024\n", + "\n", + "<|eot_id|><|start_header_id|>user<|end_header_id|>\n", + "\n", + "Write a python function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", + "\n", + " \n", + "\n", + "--------------------------------------------------\n", + "Parsing failed! Falling back to unconstrained decoding.\n", + "Exception: Unexpected token Token('NAME', 'simple') at line 3, column 11.\n", + "Expected one of: \n", + "\t* __ANON_9\n", + "\t* __ANON_21\n", + "\t* AMPERSAND\n", + "\t* RPAR\n", + "\t* __ANON_4\n", + "\t* LESSTHAN\n", + "\t* IF\n", + "\t* STAR\n", + "\t* RSQB\n", + "\t* __ANON_5\n", + "\t* __ANON_17\n", + "\t* LSQB\n", + "\t* SLASH\n", + "\t* MINUS\n", + "\t* VBAR\n", + "\t* _NL\n", + "\t* FROM\n", + "\t* __ANON_20\n", + "\t* EQUAL\n", + "\t* __ANON_22\n", + "\t* __ANON_13\n", + "\t* OR\n", + "\t* SEMICOLON\n", + "\t* PLUS\n", + "\t* LPAR\n", + "\t* CIRCUMFLEX\n", + "\t* FOR\n", + "\t* __ANON_2\n", + "\t* NOT\n", + "\t* AT\n", + "\t* __ANON_10\n", + "\t* COMMA\n", + "\t* __ANON_18\n", + "\t* COLON\n", + "\t* MORETHAN\n", + "\t* AS\n", + "\t* __ANON_6\n", + "\t* ELSE\n", + "\t* __ANON_16\n", + "\t* __ANON_11\n", + "\t* DOT\n", + "\t* IN\n", + "\t* __ANON_7\n", + "\t* ASYNC\n", + "\t* IS\n", + "\t* RBRACE\n", + "\t* __ANON_8\n", + "\t* __ANON_3\n", + "\t* AND\n", + "\t* __ANON_15\n", + "\t* __ANON_19\n", + "\t* __ANON_14\n", + "\t* PERCENT\n", + "\t* __ANON_12\n", + "\t* __ANON_1\n", + "\n", + "Partial code: ### Printing 'Hello World' in Reverse\n", + "\n", + "Here is a simple Python\n", + "Parsed lexical tokens: [Token('_NL', \"### Printing 'Hello World' in Reverse\\n\\n\"), Token('NAME', 'Here'), Token('IS', 'is'), Token('NAME', 'a'), Token('NAME', 'simple')]\n", + "--------------------------------------------------\n", + "[OUTPUT] ### Printing 'Hello World' in Reverse\n", + "\n", + "Here is a simple Python function that prints 'Hello World' in reverse:\n", + "\n", + "```python\n", + "def print_hello_world_reverse():\n", + " \"\"\"\n", + " Prints 'Hello World' in reverse.\n", + " \"\"\"\n", + " print(\"Hello World\")\n", + "\n", + "# Example usage:\n", + "print_hello_world_reverse()\n", + "```\n", + "\n", + "When you run this code, it will output:\n", + "```\n", + "olleH dlroW\n", + "```\n", + "\n", + "Alternatively, you can also use slicing to reverse the string:\n", + "\n", + "```python\n", + "def print_hello_world_reverse():\n", + " \"\"\"\n", + " Prints 'Hello World' in reverse.\n", + " \"\"\"\n", + " print(\" \".join([\"H\", \"e\", \"l\", \"l\", \"o\", \" \", \"W\", \"o\", \"r\", \"l\", \"d\"])\n", + "\n", + "# Example usage:\n", + "print_hello_world_reverse()\n", + "```\n", + "\n", + "This will output:\n", + "```\n", + "olleH dlroW\n", + "```\n", + "\n", + "Note: The `join()` method is used to concatenate the elements of a list into a single string, which is then printed.\n" + ] + } + ], + "source": [ + "grammar_str = \"python\"\n", + "\n", + "grammar = Grammar(grammar_str)\n", + "syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n", + "\n", + "prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n", + "messages = [{\"role\": \"user\", \"content\": prompt}]\n", + "prompt = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + " )\n", + "print(\"[PROMPT]\", prompt, \"\\n\")\n", + "\n", + "syncode_logits_processor.reset(prompt)\n", + "\n", + "inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n", + "\n", + "attention_mask = torch.ones_like(inputs)\n", + "output = model.generate(\n", + " inputs,\n", + " attention_mask=attention_mask,\n", + " max_length=512, \n", + " num_return_sequences=1, \n", + " pad_token_id=tokenizer.eos_token_id, \n", + " logits_processor=[syncode_logits_processor]\n", + " )\n", + "output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n", + "print(\"[OUTPUT]\", output_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n", + "\n", + "Cutting Knowledge Date: December 2023\n", + "Today Date: 26 Jul 2024\n", + "\n", + "<|eot_id|><|start_header_id|>user<|end_header_id|>\n", + "\n", + "Write a go function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", + "\n", + " \n", + "\n", + "[OUTPUT] // \n", + "\n", + "package main\n", + "\n", + "import (\n", + " \"fmt\" // Import the fmt package\n", + ")\n", + "\n", + "// Function to print 'hello world' in reverse\n", + "func printHelloWorld() {\n", + " // Declare a variable to hold the string 'hello world'\n", + " var s string = \"hello world\" // Define the string\n", + " // Use string reverse() to reverse the string\n", + " var reversed string = strings.Reverses(s) // Reverse the string\n", + " // Print the reversed string\n", + " fmt.Println(reversed) // Print the reversed string\n", + "}\n", + "\n", + "func main() {\n", + " // Call the function to print 'hello world' in reverse\n", + " printHelloWorld() // Call the function\n", + "} \n", + "\n", + "// Note: The string reverse() function in Go returns a string slice, not a string.\n", + "// If you want to convert the string slice to a string, you can use the string slice's string() method.\n", + "// Here, we are using the Reverses() function which returns a string slice.\n" + ] + } + ], + "source": [ + "grammar_str = \"go\"\n", + "\n", + "grammar = Grammar(grammar_str)\n", + "syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n", + "\n", + "prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n", + "messages = [{\"role\": \"user\", \"content\": prompt}]\n", + "prompt = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + " )\n", + "print(\"[PROMPT]\", prompt, \"\\n\")\n", + "\n", + "syncode_logits_processor.reset(prompt)\n", + "\n", + "inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n", + "\n", + "attention_mask = torch.ones_like(inputs)\n", + "output = model.generate(\n", + " inputs,\n", + " attention_mask=attention_mask,\n", + " max_length=512, \n", + " num_return_sequences=1, \n", + " pad_token_id=tokenizer.eos_token_id, \n", + " logits_processor=[syncode_logits_processor]\n", + " )\n", + "output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n", + "print(\"[OUTPUT]\", output_str)" + ] } ], "metadata": { diff --git a/syncode/parsers/python_parser.py b/syncode/parsers/python_parser.py index 01fcfa34..f2a7732e 100644 --- a/syncode/parsers/python_parser.py +++ b/syncode/parsers/python_parser.py @@ -58,12 +58,13 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult: self.dedent_queue.append(token) continue else: - self.parsed_lexer_tokens.append(token) # parser_token_seq holds all tokens except _INDENT and _DEDENT + self.parsed_lexer_tokens.append(token) # parsed_token_seq holds all tokens except _INDENT and _DEDENT while not len(self.dedent_queue)==0: # Shoot all the dedent tokens that are in the queue self.indent_level.pop() dedent_token = self.dedent_queue.pop() interactive.feed_token(dedent_token) + self.cur_ac_terminals, self.next_ac_terminals = self.next_ac_terminals, self._accepts(interactive) interactive.feed_token(token) @@ -83,8 +84,11 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult: # Compute current terminal string remainder_state, current_term_str, final_terminal = self._get_remainder(partial_code, lexing_incomplete=lexing_incomplete, parse_incomplete=parse_incomplete) - + + cur_ac_terminals = self.cur_ac_terminals + next_ac_terminals = self.next_ac_terminals next_ac_indents = None + if remainder_state == RemainderState.MAYBE_COMPLETE or remainder_state == RemainderState.COMPLETE: if len(self.parsed_lexer_tokens) > 0 and self.parsed_lexer_tokens[-1].type == '_NL': last_indent_str = self.parsed_lexer_tokens[-1].value.split('\n')[-1] @@ -99,10 +103,19 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult: next_ac_indents = IndentationConstraint(accept_indents=next_ac_indents) # '_NL' is always accepted in this case - self.cur_ac_terminals.add('_NL') - self.next_ac_terminals.add('_NL') - - return ParseResult.from_accept_terminals(self.cur_ac_terminals, self.next_ac_terminals, current_term_str, remainder_state, next_ac_indents=next_ac_indents, final_terminal=final_terminal, ignore_terminals=self.base_parser.lexer_conf.ignore) + cur_ac_terminals.add('_NL') + next_ac_terminals.add('_NL') + + # feed _DEDENT tokens in the interactive parser + # See test_grammar_python.test_parser25 + while not len(self.dedent_queue)==0 and '_DEDENT' in self.next_ac_terminals: + dedent_token = self.dedent_queue.pop() + interactive.feed_token(dedent_token) + self.cur_ac_terminals = self.next_ac_terminals + self.next_ac_terminals = self._accepts(interactive) + next_ac_terminals |= self.next_ac_terminals + + return ParseResult.from_accept_terminals(cur_ac_terminals, next_ac_terminals, current_term_str, remainder_state, next_ac_indents=next_ac_indents, final_terminal=final_terminal, ignore_terminals=self.base_parser.lexer_conf.ignore) def _update_indent_levels(self, indent_level, indent): # if self.cur_pos != len(lexer_tokens): # Store previous indentation levels except the last one diff --git a/tests/test_grammar_python.py b/tests/test_grammar_python.py index 24b14d75..6316dafa 100644 --- a/tests/test_grammar_python.py +++ b/tests/test_grammar_python.py @@ -1,9 +1,7 @@ import unittest import sys, os sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../') -from syncode.parsers.python_parser import PythonIncrementalParser from syncode.parsers import create_parser -import syncode.common from transformers import ( LlamaTokenizer, ) @@ -317,8 +315,31 @@ def test_parser24(self): print(r) assert r.remainder == 'i' assert AcceptSequence(['IN']) in r.accept_sequences - # TODO: FIX THIS TEST. - # assert r.remainder_state == RemainderState.INCOMPLETE + + def test_parser25(self): + inc_parser.reset() + partial_code = "def foo(string: str):\n\tif string == 'hello':\n\t\treturn 'world'\n\t" + r = inc_parser.get_acceptable_next_terminals(partial_code) + assert AcceptSequence(['_NL', 'ELSE']) in r.accept_sequences + assert AcceptSequence(['_NL', 'ELIF']) in r.accept_sequences + + inc_parser.reset() + partial_code = "def foo(string1: str, string2: str):\n\tif string1 == 'hello':\n\t\tif string2 == 'world':\n\t\t\treturn 'world'\n\t\t" + r = inc_parser.get_acceptable_next_terminals(partial_code) + assert AcceptSequence(['_NL', 'ELSE']) in r.accept_sequences + assert AcceptSequence(['_NL', 'ELIF']) in r.accept_sequences + + inc_parser.reset() + partial_code = "def foo(string1: str, string2: str):\n\tif string1 == 'hello':\n\t\tif string2 == 'world':\n\t\t\treturn 'world'\n\t\t\t" + r = inc_parser.get_acceptable_next_terminals(partial_code) + assert not AcceptSequence(['_NL', 'ELSE']) in r.accept_sequences + assert not AcceptSequence(['_NL', 'ELIF']) in r.accept_sequences + + inc_parser.reset() + partial_code = "def foo(string1: str, string2: str):\n\tif string1 == 'hello':\n\t\tif string2 == 'world':\n\t\t\treturn 'world'\n\t\telse" + r = inc_parser.get_acceptable_next_terminals(partial_code) + assert AcceptSequence(['ELSE', 'COLON']) in r.accept_sequences + if __name__ == "__main__": unittest.main()