From 0f7c50167652c25d847f39cbd398047ec7926aed Mon Sep 17 00:00:00 2001 From: Shubham Ugare Date: Tue, 3 Jun 2025 16:12:14 -0500 Subject: [PATCH] Update the model for the Python notebook --- notebooks/example_python.ipynb | 51 +++++++++++++++------ syncode/grammar_mask/grammar_constrainer.py | 8 ++-- syncode/grammar_mask/logits_processor.py | 7 ++- syncode/infer.py | 2 + syncode/mask_store/mask_store.py | 9 ++-- syncode/parsers/incremental_parser.py | 15 +++++- syncode/parsers/python_parser.py | 8 +++- 7 files changed, 74 insertions(+), 26 deletions(-) diff --git a/notebooks/example_python.ipynb b/notebooks/example_python.ipynb index 197a26d4..dd9d9bc5 100644 --- a/notebooks/example_python.ipynb +++ b/notebooks/example_python.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "23530ae1", "metadata": {}, "outputs": [ @@ -10,22 +10,32 @@ "name": "stdout", "output_type": "stream", "text": [ - "Loading Lark base parser from cache: cache/parsers/python_lalr_parser.pkl\n" + "[2025-06-03 16:10:30,358-root] - Loading model meta-llama/Llama-3.2-1B with device:cuda, device_map:auto, torch_dtype:torch.bfloat16\n", + "[2025-06-03 16:10:31,670-root] - Loading model meta-llama/Llama-3.2-1B with device:cuda, device_map:auto, torch_dtype:torch.bfloat16\n" ] } ], "source": [ + "import sys\n", + "sys.path.append('..') # Assuming we are in the root directory\n", "from syncode import Syncode\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", - "model_name = \"WizardLM/WizardCoder-1B-V1.0\"\n", + "model_name = \"meta-llama/Llama-3.2-1B\"\n", "\n", "# Load the unconstrained original model\n", "llm = Syncode(model = model_name, mode='original', max_new_tokens=200)\n", "\n", "# Load the Syncode augmented model\n", - "syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar='python', parse_output_only=False)" + "syn_llm = Syncode(\n", + " model = model_name, \n", + " mode='grammar_mask', \n", + " grammar='python', \n", + " parse_output_only=False,\n", + " indent=True,\n", + " opp=False\n", + " )" ] }, { @@ -39,10 +49,17 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 8, "id": "490cddb3", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -51,10 +68,11 @@ " '''Return if prime'''\n", " if n < 2:\n", " return False\n", - " for i in range(2, int(n**0.5)+1):\n", + " for i in range(2, n):\n", " if n % i == 0:\n", " return False\n", - " return True\n" + " return True\n", + "\n" ] }, { @@ -64,7 +82,7 @@ "traceback": [ "Traceback \u001b[0;36m(most recent call last)\u001b[0m:\n", "\u001b[0m File \u001b[1;32m~/anaconda3/envs/codex/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3577\u001b[0m in \u001b[1;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\u001b[0m\n", - "\u001b[0;36m Cell \u001b[0;32mIn[16], line 4\u001b[0;36m\n\u001b[0;31m exec(output)\u001b[0;36m\n", + "\u001b[0;36m Cell \u001b[0;32mIn[8], line 4\u001b[0;36m\n\u001b[0;31m exec(output)\u001b[0;36m\n", "\u001b[0;36m File \u001b[0;32m:3\u001b[0;36m\u001b[0m\n\u001b[0;31m if n < 2:\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m unindent does not match any outer indentation level\n" ] } @@ -78,22 +96,25 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "id": "76cd93f5", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ "def is_prime(n):\n", " '''Return if prime'''\n", - " if n < 2:\n", - " return False\n", - " for i in range(2, int(n**0.5) + 1):\n", - " if n % i == 0:\n", - " return False\n", - " return True\n" + " return n > 1 and all(n % i!= 0 for i in range(2, int(n**0.5) + 1))\n", + "\n" ] } ], diff --git a/syncode/grammar_mask/grammar_constrainer.py b/syncode/grammar_mask/grammar_constrainer.py index 08ee4c17..461cb9c3 100644 --- a/syncode/grammar_mask/grammar_constrainer.py +++ b/syncode/grammar_mask/grammar_constrainer.py @@ -1,6 +1,5 @@ import torch -import syncode.common as common -from transformers import LogitsProcessor, PreTrainedTokenizer +from transformers import PreTrainedTokenizer from syncode.mask_store.byte_tokenizer import ByteTokenizer from syncode.parse_result import AcceptSequence, RemainderState from syncode.parsers.incremental_parser import IncrementalParser, ParseResult @@ -53,7 +52,9 @@ def __init__(self, batch_size=1, dev_mode=False, parser='lalr', - mode='grammar_mask'): + mode='grammar_mask', + indent=False + ): self.tokenizer = tokenizer self.byte_tokenizer = byte_tokenizer @@ -82,6 +83,7 @@ def __init__(self, tokenizer=self.tokenizer, use_cache=use_cache, mode=mode, # Controls approximation strategy for token masking + indent=indent ) diff --git a/syncode/grammar_mask/logits_processor.py b/syncode/grammar_mask/logits_processor.py index cfb1c370..0e29d8e8 100644 --- a/syncode/grammar_mask/logits_processor.py +++ b/syncode/grammar_mask/logits_processor.py @@ -29,7 +29,9 @@ def __init__(self, num_samples=1, dev_mode=False, parser='lalr', - mode='grammar_mask'): + mode='grammar_mask', + indent=False + ): self.tokenizer = tokenizer self.byte_tokenizer = ByteTokenizer(tokenizer) @@ -44,7 +46,8 @@ def __init__(self, batch_size=num_samples, dev_mode=dev_mode, parser=parser, - mode=mode + mode=mode, + indent=indent ) def reset(self): diff --git a/syncode/infer.py b/syncode/infer.py index deaa0bac..9a03ef26 100644 --- a/syncode/infer.py +++ b/syncode/infer.py @@ -51,6 +51,7 @@ def __init__( seed: Optional[int] = None, opp: bool = True, device_map: Optional[str] = None, + indent: bool = False, **kwargs ): # Check inputs @@ -102,6 +103,7 @@ def __init__( dev_mode=dev_mode, parser=parser, mode=mode, + indent=indent ) # Set default max new tokens if not provided diff --git a/syncode/mask_store/mask_store.py b/syncode/mask_store/mask_store.py index 65cfdbc4..4ebea22c 100644 --- a/syncode/mask_store/mask_store.py +++ b/syncode/mask_store/mask_store.py @@ -59,7 +59,8 @@ def __init__(self, self._vocab, eos_token_id=self.eos_token_id, special_token_ids=self.special_token_ids, - indent=indent, mode=mode + indent=indent, + mode=mode ) terminal_names = [terminal.name for terminal in terminals] @@ -106,8 +107,10 @@ def init_mask_store( if use_cache and os.path.exists(fsm_path): try: with open(fsm_path, 'rb') as f: - mask_store = pickle.load(f) - return mask_store + mask_store: MaskStore = pickle.load(f) + if mask_store.indentation == indent: + return mask_store + except Exception as e: logger.warning(f"Error loading mask store: {e}") diff --git a/syncode/parsers/incremental_parser.py b/syncode/parsers/incremental_parser.py index 57d238a0..6ab51cf8 100644 --- a/syncode/parsers/incremental_parser.py +++ b/syncode/parsers/incremental_parser.py @@ -166,9 +166,20 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult: self._handle_parsing_error(lexer_tokens, token, e) # Compute current terminal string - remainder_state, current_term_str, final_terminal = self._get_remainder(partial_code, lexing_incomplete=lexing_incomplete, parse_incomplete=parse_incomplete) + remainder_state, current_term_str, final_terminal = self._get_remainder( + partial_code, + lexing_incomplete=lexing_incomplete, + parse_incomplete=parse_incomplete + ) - return ParseResult.from_accept_terminals(self.cur_ac_terminals, self.next_ac_terminals, current_term_str, remainder_state, final_terminal=final_terminal, ignore_terminals=self.base_parser.lexer_conf.ignore) + return ParseResult.from_accept_terminals( + self.cur_ac_terminals, + self.next_ac_terminals, + current_term_str, + remainder_state, + final_terminal=final_terminal, + ignore_terminals=self.base_parser.lexer_conf.ignore) + def _get_remainder(self, code, lexing_incomplete=False, parse_incomplete=False): final_terminal = None diff --git a/syncode/parsers/python_parser.py b/syncode/parsers/python_parser.py index c1ae16e2..285275e7 100644 --- a/syncode/parsers/python_parser.py +++ b/syncode/parsers/python_parser.py @@ -85,7 +85,11 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult: remainder_state, final_terminal = None, None # Compute current terminal string - remainder_state, current_term_str, final_terminal = self._get_remainder(partial_code, lexing_incomplete=lexing_incomplete, parse_incomplete=parse_incomplete) + remainder_state, current_term_str, final_terminal = self._get_remainder( + partial_code, + lexing_incomplete=lexing_incomplete, + parse_incomplete=parse_incomplete + ) cur_ac_terminals = self.cur_ac_terminals next_ac_terminals = self.next_ac_terminals @@ -93,7 +97,9 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult: if remainder_state == RemainderState.MAYBE_COMPLETE or remainder_state == RemainderState.COMPLETE: if len(self.parsed_lexer_tokens) > 0 and self.parsed_lexer_tokens[-1].type == '_NL': + # Calculate the last indetation level last_indent_str = self.parsed_lexer_tokens[-1].value.split('\n')[-1] + last_indent = last_indent_str.count(' ') + last_indent_str.count('\t') * self.tab_len next_ac_indents = [indent-last_indent for indent in self.indent_level if indent >= last_indent]