11from collections import defaultdict
22import copy , os , pickle
3+ import time
34import interegular
45import torch
56import regex
@@ -290,7 +291,8 @@ def __init__(self,
290291 special_token_ids : Iterable = [],
291292 indentation : bool = True ,
292293 mode = 'grammar_strict' , # 'grammar_strict' or 'grammar_mask'
293- ignore_terminals : Iterable [str ]= []
294+ ignore_terminals : Iterable [str ]= [],
295+ parse_table = None
294296 ):
295297 self ._vocab = vocab
296298 self .special_token_ids = special_token_ids
@@ -304,7 +306,11 @@ def __init__(self,
304306 # Iterate through each pair of DFA state and next terminals and store the overapproximate tokens
305307 self ._lookup_table = LookupTable (vocab , special_token_ids , indentation = indentation , mode = mode )
306308 terminal_names = [terminal .name for terminal in terminals ]
307- self ._store_overapproximate_tokens (terminal_names , vocab )
309+
310+ followings_terminas_map = None
311+ if parse_table is not None :
312+ followings_terminas_map = self ._compute_following_terminals_map (terminal_names , parse_table )
313+ self ._store_overapproximate_tokens (terminal_names , vocab , followings_terminas_map )
308314
309315 self .indentation = indentation
310316
@@ -321,7 +327,14 @@ def set_ignore_whitespace(self, terminals: Iterable[TerminalDef], ignore_termina
321327 return ignore_whitespace
322328
323329 @staticmethod
324- def load_dfa_mask_store (grammar : Grammar , tokenizer , use_cache = True , logger = None , mode = 'grammar_strict' ):
330+ def load_dfa_mask_store (
331+ grammar : Grammar ,
332+ tokenizer ,
333+ use_cache = True ,
334+ logger = None ,
335+ mode = 'grammar_strict' ,
336+ parse_table = None
337+ ):
325338 '''
326339 Loads the dfa for the given language and tokenizer. If the dfa is not cached, it is created and cached.
327340 '''
@@ -347,7 +360,17 @@ def load_dfa_mask_store(grammar: Grammar, tokenizer, use_cache=True, logger=None
347360 simplifications = grammar .simplifications ()
348361 os .makedirs (dfa_dir , exist_ok = True )
349362
350- mask_store = DFAMaskStore (base_parser .terminals , vocab , simplifications = simplifications , special_token_ids = [tokenizer .eos_token_id ], mode = mode , ignore_terminals = base_parser .ignore_tokens )
363+ start_time = time .time ()
364+ mask_store = DFAMaskStore (
365+ base_parser .terminals ,
366+ vocab ,
367+ simplifications = simplifications ,
368+ special_token_ids = [tokenizer .eos_token_id ],
369+ mode = mode ,
370+ ignore_terminals = base_parser .ignore_tokens ,
371+ parse_table = parse_table
372+ )
373+ print (f"Time taken to create DFA mask store: { time .time () - start_time } seconds" , flush = True )
351374
352375 pickle .dump (mask_store , open (dfa_path , 'wb' ))
353376 return mask_store
@@ -356,12 +379,42 @@ def _get_default_mask(self) -> torch.Tensor:
356379 mask = torch .zeros (len (self ._vocab ), dtype = torch .bool )
357380 return mask
358381
359- def _store_overapproximate_tokens (self , terminals : Iterable [str ], vocab : Iterable [str ]):
382+ def _compute_following_terminals_map (self , terminals : Iterable [str ], parse_table ) -> defaultdict :
383+ """
384+ From terminals, filter out terminals that cannot follow the current terminal
385+ according to the grammar.
386+
387+ If in the parsing table Action[cur_terminal, parser_state] = 'shift, new_parser_state' then next terminals
388+ are the terminals that are legal in new_parser_state.
389+ """
390+ following_terminals_map = defaultdict (set )
391+ terminals_set = set (terminals )
392+
393+ # We iterate through each cur_terminal:
394+ for cur_terminal in terminals :
395+ # We iterate through each parser_state:
396+ for _ , row in parse_table .states .items ():
397+ if cur_terminal in row :
398+ action = row [cur_terminal ]
399+ # -> If we see a shift action to new_parser_state
400+ if str (action [0 ]) == 'Shift' :
401+ new_parser_state = action [1 ]
402+ for next_terminal in parse_table .states [new_parser_state ]:
403+ # Lark parse_table stores non-terminals and terminals together
404+ if next_terminal in terminals_set :
405+ # -> -> we add the terminals that are legal in new_parser_state
406+ following_terminals_map [cur_terminal ].add (next_terminal )
407+
408+ return following_terminals_map
409+
410+
411+ def _store_overapproximate_tokens (self , terminals : Iterable [str ], vocab : Iterable [str ], followings_terminas_map : dict = None ):
360412 """
361413 Stores the overapproximate tokens for each dfa state and next terminals
362414 """
363415 all_dfa_states = self ._dfas .states ()
364416 pbar = tqdm (total = len (all_dfa_states ))
417+
365418 for dfa_state in all_dfa_states :
366419 for token_idx , token in enumerate (vocab ):
367420 is_special_token = token_idx in self .special_token_ids
@@ -371,12 +424,18 @@ def _store_overapproximate_tokens(self, terminals: Iterable[str], vocab: Iterabl
371424 self ._lookup_table .dfa_state_and_next_terminal_to_tokens_add (
372425 dfa_state , '$END' , token_idx )
373426 else :
374- self ._process_regular_tokens (terminals , dfa_state , token_idx , token )
427+ if followings_terminas_map is not None and dfa_state .terminal in followings_terminas_map :
428+ following_terminals = followings_terminas_map [dfa_state .terminal ]
429+ else :
430+ following_terminals = terminals
431+
432+ self ._process_regular_tokens (following_terminals , dfa_state , token_idx , token )
375433
376434 pbar .update (1 )
377435
378436 def _process_regular_tokens (self , terminals , dfa_state , token_idx , token ):
379437 remainder = token .replace ('\t ' , ' ' )
438+
380439 is_valid , remainder = self ._dfas .consume_prefix (dfa_state , remainder )
381440 if is_valid :
382441 if remainder == '' :
0 commit comments