diff --git a/README.md b/README.md index f6d32809..05859401 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ SynCode depends on HuggingFace [transformers](https://github.com/huggingface/tra | SynCode version | Required transformers version | Python version | | -------------- | ----------------------------- | -------------- | -| `v0.4.13` (latest) | `v4.51.0` | 3.6 - 3.12 | +| `v0.4.14` (latest) | `v4.51.3` | 3.6 - 3.12 | **Note:** Python 3.13 is not currently supported due to dependency constraints. diff --git a/pyproject.toml b/pyproject.toml index 9073ab1c..bb561291 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "syncode" -version="0.4.13" +version="0.4.14" requires-python = ">=3.6,<3.13" description = "Grammar-guided code generation tool" readme = "README.md" diff --git a/setup.py b/setup.py index 73a7b990..3956a8d6 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setuptools.setup( name="syncode", - version="0.4.13", + version="0.4.14", author="Shubham Ugare", author_email="shubhamugare@gmail.com", description="This package provides the tool for grammar augmented LLM generation.", diff --git a/syncode/grammar_mask/grammar_constrainer.py b/syncode/grammar_mask/grammar_constrainer.py index ffcc95de..08ee4c17 100644 --- a/syncode/grammar_mask/grammar_constrainer.py +++ b/syncode/grammar_mask/grammar_constrainer.py @@ -120,11 +120,11 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) -> self._set_start_from(input_ids) input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=-1) - partial_code, remainder_bytes = self._get_partial_codes(input_ids)[0] + partial_output, remainder_bytes = self._get_partial_outputs(input_ids)[0] - res, skip = self._parse_partial_code( + res, skip = self._parse_partial_output( idx=0, - partial_code=partial_code, + partial_output=partial_output, remainder_bytes=remainder_bytes, accepted_generation=False ) @@ -142,7 +142,7 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) -> is_valid = self.dfa_mask_store.is_valid_prefix(res) if is_valid: - self._update_valid_state(partial_code, 0, res) + self._update_valid_state(partial_output, 0, res) return is_valid @@ -163,11 +163,11 @@ def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: The masked scores. """ self._set_start_from(input_ids) # start_from is used for choosing where the parsing should start - partial_codes = self._get_partial_codes(input_ids) + partial_outputs = self._get_partial_outputs(input_ids) - for idx, (partial_code, remainder_bytes) in enumerate(partial_codes): + for idx, (partial_output, remainder_bytes) in enumerate(partial_outputs): # 1. Parsing - res, skip = self._parse_partial_code(idx, partial_code, remainder_bytes, accepted_generation=True) + res, skip = self._parse_partial_output(idx, partial_output, remainder_bytes, accepted_generation=True) if skip: continue # 2. Computing the accept mask @@ -187,7 +187,13 @@ def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> return scores - def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: bytes, accepted_generation=True) -> tuple[ParseResult, bool]: + def _parse_partial_output( + self, + idx: int, + partial_output: str, + remainder_bytes: bytes, + accepted_generation=True + ) -> tuple[ParseResult, bool]: """ Parse the partial code and return the result. """ @@ -195,7 +201,7 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte res = None try: - res = self.inc_parser.get_acceptable_next_terminals(partial_code) + res = self.inc_parser.get_acceptable_next_terminals(partial_output) if len(remainder_bytes) > 0: res.remainder_state = RemainderState.INCOMPLETE @@ -203,7 +209,7 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte else: res.remainder = res.remainder.encode('utf-8') - self._update_valid_state(partial_code, idx, res) + self._update_valid_state(partial_output, idx, res) except Exception as e: if self.dev_mode == True and accepted_generation: logger.info("-"*50) @@ -213,31 +219,31 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte elif self.parse_failed == False and accepted_generation: self.parse_failed = True logger.info("-"*50) - logger.info(f"Parsing failed! Falling back to unconstrained decoding.\nException: {e}\nPartial code: {partial_code}\nParsed lexical tokens: {self.inc_parser.parsed_lexer_tokens}") + logger.info(f"Parsing failed! Falling back to unconstrained decoding.\nException: {e}\nPartial code: {partial_output}\nParsed lexical tokens: {self.inc_parser.parsed_lexer_tokens}") logger.info("-"*50) skip = True return res, skip - def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]: + def _get_partial_outputs(self, input_ids: torch.LongTensor) -> list[(str, bytes)]: """ Get the partial codes for the input_ids and return the remainder bytes if the partial code is not a valid UTF-8 string. """ output = [] for idx in range(len(input_ids)): if self.parse_output_only: - partial_code, remainder_bytes = self._bytes_to_string( + partial_output, remainder_bytes = self._bytes_to_string( self.byte_tokenizer.decode( input_ids[idx, self.start_from:].tolist(), skip_special_tokens=True) ) else: - partial_code, remainder_bytes = self._bytes_to_string( + partial_output, remainder_bytes = self._bytes_to_string( self.byte_tokenizer.decode( input_ids[idx].tolist(), skip_special_tokens=True) ) - output.append((partial_code, remainder_bytes)) + output.append((partial_output, remainder_bytes)) return output - def _update_valid_state(self, partial_code: str, idx: int, r: ParseResult): + def _update_valid_state(self, partial_output: str, idx: int, r: ParseResult): """ This a simple heuristic to cut off the generated output at the end of the function. TODO: Put this under a flag to enable/disable this heuristic. @@ -245,13 +251,13 @@ def _update_valid_state(self, partial_code: str, idx: int, r: ParseResult): if idx < len(self.function_ends): if r.function_end: # If the function end is not None, then the last valid state is the function end if self.function_ends[idx] is None: self.function_ends[idx] = [] - self.function_ends[idx].append(len(partial_code) - len(r.remainder)) + self.function_ends[idx].append(len(partial_output) - len(r.remainder)) if idx < len(self.last_valid_state): for accept_seq in r.accept_sequences: # 'EOF' is special terminal since $END does not work with python if accept_seq[0] == '$END' or accept_seq[0] == 'EOF': - self.last_valid_state[idx] = len(partial_code) - len(r.remainder) + self.last_valid_state[idx] = len(partial_output) - len(r.remainder) @staticmethod def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]: diff --git a/syncode/mask_store/byte_tokenizer.py b/syncode/mask_store/byte_tokenizer.py index 81edd5d2..37ae8b1b 100644 --- a/syncode/mask_store/byte_tokenizer.py +++ b/syncode/mask_store/byte_tokenizer.py @@ -188,11 +188,12 @@ def __init__(self, tokenizer, vocab_type=None): # Cache special token IDs as a set for faster lookups self.special_token_ids = set(getattr(tokenizer, "all_special_ids", [])) + # NOTE: This seems to be problematic in some cases where regular tokens like "\t" are treated as special tokens # Added tokens are typically special tokens # if added_tokens_decoder is not None self.tokenizer.added_tokens_decoder.keys() # to special_token_ids - if hasattr(tokenizer, "added_tokens_decoder"): - self.special_token_ids.update(tokenizer.added_tokens_decoder.keys()) + # if hasattr(tokenizer, "added_tokens_decoder"): + # self.special_token_ids.update(tokenizer.added_tokens_decoder.keys()) @classmethod diff --git a/syncode/parsers/python_parser.py b/syncode/parsers/python_parser.py index 4d16ee2b..c1ae16e2 100644 --- a/syncode/parsers/python_parser.py +++ b/syncode/parsers/python_parser.py @@ -35,12 +35,12 @@ def _get_indentation(self, partial_code) -> int: return tab_len def get_acceptable_next_terminals(self, partial_code) -> ParseResult: - # Stores the sequence of tokens that the parser has seen in the order - interactive = self.interactive lexer_tokens, lexing_incomplete = self._lex_code(partial_code) + self.next_ac_terminals = self._accepts(self.interactive) # Restore the previous state of the parser self._restore_recent_parser_state(lexer_tokens) + interactive = self.interactive next_ac_indents = None