11from collections import defaultdict
22import copy , os , pickle
3+ import time
34import interegular
45import torch
56import regex
@@ -155,7 +156,6 @@ def incomplete_case_lookup(self, dfa_state: DFAState) -> Any:
155156 if dfa_state in self ._exact_lookup :
156157 return self ._exact_lookup [dfa_state ]
157158 else :
158- print (f"Warning: Exact lookup not found for { dfa_state } in the DFA mask store. This could be an error." , flush = True )
159159 return self ._overapprox_lookup [dfa_state ]
160160 raise ValueError (f"Invalid mode: { self ._mode } " )
161161
@@ -167,9 +167,6 @@ def store_overapprox_lookup(self, dfa_state: DFAState, mask: torch.Tensor):
167167
168168 def complete_case_lookup (self , dfa_state : DFAState ) -> Any :
169169 assert isinstance (dfa_state , DFAState )
170- if dfa_state not in self ._exact_lookup :
171- # FIXME: This is bit strange and need to be checked more carefully
172- return self ._overapprox_lookup [dfa_state ]
173170 return self ._exact_lookup [dfa_state ]
174171
175172 def add_exact_lookup (self , dfa_state : DFAState , token ):
@@ -294,7 +291,8 @@ def __init__(self,
294291 special_token_ids : Iterable = [],
295292 indentation : bool = True ,
296293 mode = 'grammar_strict' , # 'grammar_strict' or 'grammar_mask'
297- ignore_terminals : Iterable [str ]= []
294+ ignore_terminals : Iterable [str ]= [],
295+ parse_table = None
298296 ):
299297 self ._vocab = vocab
300298 self .special_token_ids = special_token_ids
@@ -308,7 +306,11 @@ def __init__(self,
308306 # Iterate through each pair of DFA state and next terminals and store the overapproximate tokens
309307 self ._lookup_table = LookupTable (vocab , special_token_ids , indentation = indentation , mode = mode )
310308 terminal_names = [terminal .name for terminal in terminals ]
311- self ._store_overapproximate_tokens (terminal_names , vocab )
309+
310+ followings_terminas_map = None
311+ if parse_table is not None :
312+ followings_terminas_map = self ._compute_following_terminals_map (terminal_names , parse_table )
313+ self ._store_overapproximate_tokens (terminal_names , vocab , followings_terminas_map )
312314
313315 self .indentation = indentation
314316
@@ -325,7 +327,14 @@ def set_ignore_whitespace(self, terminals: Iterable[TerminalDef], ignore_termina
325327 return ignore_whitespace
326328
327329 @staticmethod
328- def load_dfa_mask_store (grammar : Grammar , tokenizer , use_cache = True , logger = common .EmptyLogger (), mode = 'grammar_strict' ):
330+ def load_dfa_mask_store (
331+ grammar : Grammar ,
332+ tokenizer ,
333+ use_cache = True ,
334+ logger = None ,
335+ mode = 'grammar_strict' ,
336+ parse_table = None
337+ ):
329338 '''
330339 Loads the dfa for the given language and tokenizer. If the dfa is not cached, it is created and cached.
331340 '''
@@ -351,7 +360,17 @@ def load_dfa_mask_store(grammar: Grammar, tokenizer, use_cache=True, logger=comm
351360 simplifications = grammar .simplifications ()
352361 os .makedirs (dfa_dir , exist_ok = True )
353362
354- mask_store = DFAMaskStore (base_parser .terminals , vocab , simplifications = simplifications , special_token_ids = [tokenizer .eos_token_id ], mode = mode , ignore_terminals = base_parser .ignore_tokens )
363+ start_time = time .time ()
364+ mask_store = DFAMaskStore (
365+ base_parser .terminals ,
366+ vocab ,
367+ simplifications = simplifications ,
368+ special_token_ids = [tokenizer .eos_token_id ],
369+ mode = mode ,
370+ ignore_terminals = base_parser .ignore_tokens ,
371+ parse_table = parse_table
372+ )
373+ print (f"Time taken to create DFA mask store: { time .time () - start_time } seconds" , flush = True )
355374
356375 pickle .dump (mask_store , open (dfa_path , 'wb' ))
357376 return mask_store
@@ -360,12 +379,42 @@ def _get_default_mask(self) -> torch.Tensor:
360379 mask = torch .zeros (len (self ._vocab ), dtype = torch .bool )
361380 return mask
362381
363- def _store_overapproximate_tokens (self , terminals : Iterable [str ], vocab : Iterable [str ]):
382+ def _compute_following_terminals_map (self , terminals : Iterable [str ], parse_table ) -> defaultdict :
383+ """
384+ From terminals, filter out terminals that cannot follow the current terminal
385+ according to the grammar.
386+
387+ If in the parsing table Action[cur_terminal, parser_state] = 'shift, new_parser_state' then next terminals
388+ are the terminals that are legal in new_parser_state.
389+ """
390+ following_terminals_map = defaultdict (set )
391+ terminals_set = set (terminals )
392+
393+ # We iterate through each cur_terminal:
394+ for cur_terminal in terminals :
395+ # We iterate through each parser_state:
396+ for _ , row in parse_table .states .items ():
397+ if cur_terminal in row :
398+ action = row [cur_terminal ]
399+ # -> If we see a shift action to new_parser_state
400+ if str (action [0 ]) == 'Shift' :
401+ new_parser_state = action [1 ]
402+ for next_terminal in parse_table .states [new_parser_state ]:
403+ # Lark parse_table stores non-terminals and terminals together
404+ if next_terminal in terminals_set :
405+ # -> -> we add the terminals that are legal in new_parser_state
406+ following_terminals_map [cur_terminal ].add (next_terminal )
407+
408+ return following_terminals_map
409+
410+
411+ def _store_overapproximate_tokens (self , terminals : Iterable [str ], vocab : Iterable [str ], followings_terminas_map : dict = None ):
364412 """
365413 Stores the overapproximate tokens for each dfa state and next terminals
366414 """
367415 all_dfa_states = self ._dfas .states ()
368416 pbar = tqdm (total = len (all_dfa_states ))
417+
369418 for dfa_state in all_dfa_states :
370419 for token_idx , token in enumerate (vocab ):
371420 is_special_token = token_idx in self .special_token_ids
@@ -375,12 +424,18 @@ def _store_overapproximate_tokens(self, terminals: Iterable[str], vocab: Iterabl
375424 self ._lookup_table .dfa_state_and_next_terminal_to_tokens_add (
376425 dfa_state , '$END' , token_idx )
377426 else :
378- self ._process_regular_tokens (terminals , dfa_state , token_idx , token )
427+ if followings_terminas_map is not None and dfa_state .terminal in followings_terminas_map :
428+ following_terminals = followings_terminas_map [dfa_state .terminal ]
429+ else :
430+ following_terminals = terminals
431+
432+ self ._process_regular_tokens (following_terminals , dfa_state , token_idx , token )
379433
380434 pbar .update (1 )
381435
382436 def _process_regular_tokens (self , terminals , dfa_state , token_idx , token ):
383437 remainder = token .replace ('\t ' , ' ' )
438+
384439 is_valid , remainder = self ._dfas .consume_prefix (dfa_state , remainder )
385440 if is_valid :
386441 if remainder == '' :
@@ -455,12 +510,10 @@ def _lookup_next_tokens(self, dfa_states: Iterable[DFAState], r: ParseResult) ->
455510 elif len (accept_sequence ) == 2 :
456511 overapprox_token_ids |= self ._lookup_next_tokens_for_dfa_state (dfa_state , accept_sequence [1 ])
457512 elif len (accept_sequence ) == 3 :
458- # This is useful in under-approximating `grammar_strict` mode as they help improve the precision of SynCode
459- if self ._mode == 'grammar_strict' :
460- # If the DFA state is a final state we can jump to the start of next terminal
461- if self ._dfas .is_final (dfa_state ):
462- ignore_init_state = self ._dfas .initial (accept_sequence [1 ])
463- overapprox_token_ids |= self ._lookup_next_tokens_for_dfa_state (ignore_init_state , accept_sequence [2 ])
513+ # If the DFA state is a final state we can jump to the start of next terminal
514+ if self ._dfas .is_final (dfa_state ):
515+ ignore_init_state = self ._dfas .initial (accept_sequence [1 ])
516+ overapprox_token_ids |= self ._lookup_next_tokens_for_dfa_state (ignore_init_state , accept_sequence [2 ])
464517 else :
465518 raise ValueError (f"Invalid accept sequence: { accept_sequence } " )
466519 return overapprox_token_ids
0 commit comments