Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions notebooks/example_python.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,40 @@
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"id": "23530ae1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading Lark base parser from cache: cache/parsers/python_lalr_parser.pkl\n"
"[2025-06-03 16:10:30,358-root] - Loading model meta-llama/Llama-3.2-1B with device:cuda, device_map:auto, torch_dtype:torch.bfloat16\n",
"[2025-06-03 16:10:31,670-root] - Loading model meta-llama/Llama-3.2-1B with device:cuda, device_map:auto, torch_dtype:torch.bfloat16\n"
]
}
],
"source": [
"import sys\n",
"sys.path.append('..') # Assuming we are in the root directory\n",
"from syncode import Syncode\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"model_name = \"WizardLM/WizardCoder-1B-V1.0\"\n",
"model_name = \"meta-llama/Llama-3.2-1B\"\n",
"\n",
"# Load the unconstrained original model\n",
"llm = Syncode(model = model_name, mode='original', max_new_tokens=200)\n",
"\n",
"# Load the Syncode augmented model\n",
"syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar='python', parse_output_only=False)"
"syn_llm = Syncode(\n",
" model = model_name, \n",
" mode='grammar_mask', \n",
" grammar='python', \n",
" parse_output_only=False,\n",
" indent=True,\n",
" opp=False\n",
" )"
]
},
{
Expand All @@ -39,10 +49,17 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 8,
"id": "490cddb3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand All @@ -51,10 +68,11 @@
" '''Return if prime'''\n",
" if n < 2:\n",
" return False\n",
" for i in range(2, int(n**0.5)+1):\n",
" for i in range(2, n):\n",
" if n % i == 0:\n",
" return False\n",
" return True\n"
" return True\n",
"\n"
]
},
{
Expand All @@ -64,7 +82,7 @@
"traceback": [
"Traceback \u001b[0;36m(most recent call last)\u001b[0m:\n",
"\u001b[0m File \u001b[1;32m~/anaconda3/envs/codex/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3577\u001b[0m in \u001b[1;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\u001b[0m\n",
"\u001b[0;36m Cell \u001b[0;32mIn[16], line 4\u001b[0;36m\n\u001b[0;31m exec(output)\u001b[0;36m\n",
"\u001b[0;36m Cell \u001b[0;32mIn[8], line 4\u001b[0;36m\n\u001b[0;31m exec(output)\u001b[0;36m\n",
"\u001b[0;36m File \u001b[0;32m<string>:3\u001b[0;36m\u001b[0m\n\u001b[0;31m if n < 2:\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m unindent does not match any outer indentation level\n"
]
}
Expand All @@ -78,22 +96,25 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 13,
"id": "76cd93f5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"def is_prime(n):\n",
" '''Return if prime'''\n",
" if n < 2:\n",
" return False\n",
" for i in range(2, int(n**0.5) + 1):\n",
" if n % i == 0:\n",
" return False\n",
" return True\n"
" return n > 1 and all(n % i!= 0 for i in range(2, int(n**0.5) + 1))\n",
"\n"
]
}
],
Expand Down
8 changes: 5 additions & 3 deletions syncode/grammar_mask/grammar_constrainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import syncode.common as common
from transformers import LogitsProcessor, PreTrainedTokenizer
from transformers import PreTrainedTokenizer
from syncode.mask_store.byte_tokenizer import ByteTokenizer
from syncode.parse_result import AcceptSequence, RemainderState
from syncode.parsers.incremental_parser import IncrementalParser, ParseResult
Expand Down Expand Up @@ -53,7 +52,9 @@ def __init__(self,
batch_size=1,
dev_mode=False,
parser='lalr',
mode='grammar_mask'):
mode='grammar_mask',
indent=False
):

self.tokenizer = tokenizer
self.byte_tokenizer = byte_tokenizer
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(self,
tokenizer=self.tokenizer,
use_cache=use_cache,
mode=mode, # Controls approximation strategy for token masking
indent=indent
)


Expand Down
7 changes: 5 additions & 2 deletions syncode/grammar_mask/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def __init__(self,
num_samples=1,
dev_mode=False,
parser='lalr',
mode='grammar_mask'):
mode='grammar_mask',
indent=False
):

self.tokenizer = tokenizer
self.byte_tokenizer = ByteTokenizer(tokenizer)
Expand All @@ -44,7 +46,8 @@ def __init__(self,
batch_size=num_samples,
dev_mode=dev_mode,
parser=parser,
mode=mode
mode=mode,
indent=indent
)

def reset(self):
Expand Down
2 changes: 2 additions & 0 deletions syncode/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
seed: Optional[int] = None,
opp: bool = True,
device_map: Optional[str] = None,
indent: bool = False,
**kwargs
):
# Check inputs
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
dev_mode=dev_mode,
parser=parser,
mode=mode,
indent=indent
)

# Set default max new tokens if not provided
Expand Down
9 changes: 6 additions & 3 deletions syncode/mask_store/mask_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def __init__(self,
self._vocab,
eos_token_id=self.eos_token_id,
special_token_ids=self.special_token_ids,
indent=indent, mode=mode
indent=indent,
mode=mode
)
terminal_names = [terminal.name for terminal in terminals]

Expand Down Expand Up @@ -106,8 +107,10 @@ def init_mask_store(
if use_cache and os.path.exists(fsm_path):
try:
with open(fsm_path, 'rb') as f:
mask_store = pickle.load(f)
return mask_store
mask_store: MaskStore = pickle.load(f)
if mask_store.indentation == indent:
return mask_store

except Exception as e:
logger.warning(f"Error loading mask store: {e}")

Expand Down
15 changes: 13 additions & 2 deletions syncode/parsers/incremental_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,20 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
self._handle_parsing_error(lexer_tokens, token, e)

# Compute current terminal string
remainder_state, current_term_str, final_terminal = self._get_remainder(partial_code, lexing_incomplete=lexing_incomplete, parse_incomplete=parse_incomplete)
remainder_state, current_term_str, final_terminal = self._get_remainder(
partial_code,
lexing_incomplete=lexing_incomplete,
parse_incomplete=parse_incomplete
)

return ParseResult.from_accept_terminals(self.cur_ac_terminals, self.next_ac_terminals, current_term_str, remainder_state, final_terminal=final_terminal, ignore_terminals=self.base_parser.lexer_conf.ignore)
return ParseResult.from_accept_terminals(
self.cur_ac_terminals,
self.next_ac_terminals,
current_term_str,
remainder_state,
final_terminal=final_terminal,
ignore_terminals=self.base_parser.lexer_conf.ignore)


def _get_remainder(self, code, lexing_incomplete=False, parse_incomplete=False):
final_terminal = None
Expand Down
8 changes: 7 additions & 1 deletion syncode/parsers/python_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,21 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
remainder_state, final_terminal = None, None

# Compute current terminal string
remainder_state, current_term_str, final_terminal = self._get_remainder(partial_code, lexing_incomplete=lexing_incomplete, parse_incomplete=parse_incomplete)
remainder_state, current_term_str, final_terminal = self._get_remainder(
partial_code,
lexing_incomplete=lexing_incomplete,
parse_incomplete=parse_incomplete
)

cur_ac_terminals = self.cur_ac_terminals
next_ac_terminals = self.next_ac_terminals
next_ac_indents = None

if remainder_state == RemainderState.MAYBE_COMPLETE or remainder_state == RemainderState.COMPLETE:
if len(self.parsed_lexer_tokens) > 0 and self.parsed_lexer_tokens[-1].type == '_NL':
# Calculate the last indetation level
last_indent_str = self.parsed_lexer_tokens[-1].value.split('\n')[-1]

last_indent = last_indent_str.count(' ') + last_indent_str.count('\t') * self.tab_len
next_ac_indents = [indent-last_indent for indent in self.indent_level if indent >= last_indent]

Expand Down