Skip to content

Commit 23d608e

Browse files
committed
Compress mask store by only storing viable terminal sequences
1 parent e42a0f6 commit 23d608e

File tree

2 files changed

+69
-9
lines changed

2 files changed

+69
-9
lines changed

syncode/dfa_mask_store.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import defaultdict
22
import copy, os, pickle
3+
import time
34
import interegular
45
import torch
56
import regex
@@ -290,7 +291,8 @@ def __init__(self,
290291
special_token_ids: Iterable=[],
291292
indentation: bool=True,
292293
mode='grammar_strict', # 'grammar_strict' or 'grammar_mask'
293-
ignore_terminals: Iterable[str]=[]
294+
ignore_terminals: Iterable[str]=[],
295+
parse_table=None
294296
):
295297
self._vocab = vocab
296298
self.special_token_ids = special_token_ids
@@ -304,7 +306,11 @@ def __init__(self,
304306
# Iterate through each pair of DFA state and next terminals and store the overapproximate tokens
305307
self._lookup_table = LookupTable(vocab, special_token_ids, indentation=indentation, mode=mode)
306308
terminal_names = [terminal.name for terminal in terminals]
307-
self._store_overapproximate_tokens(terminal_names, vocab)
309+
310+
followings_terminas_map = None
311+
if parse_table is not None:
312+
followings_terminas_map = self._compute_following_terminals_map(terminal_names, parse_table)
313+
self._store_overapproximate_tokens(terminal_names, vocab, followings_terminas_map)
308314

309315
self.indentation = indentation
310316

@@ -321,7 +327,14 @@ def set_ignore_whitespace(self, terminals: Iterable[TerminalDef], ignore_termina
321327
return ignore_whitespace
322328

323329
@staticmethod
324-
def load_dfa_mask_store(grammar: Grammar, tokenizer, use_cache=True, logger=None, mode='grammar_strict'):
330+
def load_dfa_mask_store(
331+
grammar: Grammar,
332+
tokenizer,
333+
use_cache=True,
334+
logger=None,
335+
mode='grammar_strict',
336+
parse_table=None
337+
):
325338
'''
326339
Loads the dfa for the given language and tokenizer. If the dfa is not cached, it is created and cached.
327340
'''
@@ -347,7 +360,17 @@ def load_dfa_mask_store(grammar: Grammar, tokenizer, use_cache=True, logger=None
347360
simplifications = grammar.simplifications()
348361
os.makedirs(dfa_dir, exist_ok=True)
349362

350-
mask_store = DFAMaskStore(base_parser.terminals, vocab, simplifications=simplifications, special_token_ids=[tokenizer.eos_token_id], mode=mode, ignore_terminals=base_parser.ignore_tokens)
363+
start_time = time.time()
364+
mask_store = DFAMaskStore(
365+
base_parser.terminals,
366+
vocab,
367+
simplifications=simplifications,
368+
special_token_ids=[tokenizer.eos_token_id],
369+
mode=mode,
370+
ignore_terminals=base_parser.ignore_tokens,
371+
parse_table=parse_table
372+
)
373+
print(f"Time taken to create DFA mask store: {time.time() - start_time} seconds", flush=True)
351374

352375
pickle.dump(mask_store, open(dfa_path, 'wb'))
353376
return mask_store
@@ -356,12 +379,42 @@ def _get_default_mask(self) -> torch.Tensor:
356379
mask = torch.zeros(len(self._vocab), dtype=torch.bool)
357380
return mask
358381

359-
def _store_overapproximate_tokens(self, terminals: Iterable[str], vocab: Iterable[str]):
382+
def _compute_following_terminals_map(self, terminals: Iterable[str], parse_table) -> defaultdict:
383+
"""
384+
From terminals, filter out terminals that cannot follow the current terminal
385+
according to the grammar.
386+
387+
If in the parsing table Action[cur_terminal, parser_state] = 'shift, new_parser_state' then next terminals
388+
are the terminals that are legal in new_parser_state.
389+
"""
390+
following_terminals_map = defaultdict(set)
391+
terminals_set = set(terminals)
392+
393+
# We iterate through each cur_terminal:
394+
for cur_terminal in terminals:
395+
# We iterate through each parser_state:
396+
for _, row in parse_table.states.items():
397+
if cur_terminal in row:
398+
action = row[cur_terminal]
399+
# -> If we see a shift action to new_parser_state
400+
if str(action[0]) == 'Shift':
401+
new_parser_state = action[1]
402+
for next_terminal in parse_table.states[new_parser_state]:
403+
# Lark parse_table stores non-terminals and terminals together
404+
if next_terminal in terminals_set:
405+
# -> -> we add the terminals that are legal in new_parser_state
406+
following_terminals_map[cur_terminal].add(next_terminal)
407+
408+
return following_terminals_map
409+
410+
411+
def _store_overapproximate_tokens(self, terminals: Iterable[str], vocab: Iterable[str], followings_terminas_map: dict=None):
360412
"""
361413
Stores the overapproximate tokens for each dfa state and next terminals
362414
"""
363415
all_dfa_states = self._dfas.states()
364416
pbar = tqdm(total=len(all_dfa_states))
417+
365418
for dfa_state in all_dfa_states:
366419
for token_idx, token in enumerate(vocab):
367420
is_special_token = token_idx in self.special_token_ids
@@ -371,12 +424,18 @@ def _store_overapproximate_tokens(self, terminals: Iterable[str], vocab: Iterabl
371424
self._lookup_table.dfa_state_and_next_terminal_to_tokens_add(
372425
dfa_state, '$END', token_idx)
373426
else:
374-
self._process_regular_tokens(terminals, dfa_state, token_idx, token)
427+
if followings_terminas_map is not None and dfa_state.terminal in followings_terminas_map:
428+
following_terminals = followings_terminas_map[dfa_state.terminal]
429+
else:
430+
following_terminals = terminals
431+
432+
self._process_regular_tokens(following_terminals, dfa_state, token_idx, token)
375433

376434
pbar.update(1)
377435

378436
def _process_regular_tokens(self, terminals, dfa_state, token_idx, token):
379437
remainder = token.replace('\t', ' ')
438+
380439
is_valid, remainder = self._dfas.consume_prefix(dfa_state, remainder)
381440
if is_valid:
382441
if remainder == '':

syncode/grammar_decoder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,19 @@ def __init__(self,
5151
# Ignore whitespace tokens
5252
self._ignore_whitespace = self._get_ignore_whitespace(self.grammar)
5353

54+
# Create parser
55+
self.inc_parser: IncrementalParser = create_parser(self.grammar, logger=self.logger, parser=parser, ignore_whitespace=self._ignore_whitespace)
56+
5457
# Load dfa mask store
5558
self.dfa_mask_store = DFAMaskStore.load_dfa_mask_store(
5659
grammar=self.grammar,
5760
tokenizer=self.tokenizer,
5861
use_cache=use_cache,
5962
logger=self.logger,
6063
mode=mode,
64+
parse_table=self.inc_parser.base_parser.parser.parser._parse_table,
6165
)
6266

63-
# Create parser
64-
self.inc_parser: IncrementalParser = create_parser(self.grammar, logger=self.logger, parser=parser, ignore_whitespace=self._ignore_whitespace)
65-
6667

6768
def _log_current_status(self, partial_code, r: ParseResult):
6869
self.logger.log_code('Partial code', partial_code)

0 commit comments

Comments
 (0)