Skip to content

Commit 36f4346

Browse files
committed
Add different logging system
1 parent 760b07f commit 36f4346

File tree

13 files changed

+501
-359
lines changed

13 files changed

+501
-359
lines changed

syncode/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
from syncode.infer import Syncode
22
from grammar_decoder import SyncodeLogitsProcessor
33
from parsers.grammars import Grammar
4+
import common
5+
6+
common.setup_logging()

syncode/common.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import logging
12
import os
3+
import sys
24
import torch
35
from transformers import AutoTokenizer, AutoModelForCausalLM
46

@@ -36,6 +38,46 @@ def get_output_path(model_name, grammar, dataset, num_samples, mode):
3638
os.makedirs(out_dir, exist_ok=True)
3739
return out_dir,out_path
3840

41+
# This is the setup for Python logging
42+
def setup_logging(level=None):
43+
"""
44+
Configure the root logger for both application and test usage.
45+
46+
This function is safe to call multiple times - it will only configure
47+
logging once to avoid duplicate handlers.
48+
49+
Args:
50+
level: Override the logging level. If None, uses the LOG_LEVEL
51+
environment variable or defaults to INFO.
52+
53+
Returns:
54+
The root logger
55+
"""
56+
# Determine the logging level
57+
if level is None:
58+
# Get level from environment or default to INFO
59+
level_name = os.environ.get('LOG_LEVEL', 'INFO')
60+
level = getattr(logging, level_name.upper(), logging.INFO)
61+
62+
# Get the root logger
63+
root_logger = logging.getLogger()
64+
65+
# Clear any existing handlers to avoid duplicates
66+
for handler in root_logger.handlers[:]:
67+
root_logger.removeHandler(handler)
68+
69+
# Set the logging level
70+
root_logger.setLevel(level)
71+
72+
# Create a stdout handler
73+
handler = logging.StreamHandler(sys.stdout)
74+
formatter = logging.Formatter('[%(asctime)s-%(name)s] - %(message)s')
75+
handler.setFormatter(formatter)
76+
root_logger.addHandler(handler)
77+
78+
return root_logger
79+
80+
3981
class Logger:
4082
"""
4183
Logger class for logging the output of the model

syncode/grammar_decoder.py

Lines changed: 11 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from syncode.parsers import create_parser, create_base_parser
88
from syncode.mask_store.mask_store import MaskStore
99
from syncode.parsers.grammars import Grammar
10+
import logging
11+
logger = logging.getLogger(__name__)
12+
1013

1114
# Set to True for debugging
1215
DEBUG = True
@@ -18,15 +21,16 @@ class SyncodeLogitsProcessor(LogitsProcessor):
1821
Args:
1922
grammar (str): The grammar to use for parsing e.g. "python".
2023
tokenizer (PreTrainedTokenizer): The tokenizer to use for decoding.
21-
logger (common.Logger): The logger to use for logging.
2224
use_cache (bool, optional): Whether to use the cache. Defaults to True.
2325
parse_output_only (bool, optional): Whether to parse the prompt. Defaults to False.
26+
num_samples (int, optional): The number of sequences to generate. Defaults to 1.
2427
dev_mode (bool, optional): Whether to run in development mode. Defaults to False.
28+
parser (str, optional): The parser to use. Defaults to 'lalr'.
29+
mode (str, optional): The mode to use. Defaults to 'grammar_mask'.
2530
"""
2631
def __init__(self,
2732
grammar: Grammar,
2833
tokenizer: PreTrainedTokenizer,
29-
logger: common.Logger=common.EmptyLogger(),
3034
use_cache=True,
3135
parse_output_only=True,
3236
num_samples=1,
@@ -38,7 +42,6 @@ def __init__(self,
3842
self.byte_tokenizer = ByteTokenizer(tokenizer)
3943

4044
self.grammar = grammar
41-
self.logger = logger
4245
self.dev_mode = dev_mode
4346
self.batch_size = num_samples
4447
self.parse_failed = False
@@ -55,23 +58,17 @@ def __init__(self,
5558
self._ignore_whitespace = self._get_ignore_whitespace(self.grammar)
5659

5760
# Create parser
58-
self.inc_parser: IncrementalParser = create_parser(self.grammar, logger=self.logger, parser=parser, ignore_whitespace=self._ignore_whitespace)
61+
self.inc_parser: IncrementalParser = create_parser(self.grammar, parser=parser, ignore_whitespace=self._ignore_whitespace)
5962

6063
# Load dfa mask store
6164
self.dfa_mask_store = MaskStore.init_mask_store(
6265
grammar=self.grammar,
6366
tokenizer=self.tokenizer,
6467
use_cache=use_cache,
65-
logger=self.logger,
6668
mode=mode,
67-
parse_table=self.inc_parser.base_parser.parser.parser._parse_table,
6869
)
69-
70+
7071

71-
def _log_current_status(self, partial_code, r: ParseResult):
72-
self.logger.log_code('Partial code', partial_code)
73-
self.logger.log(repr(r))
74-
7572
def _get_ignore_whitespace(self, grammar):
7673
"""
7774
Check if the grammar allows whitespace tokens to be ignored.
@@ -158,11 +155,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
158155
res, skip = self._parse_partial_code(idx, partial_code, remainder_bytes, accepted_generation=True)
159156
if skip: continue
160157

161-
accept_mask = self.dfa_mask_store.get_accept_mask(res, logger=self.logger)
162-
163-
if DEBUG:
164-
self._log_current_status(partial_code, res)
165-
greedy_token = self.tokenizer.decode(scores[idx].argmax(dim=-1))
158+
accept_mask = self.dfa_mask_store.get_accept_mask(res)
166159

167160
if torch.sum(accept_mask) != 0: # If there are acceptable tokens for the current partial code
168161
if len(scores[idx]) > len(accept_mask):
@@ -172,11 +165,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
172165
accept_mask = accept_mask[: len(scores[idx])]
173166
scores[idx] = scores[idx].masked_fill(~accept_mask.to(scores.device), -float("inf"))
174167
else: # Otherwise, report the error and mask no tokens
175-
self.logger.log('No acceptable tokens for the current partial code!')
176-
self._log_current_status(partial_code, res)
177-
178-
# For debugging - remove later
179-
if DEBUG: self._debug_greedy(scores, idx, partial_code, res, greedy_token)
168+
logger.debug('No acceptable tokens for the current partial code!')
169+
logger.debug(repr(res))
180170

181171
return scores
182172

@@ -239,28 +229,6 @@ def update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
239229
if accept_seq[0] == '$END' or accept_seq[0] == 'EOF':
240230
self.last_valid_state[idx] = len(partial_code) - len(r.remainder)
241231

242-
def _debug_greedy(self, scores, idx, partial_code, r, greedy_token):
243-
greedy_grammar_token = self.tokenizer.decode(scores[idx].argmax(dim=-1))
244-
if greedy_token != greedy_grammar_token:
245-
self._log_greedy_difference(greedy_grammar_token, partial_code, r, greedy_token)
246-
247-
def _log_greedy_difference(self, greedy_grammar_token, partial_code, r, greedy_token):
248-
self.logger.log_check(f"Greedy token and greedy grammar-based token do not match!")
249-
self.logger.log(f"Greedy token: {repr(greedy_token)}")
250-
self.logger.log(f"Greedy grammar-based token: {repr(greedy_grammar_token)}")
251-
self._log_current_status(partial_code, r)
252-
253-
def print_debug(self):
254-
print('-'*50)
255-
print('Parsed terminals:')
256-
257-
name_to_pattern = {}
258-
for term in self.inc_parser.base_parser.terminals:
259-
name_to_pattern[term.name] = term.pattern
260-
261-
for token in self.inc_parser.parsed_lexer_tokens:
262-
print(f"(type: {name_to_pattern[token.type]} | value: '{token.value}')")
263-
print('-'*50)
264232

265233
@staticmethod
266234
def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:

syncode/infer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ class Syncode:
5656
new_mask_store (bool, optional): Use new DFA mask store. Defaults to False.
5757
5858
dev_mode (bool, optional): Development mode. Defaults to False.
59-
60-
log_level (int, optional): Log level. Defaults to 2. 0 for no logs, 1 for minimal logs, 2 for all logs including time.
6159
6260
opp (bool, optional): Whether to use opportunistic generation. Defaults to True.
6361
"""
@@ -70,7 +68,6 @@ def __init__(
7068
grammar: Optional[str] = None,
7169
parse_output_only: bool = True,
7270
dev_mode: bool = False,
73-
log_level: int = 1,
7471
new_mask_store: bool = False,
7572
parser: Literal["lr", "lalr"] = "lalr",
7673
seed: Optional[int] = None,
@@ -91,7 +88,6 @@ def __init__(
9188
self.num_samples = kwargs.get('num_return_sequences', 1)
9289
self.new_mask_store = new_mask_store
9390
self.parser = parser
94-
self.log_level = log_level
9591

9692
# Set seed
9793
if seed is not None:

syncode/mask_store/fsm_set.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import time
12
import interegular
23
from typing import Any, Optional, Tuple, Iterable, Dict
34
from syncode.mask_store.byte_fsm import ByteFSM
4-
5+
import logging
6+
logger = logging.getLogger(__name__)
57

68
class JointFSMState:
79
"""
@@ -27,6 +29,7 @@ class FSMSet:
2729
Uses external ByteFSM for regex matching.
2830
"""
2931
def __init__(self, terminals: Iterable['MockTerminalDef'], simplifications: Dict[str, str] = {}):
32+
start_time = time.time()
3033
self._terminals_to_byte_fsm: Dict[str, ByteFSM] = {} # Store ByteFSM instances
3134
self.anything_else = interegular.fsm.anything_else
3235
self._simplifications: Dict[str, str] = simplifications
@@ -41,6 +44,7 @@ def __init__(self, terminals: Iterable['MockTerminalDef'], simplifications: Dict
4144
# This handles the regex pattern matching
4245
byte_fsm = ByteFSM(terminal_regex)
4346
self._terminals_to_byte_fsm[terminal.name] = byte_fsm
47+
logger.info(f"FSMs initialized in {time.time() - start_time:.2f} seconds")
4448

4549
def states(self):
4650
"""Returns all possible DFA states for all terminals."""

syncode/mask_store/lookup_table.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,38 @@
44
import regex
55
from syncode.mask_store.mask_store import JointFSMState
66
from syncode.parse_result import IndentationConstraint
7-
from typing import Any, Tuple, Iterable, Dict
7+
from typing import Any, Tuple, Iterable, Dict, Union
8+
import logging
9+
logger = logging.getLogger(__name__)
10+
811

912
class LookupTable:
1013
"""
1114
Stores the overapproximate tokens
1215
"""
13-
def __init__(self, vocab: Iterable[str], special_token_ids: Iterable[int], indentation=False, mode='grammar_mask'):
16+
def __init__(
17+
self,
18+
vocab: Iterable[str],
19+
eos_token_id: int,
20+
special_token_ids: Iterable[int],
21+
indent=False,
22+
mode='grammar_mask'
23+
):
1424
self._fsm_state_and_next_terminal_to_tokens: defaultdict = defaultdict(list)
1525
self._overapprox_lookup: Dict[JointFSMState, Any] = {}
1626
self._exact_lookup: dict = {}
1727
self._mode = mode
1828
self._vocab: Iterable[str] = vocab
19-
self.indentation = indentation
29+
self.indent = indent
30+
31+
# In the default mask, add all tokens that are special tokens except the EOS token
32+
self._default_mask = torch.zeros(len(vocab), dtype=torch.bool)
33+
for token_id in special_token_ids:
34+
if token_id != eos_token_id:
35+
self._default_mask[token_id] = 1
2036

21-
self._default_mask = self._get_default_mask(special_token_ids)
22-
if indentation:
37+
if indent:
38+
logger.info("Indentation mode enabled")
2339
self._whitespace_tokens_map: defaultdict = defaultdict(list)
2440
self._indentation_to_tokens_map: defaultdict = defaultdict(list)
2541
self._create_indentation_to_tokens_map()
@@ -83,18 +99,14 @@ def convert_lookups_from_list_to_mask(self):
8399
self._exact_lookup[key] = self._list_to_mask(val)
84100

85101
# TODO: move this logic to the lookup table
86-
if self.indentation:
102+
if self.indent:
87103
for key, val in self._whitespace_tokens_map.items():
88104
self._whitespace_tokens_map[key] = self._list_to_mask(val)
89105
for key, val in self._indentation_to_tokens_map.items():
90106
self._indentation_to_tokens_map[key] = self._list_to_mask(val)
91107

92-
def _get_default_mask(self, special_token_ids=None) -> torch.Tensor:
93-
if special_token_ids is not None:
94-
mask = torch.zeros(len(self._vocab), dtype=torch.bool)
95-
else:
96-
mask = copy.deepcopy(self._default_mask)
97-
return mask
108+
def _get_default_mask(self) -> torch.Tensor:
109+
return copy.deepcopy(self._default_mask)
98110

99111
def _create_indentation_to_tokens_map(self):
100112
"""
@@ -107,15 +119,36 @@ def _create_indentation_to_tokens_map(self):
107119
else:
108120
self._indentation_to_tokens_map[indent].append(token_idx)
109121

110-
def _get_indent_type(self, s: str) -> Tuple[bool, int]:
111-
m = regex.match(r'[\t ]+', s, partial=True)
122+
def _get_indent_type(self, s: Union[str, bytes]) -> Tuple[bool, int]:
123+
"""
124+
Determine the indentation type and level from a string or bytes input.
125+
126+
Args:
127+
s (Union[str, bytes]): The input string or bytes to analyze
128+
129+
Returns:
130+
Tuple[bool, int]: A tuple containing:
131+
- bool: Whether the input is entirely whitespace
132+
- int: The indentation level (spaces + 4*tabs)
133+
"""
134+
# Convert bytes to string if needed
135+
if isinstance(s, bytes):
136+
try:
137+
s_str = s.decode('utf-8')
138+
except UnicodeDecodeError:
139+
# Handle decode errors by returning default values
140+
return False, 0
141+
else:
142+
s_str = s
143+
144+
m = regex.match(r'[\t ]+', s_str, partial=True)
112145
full_match = False
113146
if m != None:
114147
start, end = m.start(), m.end()
115-
if end == len(s):
148+
if end == len(s_str):
116149
full_match = True
117-
return full_match, s[start: end].count(' ') + 4*s[start: end].count('\t')
118-
return False, 0
150+
return full_match, s_str[start: end].count(' ') + 4*s_str[start: end].count('\t')
151+
return False, 0
119152

120153
def get_indentation_tokens(self, indent_constraint: IndentationConstraint, get_list=False):
121154
"""

0 commit comments

Comments
 (0)