1+ from typing import Optional , Tuple
12import torch
2- import syncode . common as common
3- from transformers import LogitsProcessor , PreTrainedTokenizer
3+ import re
4+ from transformers import PreTrainedTokenizer
45from syncode .mask_store .byte_tokenizer import ByteTokenizer
56from syncode .parse_result import AcceptSequence , RemainderState
67from syncode .parsers .incremental_parser import IncrementalParser , ParseResult
@@ -43,17 +44,32 @@ class GrammarConstrainer:
4344
4445 For more details on the approximation methods, refer to the SynCode paper:
4546 https://arxiv.org/abs/2403.01632
47+
48+
49+ start_delim (str, optional): Start delimiter marking the beginning of structured
50+ (grammar-constrained) content.
51+ end_delim (str, optional): End delimiter marking the end of structured content.
52+
53+ NOTE: These delimiters are used to extract structured regions for parsing and grammar enforcement.
54+ See *CRANE: Reasoning with Constrained LLM Generation*
55+ ([arXiv:2502.09061](https://arxiv.org/abs/2502.09061)) for more details.
56+ Example: `start_delim="```python\n "` and `end_delim="```"` would parse only
57+ the content between these markers.
4658 """
47- def __init__ (self ,
48- grammar : Grammar ,
49- tokenizer : PreTrainedTokenizer ,
50- byte_tokenizer : ByteTokenizer ,
51- use_cache = True ,
52- parse_output_only = True ,
53- batch_size = 1 ,
54- dev_mode = False ,
55- parser = 'lalr' ,
56- mode = 'grammar_mask' ):
59+ def __init__ (
60+ self ,
61+ grammar : Grammar ,
62+ tokenizer : PreTrainedTokenizer ,
63+ byte_tokenizer : ByteTokenizer ,
64+ use_cache = True ,
65+ parse_output_only = True ,
66+ batch_size = 1 ,
67+ dev_mode = False ,
68+ parser = 'lalr' ,
69+ mode = 'grammar_mask' ,
70+ start_delim = None ,
71+ end_delim = None ,
72+ ):
5773
5874 self .tokenizer = tokenizer
5975 self .byte_tokenizer = byte_tokenizer
@@ -84,6 +100,10 @@ def __init__(self,
84100 mode = mode , # Controls approximation strategy for token masking
85101 )
86102
103+ # Used for separating the structured content from the rest of the generated text
104+ # defaults to None, meaning no delimiters are used
105+ self .start_delim = start_delim
106+ self .end_delim = end_delim
87107
88108 def reset (self ):
89109 """
@@ -120,11 +140,11 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
120140 self ._set_start_from (input_ids )
121141
122142 input_ids = torch .cat ((input_ids , next_token .unsqueeze (0 )), dim = - 1 )
123- partial_code , remainder_bytes = self ._get_partial_codes (input_ids )[0 ]
143+ partial_output , remainder_bytes = self ._get_partial_outputs (input_ids )[0 ]
124144
125- res , skip = self ._parse_partial_code (
145+ res , skip = self ._parse_partial_output (
126146 idx = 0 ,
127- partial_code = partial_code ,
147+ partial_output = partial_output ,
128148 remainder_bytes = remainder_bytes ,
129149 accepted_generation = False
130150 )
@@ -142,7 +162,7 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
142162 is_valid = self .dfa_mask_store .is_valid_prefix (res )
143163
144164 if is_valid :
145- self ._update_valid_state (partial_code , 0 , res )
165+ self ._update_valid_state (partial_output , 0 , res )
146166
147167 return is_valid
148168
@@ -163,11 +183,11 @@ def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) ->
163183 torch.FloatTensor: The masked scores.
164184 """
165185 self ._set_start_from (input_ids ) # start_from is used for choosing where the parsing should start
166- partial_codes = self ._get_partial_codes (input_ids )
186+ partial_outputs = self ._get_partial_outputs (input_ids )
167187
168- for idx , (partial_code , remainder_bytes ) in enumerate (partial_codes ):
188+ for idx , (partial_output , remainder_bytes ) in enumerate (partial_outputs ):
169189 # 1. Parsing
170- res , skip = self ._parse_partial_code (idx , partial_code , remainder_bytes , accepted_generation = True )
190+ res , skip = self ._parse_partial_output (idx , partial_output , remainder_bytes , accepted_generation = True )
171191 if skip : continue
172192
173193 # 2. Computing the accept mask
@@ -187,68 +207,111 @@ def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) ->
187207
188208 return scores
189209
190- def _parse_partial_code (self , idx : int , partial_code : str , remainder_bytes : bytes , accepted_generation = True ) -> tuple [ParseResult , bool ]:
210+ def _parse_partial_output (self , idx : int , partial_output : str , remainder_bytes : bytes , accepted_generation = True ) -> tuple [ParseResult , bool ]:
191211 """
192212 Parse the partial code and return the result.
193213 """
194214 skip = False
195215 res = None
196216
197217 try :
198- res = self .inc_parser .get_acceptable_next_terminals (partial_code )
218+ res = self .inc_parser .get_acceptable_next_terminals (partial_output )
199219
200220 if len (remainder_bytes ) > 0 :
201221 res .remainder_state = RemainderState .INCOMPLETE
202222 res .remainder = res .remainder .encode ('utf-8' ) + remainder_bytes
203223 else :
204224 res .remainder = res .remainder .encode ('utf-8' )
205225
206- self ._update_valid_state (partial_code , idx , res )
226+ self ._update_valid_state (partial_output , idx , res )
207227 except Exception as e :
208228 if self .dev_mode == True and accepted_generation :
209229 raise e
210230 elif self .parse_failed == False and accepted_generation :
211231 self .parse_failed = True
212232 logger .info ("-" * 50 )
213- logger .info (f"Parsing failed! Falling back to unconstrained decoding.\n Exception: { e } \n Partial code: { partial_code } \n Parsed lexical tokens: { self .inc_parser .parsed_lexer_tokens } " )
233+ logger .info (f"Parsing failed! Falling back to unconstrained decoding.\n Exception: { e } \n Partial code: { partial_output } \n Parsed lexical tokens: { self .inc_parser .parsed_lexer_tokens } " )
214234 logger .info ("-" * 50 )
215235 skip = True
216236 return res , skip
217237
218- def _get_partial_codes (self , input_ids : torch .LongTensor ) -> list [(str , bytes )]:
238+ def _get_partial_outputs (self , input_ids : torch .LongTensor ) -> list [(str , bytes )]:
219239 """
220240 Get the partial codes for the input_ids and return the remainder bytes if the partial code is not a valid UTF-8 string.
221241 """
222242 output = []
223243 for idx in range (len (input_ids )):
224244 if self .parse_output_only :
225- partial_code , remainder_bytes = self ._bytes_to_string (
245+ partial_output , remainder_bytes = self ._bytes_to_string (
226246 self .byte_tokenizer .decode (
227247 input_ids [idx , self .start_from :].tolist (), skip_special_tokens = True )
228248 )
229249 else :
230- partial_code , remainder_bytes = self ._bytes_to_string (
250+ partial_output , remainder_bytes = self ._bytes_to_string (
231251 self .byte_tokenizer .decode (
232252 input_ids [idx ].tolist (), skip_special_tokens = True )
233253 )
234- output .append ((partial_code , remainder_bytes ))
254+
255+ # Use self.start_delim and self.end_delim to extract the structured content
256+ # It is possible that there are multiple start_delim and end_delim in the current input
257+
258+
259+ output .append ((partial_output , remainder_bytes ))
235260 return output
236261
237- def _update_valid_state (self , partial_code : str , idx : int , r : ParseResult ):
262+ @staticmethod
263+ def extract_last_structured_block (
264+ text : str ,
265+ start_delim : Optional [str ],
266+ end_delim : Optional [str ]
267+ ) -> Tuple [str , bool ]:
268+ """
269+ Extracts the last structured block from `text` between `start_delim` and `end_delim`.
270+
271+ Returns:
272+ (extracted_text: str, should_constrain: bool)
273+ - extracted_text: The content of the last delimited block.
274+ - should_constrain: True if a start delimiter is present without a matching end,
275+ meaning structured generation should continue.
276+ """
277+ if start_delim is None or end_delim is None :
278+ return "" , False
279+
280+ # Find all fully enclosed blocks
281+ pattern = re .escape (start_delim ) + r"(.*?)" + re .escape (end_delim )
282+ matches = list (re .finditer (pattern , text , flags = re .DOTALL ))
283+
284+ if matches :
285+ last_match = matches [- 1 ]
286+ return last_match .group (1 ).strip (), False # closed, no need to constrain further
287+
288+ # If there's a start but no end, check for unclosed start
289+ last_start_idx = text .rfind (start_delim )
290+ last_end_idx = text .rfind (end_delim )
291+
292+ # If the start delimiter appears after the last end delimiter, it's an open block
293+ if last_start_idx > last_end_idx :
294+ # Return the content after the last start delimiter, even if the end delimiter is missing
295+ return text [last_start_idx + len (start_delim ):].strip (), True
296+
297+ return "" , False # no open block => no constraint needed
298+
299+
300+ def _update_valid_state (self , partial_output : str , idx : int , r : ParseResult ):
238301 """
239302 This a simple heuristic to cut off the generated output at the end of the function.
240303 TODO: Put this under a flag to enable/disable this heuristic.
241304 """
242305 if idx < len (self .function_ends ):
243306 if r .function_end : # If the function end is not None, then the last valid state is the function end
244307 if self .function_ends [idx ] is None : self .function_ends [idx ] = []
245- self .function_ends [idx ].append (len (partial_code ) - len (r .remainder ))
308+ self .function_ends [idx ].append (len (partial_output ) - len (r .remainder ))
246309
247310 if idx < len (self .last_valid_state ):
248311 for accept_seq in r .accept_sequences :
249312 # 'EOF' is special terminal since $END does not work with python
250313 if accept_seq [0 ] == '$END' or accept_seq [0 ] == 'EOF' :
251- self .last_valid_state [idx ] = len (partial_code ) - len (r .remainder )
314+ self .last_valid_state [idx ] = len (partial_output ) - len (r .remainder )
252315
253316 @staticmethod
254317 def _bytes_to_string (byte_sequence : bytes ) -> tuple [str , bytes ]:
0 commit comments