Skip to content

Commit bf20e10

Browse files
committed
Add crane grammar delimeters
1 parent d343567 commit bf20e10

File tree

4 files changed

+225
-32
lines changed

4 files changed

+225
-32
lines changed

syncode/grammar_mask/grammar_constrainer.py

Lines changed: 93 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from typing import Optional, Tuple
12
import torch
2-
import syncode.common as common
3-
from transformers import LogitsProcessor, PreTrainedTokenizer
3+
import re
4+
from transformers import PreTrainedTokenizer
45
from syncode.mask_store.byte_tokenizer import ByteTokenizer
56
from syncode.parse_result import AcceptSequence, RemainderState
67
from syncode.parsers.incremental_parser import IncrementalParser, ParseResult
@@ -43,17 +44,32 @@ class GrammarConstrainer:
4344
4445
For more details on the approximation methods, refer to the SynCode paper:
4546
https://arxiv.org/abs/2403.01632
47+
48+
49+
start_delim (str, optional): Start delimiter marking the beginning of structured
50+
(grammar-constrained) content.
51+
end_delim (str, optional): End delimiter marking the end of structured content.
52+
53+
NOTE: These delimiters are used to extract structured regions for parsing and grammar enforcement.
54+
See *CRANE: Reasoning with Constrained LLM Generation*
55+
([arXiv:2502.09061](https://arxiv.org/abs/2502.09061)) for more details.
56+
Example: `start_delim="```python\n"` and `end_delim="```"` would parse only
57+
the content between these markers.
4658
"""
47-
def __init__(self,
48-
grammar: Grammar,
49-
tokenizer: PreTrainedTokenizer,
50-
byte_tokenizer: ByteTokenizer,
51-
use_cache=True,
52-
parse_output_only=True,
53-
batch_size=1,
54-
dev_mode=False,
55-
parser='lalr',
56-
mode='grammar_mask'):
59+
def __init__(
60+
self,
61+
grammar: Grammar,
62+
tokenizer: PreTrainedTokenizer,
63+
byte_tokenizer: ByteTokenizer,
64+
use_cache=True,
65+
parse_output_only=True,
66+
batch_size=1,
67+
dev_mode=False,
68+
parser='lalr',
69+
mode='grammar_mask',
70+
start_delim=None,
71+
end_delim=None,
72+
):
5773

5874
self.tokenizer = tokenizer
5975
self.byte_tokenizer = byte_tokenizer
@@ -84,6 +100,10 @@ def __init__(self,
84100
mode=mode, # Controls approximation strategy for token masking
85101
)
86102

103+
# Used for separating the structured content from the rest of the generated text
104+
# defaults to None, meaning no delimiters are used
105+
self.start_delim = start_delim
106+
self.end_delim = end_delim
87107

88108
def reset(self):
89109
"""
@@ -120,11 +140,11 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
120140
self._set_start_from(input_ids)
121141

122142
input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=-1)
123-
partial_code, remainder_bytes = self._get_partial_codes(input_ids)[0]
143+
partial_output, remainder_bytes = self._get_partial_outputs(input_ids)[0]
124144

125-
res, skip = self._parse_partial_code(
145+
res, skip = self._parse_partial_output(
126146
idx=0,
127-
partial_code=partial_code,
147+
partial_output=partial_output,
128148
remainder_bytes=remainder_bytes,
129149
accepted_generation=False
130150
)
@@ -142,7 +162,7 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
142162
is_valid = self.dfa_mask_store.is_valid_prefix(res)
143163

144164
if is_valid:
145-
self._update_valid_state(partial_code, 0, res)
165+
self._update_valid_state(partial_output, 0, res)
146166

147167
return is_valid
148168

@@ -163,11 +183,11 @@ def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) ->
163183
torch.FloatTensor: The masked scores.
164184
"""
165185
self._set_start_from(input_ids) # start_from is used for choosing where the parsing should start
166-
partial_codes = self._get_partial_codes(input_ids)
186+
partial_outputs = self._get_partial_outputs(input_ids)
167187

168-
for idx, (partial_code, remainder_bytes) in enumerate(partial_codes):
188+
for idx, (partial_output, remainder_bytes) in enumerate(partial_outputs):
169189
# 1. Parsing
170-
res, skip = self._parse_partial_code(idx, partial_code, remainder_bytes, accepted_generation=True)
190+
res, skip = self._parse_partial_output(idx, partial_output, remainder_bytes, accepted_generation=True)
171191
if skip: continue
172192

173193
# 2. Computing the accept mask
@@ -187,68 +207,111 @@ def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) ->
187207

188208
return scores
189209

190-
def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: bytes, accepted_generation=True) -> tuple[ParseResult, bool]:
210+
def _parse_partial_output(self, idx: int, partial_output: str, remainder_bytes: bytes, accepted_generation=True) -> tuple[ParseResult, bool]:
191211
"""
192212
Parse the partial code and return the result.
193213
"""
194214
skip = False
195215
res = None
196216

197217
try:
198-
res = self.inc_parser.get_acceptable_next_terminals(partial_code)
218+
res = self.inc_parser.get_acceptable_next_terminals(partial_output)
199219

200220
if len(remainder_bytes) > 0:
201221
res.remainder_state = RemainderState.INCOMPLETE
202222
res.remainder = res.remainder.encode('utf-8') + remainder_bytes
203223
else:
204224
res.remainder = res.remainder.encode('utf-8')
205225

206-
self._update_valid_state(partial_code, idx, res)
226+
self._update_valid_state(partial_output, idx, res)
207227
except Exception as e:
208228
if self.dev_mode == True and accepted_generation:
209229
raise e
210230
elif self.parse_failed == False and accepted_generation:
211231
self.parse_failed = True
212232
logger.info("-"*50)
213-
logger.info(f"Parsing failed! Falling back to unconstrained decoding.\nException: {e}\nPartial code: {partial_code}\nParsed lexical tokens: {self.inc_parser.parsed_lexer_tokens}")
233+
logger.info(f"Parsing failed! Falling back to unconstrained decoding.\nException: {e}\nPartial code: {partial_output}\nParsed lexical tokens: {self.inc_parser.parsed_lexer_tokens}")
214234
logger.info("-"*50)
215235
skip = True
216236
return res, skip
217237

218-
def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
238+
def _get_partial_outputs(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
219239
"""
220240
Get the partial codes for the input_ids and return the remainder bytes if the partial code is not a valid UTF-8 string.
221241
"""
222242
output = []
223243
for idx in range(len(input_ids)):
224244
if self.parse_output_only:
225-
partial_code, remainder_bytes = self._bytes_to_string(
245+
partial_output, remainder_bytes = self._bytes_to_string(
226246
self.byte_tokenizer.decode(
227247
input_ids[idx, self.start_from:].tolist(), skip_special_tokens=True)
228248
)
229249
else:
230-
partial_code, remainder_bytes = self._bytes_to_string(
250+
partial_output, remainder_bytes = self._bytes_to_string(
231251
self.byte_tokenizer.decode(
232252
input_ids[idx].tolist(), skip_special_tokens=True)
233253
)
234-
output.append((partial_code, remainder_bytes))
254+
255+
# Use self.start_delim and self.end_delim to extract the structured content
256+
# It is possible that there are multiple start_delim and end_delim in the current input
257+
258+
259+
output.append((partial_output, remainder_bytes))
235260
return output
236261

237-
def _update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
262+
@staticmethod
263+
def extract_last_structured_block(
264+
text: str,
265+
start_delim: Optional[str],
266+
end_delim: Optional[str]
267+
) -> Tuple[str, bool]:
268+
"""
269+
Extracts the last structured block from `text` between `start_delim` and `end_delim`.
270+
271+
Returns:
272+
(extracted_text: str, should_constrain: bool)
273+
- extracted_text: The content of the last delimited block.
274+
- should_constrain: True if a start delimiter is present without a matching end,
275+
meaning structured generation should continue.
276+
"""
277+
if start_delim is None or end_delim is None:
278+
return "", False
279+
280+
# Find all fully enclosed blocks
281+
pattern = re.escape(start_delim) + r"(.*?)" + re.escape(end_delim)
282+
matches = list(re.finditer(pattern, text, flags=re.DOTALL))
283+
284+
if matches:
285+
last_match = matches[-1]
286+
return last_match.group(1).strip(), False # closed, no need to constrain further
287+
288+
# If there's a start but no end, check for unclosed start
289+
last_start_idx = text.rfind(start_delim)
290+
last_end_idx = text.rfind(end_delim)
291+
292+
# If the start delimiter appears after the last end delimiter, it's an open block
293+
if last_start_idx > last_end_idx:
294+
# Return the content after the last start delimiter, even if the end delimiter is missing
295+
return text[last_start_idx + len(start_delim):].strip(), True
296+
297+
return "", False # no open block => no constraint needed
298+
299+
300+
def _update_valid_state(self, partial_output: str, idx: int, r: ParseResult):
238301
"""
239302
This a simple heuristic to cut off the generated output at the end of the function.
240303
TODO: Put this under a flag to enable/disable this heuristic.
241304
"""
242305
if idx < len(self.function_ends):
243306
if r.function_end: # If the function end is not None, then the last valid state is the function end
244307
if self.function_ends[idx] is None: self.function_ends[idx] = []
245-
self.function_ends[idx].append(len(partial_code) - len(r.remainder))
308+
self.function_ends[idx].append(len(partial_output) - len(r.remainder))
246309

247310
if idx < len(self.last_valid_state):
248311
for accept_seq in r.accept_sequences:
249312
# 'EOF' is special terminal since $END does not work with python
250313
if accept_seq[0] == '$END' or accept_seq[0] == 'EOF':
251-
self.last_valid_state[idx] = len(partial_code) - len(r.remainder)
314+
self.last_valid_state[idx] = len(partial_output) - len(r.remainder)
252315

253316
@staticmethod
254317
def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:

syncode/grammar_mask/logits_processor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ def __init__(self,
2929
num_samples=1,
3030
dev_mode=False,
3131
parser='lalr',
32-
mode='grammar_mask'):
32+
mode='grammar_mask',
33+
start_delim=None,
34+
end_delim=None
35+
):
3336

3437
self.tokenizer = tokenizer
3538
self.byte_tokenizer = ByteTokenizer(tokenizer)
@@ -44,7 +47,9 @@ def __init__(self,
4447
batch_size=num_samples,
4548
dev_mode=dev_mode,
4649
parser=parser,
47-
mode=mode
50+
mode=mode,
51+
start_delim=start_delim,
52+
end_delim=end_delim
4853
)
4954

5055
def reset(self):

syncode/infer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ class Syncode:
3535
parser (str, optional): Parser to use. Defaults to "lalr". Options are "lr" and "lalr".
3636
seed (int, optional): Random seed for reproducibility. Defaults to None.
3737
opp (bool, optional): Whether to use opportunistic generation. Defaults to True.
38+
device_map (str, optional): Device map for model loading. Defaults to None.
39+
40+
start_delim (str, optional): Start delimiter marking the beginning of structured
41+
(grammar-constrained) content.
42+
end_delim (str, optional): End delimiter marking the end of structured content.
43+
44+
NOTE: These delimiters are used to extract structured regions for parsing and grammar enforcement.
45+
See *CRANE: Reasoning with Constrained LLM Generation*
46+
([arXiv:2502.09061](https://arxiv.org/abs/2502.09061)) for more details.
47+
Example: `start_delim="```python\n"` and `end_delim="```"` would parse only
48+
the content between these markers.
49+
3850
**kwargs: Additional arguments passed to the model for generation.
3951
"""
4052
def __init__(
@@ -51,6 +63,8 @@ def __init__(
5163
seed: Optional[int] = None,
5264
opp: bool = True,
5365
device_map: Optional[str] = None,
66+
start_delim: Optional[str] = None,
67+
end_delim: Optional[str] = None,
5468
**kwargs
5569
):
5670
# Check inputs
@@ -102,6 +116,8 @@ def __init__(
102116
dev_mode=dev_mode,
103117
parser=parser,
104118
mode=mode,
119+
start_delim=start_delim,
120+
end_delim=end_delim,
105121
)
106122

107123
# Set default max new tokens if not provided
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import unittest
2+
import sys, os
3+
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
4+
from syncode.grammar_mask.grammar_constrainer import GrammarConstrainer
5+
6+
class TestGrammarDelimeter(unittest.TestCase):
7+
def test_all_cases(self):
8+
test_cases = [
9+
{
10+
"desc": "Single math block",
11+
"text": "Here's a formula: <<x^2 + y^2 = z^2>> in the middle of text.",
12+
"start": "<<",
13+
"end": ">>",
14+
"expected_text": "x^2 + y^2 = z^2",
15+
"expected_constrain": False,
16+
},
17+
{
18+
"desc": "Multiple math blocks, returns last",
19+
"text": "Intro <<a+b>> some more <<c+d>> ending",
20+
"start": "<<",
21+
"end": ">>",
22+
"expected_text": "c+d",
23+
"expected_constrain": False,
24+
},
25+
{
26+
"desc": "Missing closing delimiter",
27+
"text": "This math is broken <<1+1=",
28+
"start": "<<",
29+
"end": ">>",
30+
"expected_text": "1+1=",
31+
"expected_constrain": True,
32+
},
33+
{
34+
"desc": "Nested-looking delimiters",
35+
"text": "Messy: <<1 + <<2>> + 3>>",
36+
"start": "<<",
37+
"end": ">>",
38+
"expected_text": "1 + <<2", # Still not closed due to the second <<
39+
"expected_constrain": False,
40+
},
41+
{
42+
"desc": "No delimiters present",
43+
"text": "No math here, just plain text.",
44+
"start": "<<",
45+
"end": ">>",
46+
"expected_text": "",
47+
"expected_constrain": False,
48+
},
49+
{
50+
"desc": "Math with newlines",
51+
"text": "Start <<\na = b + c\nf(x) = x^2\n>> end",
52+
"start": "<<",
53+
"end": ">>",
54+
"expected_text": "a = b + c\nf(x) = x^2",
55+
"expected_constrain": False,
56+
},
57+
{
58+
"desc": "Only start delim exists after closed one",
59+
"text": "Closed first: <<1+2>> then opened: <<3+4",
60+
"start": "<<",
61+
"end": ">>",
62+
"expected_text": "3+4",
63+
"expected_constrain": True,
64+
},
65+
{
66+
"desc": "Edge case where text ends with prefix of the delimiter",
67+
"text": "some text << xyz >",
68+
"start": "<<",
69+
"end": ">>",
70+
"expected_text": "",
71+
"expected_constrain": False,
72+
},
73+
# New test cases for Python code block
74+
{
75+
"desc": "Python code block with code",
76+
"text": "Some introductory text before code block ```python\nx = 5\nprint(x)``` more text after.",
77+
"start": "```python\n",
78+
"end": "```",
79+
"expected_text": "x = 5\nprint(x)",
80+
"expected_constrain": False,
81+
},
82+
{
83+
"desc": "Python code block with no closing delimiter",
84+
"text": "Here is some code: ```python\nx = 5\nprint(x)",
85+
"start": "```python\n",
86+
"end": "```",
87+
"expected_text": "x = 5\nprint(x)",
88+
"expected_constrain": True,
89+
},
90+
{
91+
"desc": "Multiple code blocks with a start delimiter present after a closed block",
92+
"text": "First block ```python\nx = 10``` and then a second block ```python\ny = 20",
93+
"start": "```python\n",
94+
"end": "```",
95+
"expected_text": "y = 20",
96+
"expected_constrain": True,
97+
},
98+
]
99+
100+
for case in test_cases:
101+
with self.subTest(msg=case["desc"]):
102+
result_text, should_constrain = GrammarConstrainer.extract_last_structured_block(
103+
case["text"], case["start"], case["end"]
104+
)
105+
self.assertEqual(result_text, case["expected_text"])
106+
self.assertEqual(should_constrain, case["expected_constrain"])
107+
108+
if __name__ == "__main__":
109+
unittest.main()

0 commit comments

Comments
 (0)