77from syncode .parsers import create_parser , create_base_parser
88from syncode .mask_store .mask_store import MaskStore
99from syncode .parsers .grammars import Grammar
10+ import logging
11+ logger = logging .getLogger (__name__ )
12+
1013
1114# Set to True for debugging
1215DEBUG = True
@@ -18,15 +21,16 @@ class SyncodeLogitsProcessor(LogitsProcessor):
1821 Args:
1922 grammar (str): The grammar to use for parsing e.g. "python".
2023 tokenizer (PreTrainedTokenizer): The tokenizer to use for decoding.
21- logger (common.Logger): The logger to use for logging.
2224 use_cache (bool, optional): Whether to use the cache. Defaults to True.
2325 parse_output_only (bool, optional): Whether to parse the prompt. Defaults to False.
26+ num_samples (int, optional): The number of sequences to generate. Defaults to 1.
2427 dev_mode (bool, optional): Whether to run in development mode. Defaults to False.
28+ parser (str, optional): The parser to use. Defaults to 'lalr'.
29+ mode (str, optional): The mode to use. Defaults to 'grammar_mask'.
2530 """
2631 def __init__ (self ,
2732 grammar : Grammar ,
2833 tokenizer : PreTrainedTokenizer ,
29- logger : common .Logger = common .EmptyLogger (),
3034 use_cache = True ,
3135 parse_output_only = True ,
3236 num_samples = 1 ,
@@ -38,7 +42,6 @@ def __init__(self,
3842 self .byte_tokenizer = ByteTokenizer (tokenizer )
3943
4044 self .grammar = grammar
41- self .logger = logger
4245 self .dev_mode = dev_mode
4346 self .batch_size = num_samples
4447 self .parse_failed = False
@@ -55,23 +58,17 @@ def __init__(self,
5558 self ._ignore_whitespace = self ._get_ignore_whitespace (self .grammar )
5659
5760 # Create parser
58- self .inc_parser : IncrementalParser = create_parser (self .grammar , logger = self . logger , parser = parser , ignore_whitespace = self ._ignore_whitespace )
61+ self .inc_parser : IncrementalParser = create_parser (self .grammar , parser = parser , ignore_whitespace = self ._ignore_whitespace )
5962
6063 # Load dfa mask store
6164 self .dfa_mask_store = MaskStore .init_mask_store (
6265 grammar = self .grammar ,
6366 tokenizer = self .tokenizer ,
6467 use_cache = use_cache ,
65- logger = self .logger ,
6668 mode = mode ,
67- parse_table = self .inc_parser .base_parser .parser .parser ._parse_table ,
6869 )
69-
70+
7071
71- def _log_current_status (self , partial_code , r : ParseResult ):
72- self .logger .log_code ('Partial code' , partial_code )
73- self .logger .log (repr (r ))
74-
7572 def _get_ignore_whitespace (self , grammar ):
7673 """
7774 Check if the grammar allows whitespace tokens to be ignored.
@@ -158,11 +155,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
158155 res , skip = self ._parse_partial_code (idx , partial_code , remainder_bytes , accepted_generation = True )
159156 if skip : continue
160157
161- accept_mask = self .dfa_mask_store .get_accept_mask (res , logger = self .logger )
162-
163- if DEBUG :
164- self ._log_current_status (partial_code , res )
165- greedy_token = self .tokenizer .decode (scores [idx ].argmax (dim = - 1 ))
158+ accept_mask = self .dfa_mask_store .get_accept_mask (res )
166159
167160 if torch .sum (accept_mask ) != 0 : # If there are acceptable tokens for the current partial code
168161 if len (scores [idx ]) > len (accept_mask ):
@@ -172,11 +165,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
172165 accept_mask = accept_mask [: len (scores [idx ])]
173166 scores [idx ] = scores [idx ].masked_fill (~ accept_mask .to (scores .device ), - float ("inf" ))
174167 else : # Otherwise, report the error and mask no tokens
175- self .logger .log ('No acceptable tokens for the current partial code!' )
176- self ._log_current_status (partial_code , res )
177-
178- # For debugging - remove later
179- if DEBUG : self ._debug_greedy (scores , idx , partial_code , res , greedy_token )
168+ logger .debug ('No acceptable tokens for the current partial code!' )
169+ logger .debug (repr (res ))
180170
181171 return scores
182172
@@ -239,28 +229,6 @@ def update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
239229 if accept_seq [0 ] == '$END' or accept_seq [0 ] == 'EOF' :
240230 self .last_valid_state [idx ] = len (partial_code ) - len (r .remainder )
241231
242- def _debug_greedy (self , scores , idx , partial_code , r , greedy_token ):
243- greedy_grammar_token = self .tokenizer .decode (scores [idx ].argmax (dim = - 1 ))
244- if greedy_token != greedy_grammar_token :
245- self ._log_greedy_difference (greedy_grammar_token , partial_code , r , greedy_token )
246-
247- def _log_greedy_difference (self , greedy_grammar_token , partial_code , r , greedy_token ):
248- self .logger .log_check (f"Greedy token and greedy grammar-based token do not match!" )
249- self .logger .log (f"Greedy token: { repr (greedy_token )} " )
250- self .logger .log (f"Greedy grammar-based token: { repr (greedy_grammar_token )} " )
251- self ._log_current_status (partial_code , r )
252-
253- def print_debug (self ):
254- print ('-' * 50 )
255- print ('Parsed terminals:' )
256-
257- name_to_pattern = {}
258- for term in self .inc_parser .base_parser .terminals :
259- name_to_pattern [term .name ] = term .pattern
260-
261- for token in self .inc_parser .parsed_lexer_tokens :
262- print (f"(type: { name_to_pattern [token .type ]} | value: '{ token .value } ')" )
263- print ('-' * 50 )
264232
265233 @staticmethod
266234 def _bytes_to_string (byte_sequence : bytes ) -> tuple [str , bytes ]:
0 commit comments