Skip to content

Commit 5ac1833

Browse files
committed
Compress mask store by only storing viable terminal sequences
1 parent 0388999 commit 5ac1833

File tree

2 files changed

+73
-19
lines changed

2 files changed

+73
-19
lines changed

syncode/dfa_mask_store.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import defaultdict
22
import copy, os, pickle
3+
import time
34
import interegular
45
import torch
56
import regex
@@ -155,7 +156,6 @@ def incomplete_case_lookup(self, dfa_state: DFAState) -> Any:
155156
if dfa_state in self._exact_lookup:
156157
return self._exact_lookup[dfa_state]
157158
else:
158-
print(f"Warning: Exact lookup not found for {dfa_state} in the DFA mask store. This could be an error.", flush=True)
159159
return self._overapprox_lookup[dfa_state]
160160
raise ValueError(f"Invalid mode: {self._mode}")
161161

@@ -167,9 +167,6 @@ def store_overapprox_lookup(self, dfa_state: DFAState, mask: torch.Tensor):
167167

168168
def complete_case_lookup(self, dfa_state: DFAState) -> Any:
169169
assert isinstance(dfa_state, DFAState)
170-
if dfa_state not in self._exact_lookup:
171-
# FIXME: This is bit strange and need to be checked more carefully
172-
return self._overapprox_lookup[dfa_state]
173170
return self._exact_lookup[dfa_state]
174171

175172
def add_exact_lookup(self, dfa_state: DFAState, token):
@@ -294,7 +291,8 @@ def __init__(self,
294291
special_token_ids: Iterable=[],
295292
indentation: bool=True,
296293
mode='grammar_strict', # 'grammar_strict' or 'grammar_mask'
297-
ignore_terminals: Iterable[str]=[]
294+
ignore_terminals: Iterable[str]=[],
295+
parse_table=None
298296
):
299297
self._vocab = vocab
300298
self.special_token_ids = special_token_ids
@@ -308,7 +306,11 @@ def __init__(self,
308306
# Iterate through each pair of DFA state and next terminals and store the overapproximate tokens
309307
self._lookup_table = LookupTable(vocab, special_token_ids, indentation=indentation, mode=mode)
310308
terminal_names = [terminal.name for terminal in terminals]
311-
self._store_overapproximate_tokens(terminal_names, vocab)
309+
310+
followings_terminas_map = None
311+
if parse_table is not None:
312+
followings_terminas_map = self._compute_following_terminals_map(terminal_names, parse_table)
313+
self._store_overapproximate_tokens(terminal_names, vocab, followings_terminas_map)
312314

313315
self.indentation = indentation
314316

@@ -325,7 +327,14 @@ def set_ignore_whitespace(self, terminals: Iterable[TerminalDef], ignore_termina
325327
return ignore_whitespace
326328

327329
@staticmethod
328-
def load_dfa_mask_store(grammar: Grammar, tokenizer, use_cache=True, logger=common.EmptyLogger(), mode='grammar_strict'):
330+
def load_dfa_mask_store(
331+
grammar: Grammar,
332+
tokenizer,
333+
use_cache=True,
334+
logger=None,
335+
mode='grammar_strict',
336+
parse_table=None
337+
):
329338
'''
330339
Loads the dfa for the given language and tokenizer. If the dfa is not cached, it is created and cached.
331340
'''
@@ -351,7 +360,17 @@ def load_dfa_mask_store(grammar: Grammar, tokenizer, use_cache=True, logger=comm
351360
simplifications = grammar.simplifications()
352361
os.makedirs(dfa_dir, exist_ok=True)
353362

354-
mask_store = DFAMaskStore(base_parser.terminals, vocab, simplifications=simplifications, special_token_ids=[tokenizer.eos_token_id], mode=mode, ignore_terminals=base_parser.ignore_tokens)
363+
start_time = time.time()
364+
mask_store = DFAMaskStore(
365+
base_parser.terminals,
366+
vocab,
367+
simplifications=simplifications,
368+
special_token_ids=[tokenizer.eos_token_id],
369+
mode=mode,
370+
ignore_terminals=base_parser.ignore_tokens,
371+
parse_table=parse_table
372+
)
373+
print(f"Time taken to create DFA mask store: {time.time() - start_time} seconds", flush=True)
355374

356375
pickle.dump(mask_store, open(dfa_path, 'wb'))
357376
return mask_store
@@ -360,12 +379,42 @@ def _get_default_mask(self) -> torch.Tensor:
360379
mask = torch.zeros(len(self._vocab), dtype=torch.bool)
361380
return mask
362381

363-
def _store_overapproximate_tokens(self, terminals: Iterable[str], vocab: Iterable[str]):
382+
def _compute_following_terminals_map(self, terminals: Iterable[str], parse_table) -> defaultdict:
383+
"""
384+
From terminals, filter out terminals that cannot follow the current terminal
385+
according to the grammar.
386+
387+
If in the parsing table Action[cur_terminal, parser_state] = 'shift, new_parser_state' then next terminals
388+
are the terminals that are legal in new_parser_state.
389+
"""
390+
following_terminals_map = defaultdict(set)
391+
terminals_set = set(terminals)
392+
393+
# We iterate through each cur_terminal:
394+
for cur_terminal in terminals:
395+
# We iterate through each parser_state:
396+
for _, row in parse_table.states.items():
397+
if cur_terminal in row:
398+
action = row[cur_terminal]
399+
# -> If we see a shift action to new_parser_state
400+
if str(action[0]) == 'Shift':
401+
new_parser_state = action[1]
402+
for next_terminal in parse_table.states[new_parser_state]:
403+
# Lark parse_table stores non-terminals and terminals together
404+
if next_terminal in terminals_set:
405+
# -> -> we add the terminals that are legal in new_parser_state
406+
following_terminals_map[cur_terminal].add(next_terminal)
407+
408+
return following_terminals_map
409+
410+
411+
def _store_overapproximate_tokens(self, terminals: Iterable[str], vocab: Iterable[str], followings_terminas_map: dict=None):
364412
"""
365413
Stores the overapproximate tokens for each dfa state and next terminals
366414
"""
367415
all_dfa_states = self._dfas.states()
368416
pbar = tqdm(total=len(all_dfa_states))
417+
369418
for dfa_state in all_dfa_states:
370419
for token_idx, token in enumerate(vocab):
371420
is_special_token = token_idx in self.special_token_ids
@@ -375,12 +424,18 @@ def _store_overapproximate_tokens(self, terminals: Iterable[str], vocab: Iterabl
375424
self._lookup_table.dfa_state_and_next_terminal_to_tokens_add(
376425
dfa_state, '$END', token_idx)
377426
else:
378-
self._process_regular_tokens(terminals, dfa_state, token_idx, token)
427+
if followings_terminas_map is not None and dfa_state.terminal in followings_terminas_map:
428+
following_terminals = followings_terminas_map[dfa_state.terminal]
429+
else:
430+
following_terminals = terminals
431+
432+
self._process_regular_tokens(following_terminals, dfa_state, token_idx, token)
379433

380434
pbar.update(1)
381435

382436
def _process_regular_tokens(self, terminals, dfa_state, token_idx, token):
383437
remainder = token.replace('\t', ' ')
438+
384439
is_valid, remainder = self._dfas.consume_prefix(dfa_state, remainder)
385440
if is_valid:
386441
if remainder == '':
@@ -455,12 +510,10 @@ def _lookup_next_tokens(self, dfa_states: Iterable[DFAState], r: ParseResult) ->
455510
elif len(accept_sequence) == 2:
456511
overapprox_token_ids |= self._lookup_next_tokens_for_dfa_state(dfa_state, accept_sequence[1])
457512
elif len(accept_sequence) == 3:
458-
# This is useful in under-approximating `grammar_strict` mode as they help improve the precision of SynCode
459-
if self._mode == 'grammar_strict':
460-
# If the DFA state is a final state we can jump to the start of next terminal
461-
if self._dfas.is_final(dfa_state):
462-
ignore_init_state = self._dfas.initial(accept_sequence[1])
463-
overapprox_token_ids |= self._lookup_next_tokens_for_dfa_state(ignore_init_state, accept_sequence[2])
513+
# If the DFA state is a final state we can jump to the start of next terminal
514+
if self._dfas.is_final(dfa_state):
515+
ignore_init_state = self._dfas.initial(accept_sequence[1])
516+
overapprox_token_ids |= self._lookup_next_tokens_for_dfa_state(ignore_init_state, accept_sequence[2])
464517
else:
465518
raise ValueError(f"Invalid accept sequence: {accept_sequence}")
466519
return overapprox_token_ids

syncode/grammar_decoder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,19 @@ def __init__(self,
4949
# Ignore whitespace tokens
5050
self._ignore_whitespace = self._get_ignore_whitespace(self.grammar)
5151

52+
# Create parsers
53+
self.inc_parser: Iterator[IncrementalParser] = create_parser(self.grammar, logger=self.logger, parser=parser, ignore_whitespace=self._ignore_whitespace)
54+
5255
# Load dfa mask store
5356
self.dfa_mask_store = DFAMaskStore.load_dfa_mask_store(
5457
grammar=self.grammar,
5558
tokenizer=self.tokenizer,
5659
use_cache=use_cache,
5760
logger=self.logger,
5861
mode=mode,
62+
parse_table=self.inc_parser.base_parser.parser.parser._parse_table
5963
)
6064

61-
# Create parsers
62-
self.inc_parser: Iterator[IncrementalParser] = create_parser(self.grammar, logger=self.logger, parser=parser, ignore_whitespace=self._ignore_whitespace)
63-
6465

6566
def _log_current_status(self, partial_code, r: ParseResult):
6667
self.logger.log_code('Partial code', partial_code)

0 commit comments

Comments
 (0)