Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ SynCode depends on HuggingFace [transformers](https://github.com/huggingface/tra

| SynCode version | Required transformers version | Python version |
| -------------- | ----------------------------- | -------------- |
| `v0.4.13` (latest) | `v4.51.0` | 3.6 - 3.12 |
| `v0.4.14` (latest) | `v4.51.3` | 3.6 - 3.12 |

**Note:** Python 3.13 is not currently supported due to dependency constraints.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "syncode"
version="0.4.13"
version="0.4.14"
requires-python = ">=3.6,<3.13"
description = "Grammar-guided code generation tool"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

setuptools.setup(
name="syncode",
version="0.4.13",
version="0.4.14",
author="Shubham Ugare",
author_email="shubhamugare@gmail.com",
description="This package provides the tool for grammar augmented LLM generation.",
Expand Down
42 changes: 24 additions & 18 deletions syncode/grammar_mask/grammar_constrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
self._set_start_from(input_ids)

input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=-1)
partial_code, remainder_bytes = self._get_partial_codes(input_ids)[0]
partial_output, remainder_bytes = self._get_partial_outputs(input_ids)[0]

res, skip = self._parse_partial_code(
res, skip = self._parse_partial_output(
idx=0,
partial_code=partial_code,
partial_output=partial_output,
remainder_bytes=remainder_bytes,
accepted_generation=False
)
Expand All @@ -142,7 +142,7 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
is_valid = self.dfa_mask_store.is_valid_prefix(res)

if is_valid:
self._update_valid_state(partial_code, 0, res)
self._update_valid_state(partial_output, 0, res)

return is_valid

Expand All @@ -163,11 +163,11 @@ def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) ->
torch.FloatTensor: The masked scores.
"""
self._set_start_from(input_ids) # start_from is used for choosing where the parsing should start
partial_codes = self._get_partial_codes(input_ids)
partial_outputs = self._get_partial_outputs(input_ids)

for idx, (partial_code, remainder_bytes) in enumerate(partial_codes):
for idx, (partial_output, remainder_bytes) in enumerate(partial_outputs):
# 1. Parsing
res, skip = self._parse_partial_code(idx, partial_code, remainder_bytes, accepted_generation=True)
res, skip = self._parse_partial_output(idx, partial_output, remainder_bytes, accepted_generation=True)
if skip: continue

# 2. Computing the accept mask
Expand All @@ -187,23 +187,29 @@ def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) ->

return scores

def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: bytes, accepted_generation=True) -> tuple[ParseResult, bool]:
def _parse_partial_output(
self,
idx: int,
partial_output: str,
remainder_bytes: bytes,
accepted_generation=True
) -> tuple[ParseResult, bool]:
"""
Parse the partial code and return the result.
"""
skip = False
res = None

try:
res = self.inc_parser.get_acceptable_next_terminals(partial_code)
res = self.inc_parser.get_acceptable_next_terminals(partial_output)

if len(remainder_bytes) > 0:
res.remainder_state = RemainderState.INCOMPLETE
res.remainder = res.remainder.encode('utf-8') + remainder_bytes
else:
res.remainder = res.remainder.encode('utf-8')

self._update_valid_state(partial_code, idx, res)
self._update_valid_state(partial_output, idx, res)
except Exception as e:
if self.dev_mode == True and accepted_generation:
logger.info("-"*50)
Expand All @@ -213,45 +219,45 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte
elif self.parse_failed == False and accepted_generation:
self.parse_failed = True
logger.info("-"*50)
logger.info(f"Parsing failed! Falling back to unconstrained decoding.\nException: {e}\nPartial code: {partial_code}\nParsed lexical tokens: {self.inc_parser.parsed_lexer_tokens}")
logger.info(f"Parsing failed! Falling back to unconstrained decoding.\nException: {e}\nPartial code: {partial_output}\nParsed lexical tokens: {self.inc_parser.parsed_lexer_tokens}")
logger.info("-"*50)
skip = True
return res, skip

def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
def _get_partial_outputs(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
"""
Get the partial codes for the input_ids and return the remainder bytes if the partial code is not a valid UTF-8 string.
"""
output = []
for idx in range(len(input_ids)):
if self.parse_output_only:
partial_code, remainder_bytes = self._bytes_to_string(
partial_output, remainder_bytes = self._bytes_to_string(
self.byte_tokenizer.decode(
input_ids[idx, self.start_from:].tolist(), skip_special_tokens=True)
)
else:
partial_code, remainder_bytes = self._bytes_to_string(
partial_output, remainder_bytes = self._bytes_to_string(
self.byte_tokenizer.decode(
input_ids[idx].tolist(), skip_special_tokens=True)
)
output.append((partial_code, remainder_bytes))
output.append((partial_output, remainder_bytes))
return output

def _update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
def _update_valid_state(self, partial_output: str, idx: int, r: ParseResult):
"""
This a simple heuristic to cut off the generated output at the end of the function.
TODO: Put this under a flag to enable/disable this heuristic.
"""
if idx < len(self.function_ends):
if r.function_end: # If the function end is not None, then the last valid state is the function end
if self.function_ends[idx] is None: self.function_ends[idx] = []
self.function_ends[idx].append(len(partial_code) - len(r.remainder))
self.function_ends[idx].append(len(partial_output) - len(r.remainder))

if idx < len(self.last_valid_state):
for accept_seq in r.accept_sequences:
# 'EOF' is special terminal since $END does not work with python
if accept_seq[0] == '$END' or accept_seq[0] == 'EOF':
self.last_valid_state[idx] = len(partial_code) - len(r.remainder)
self.last_valid_state[idx] = len(partial_output) - len(r.remainder)

@staticmethod
def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:
Expand Down
5 changes: 3 additions & 2 deletions syncode/mask_store/byte_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,12 @@ def __init__(self, tokenizer, vocab_type=None):
# Cache special token IDs as a set for faster lookups
self.special_token_ids = set(getattr(tokenizer, "all_special_ids", []))

# NOTE: This seems to be problematic in some cases where regular tokens like "\t" are treated as special tokens
# Added tokens are typically special tokens
# if added_tokens_decoder is not None self.tokenizer.added_tokens_decoder.keys()
# to special_token_ids
if hasattr(tokenizer, "added_tokens_decoder"):
self.special_token_ids.update(tokenizer.added_tokens_decoder.keys())
# if hasattr(tokenizer, "added_tokens_decoder"):
# self.special_token_ids.update(tokenizer.added_tokens_decoder.keys())


@classmethod
Expand Down
4 changes: 2 additions & 2 deletions syncode/parsers/python_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def _get_indentation(self, partial_code) -> int:
return tab_len

def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
# Stores the sequence of tokens that the parser has seen in the order
interactive = self.interactive
lexer_tokens, lexing_incomplete = self._lex_code(partial_code)
self.next_ac_terminals = self._accepts(self.interactive)

# Restore the previous state of the parser
self._restore_recent_parser_state(lexer_tokens)
interactive = self.interactive

next_ac_indents = None

Expand Down