Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions notebooks/tests/debug_grammar.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 11.27it/s]\n"
]
}
],
"source": [
"from syncode.infer import Syncode\n",
"\n",
"# Load the unconstrained original model\n",
"model_name = \"microsoft/Phi-3-mini-4k-instruct\"\n",
"\n",
"grammar = r\"\"\" \n",
" start: instruction\n",
" instruction: \"Press the \" button \" button\"\n",
" button: \"power\" | \"volume up\" | \"volume down\" | \"home\" | \"back\" | \"recent apps\" | \"menu\" | \"search\"\n",
" \"\"\"\n",
"\n",
"syn_llm = Syncode(model=model_name, grammar=grammar)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------\n",
"Parsed terminals:\n",
"(type: 'Press\\\\ the\\\\ ' | value: 'Press the ')\n",
"(type: 'back' | value: 'back')\n",
"(type: '\\\\ button' | value: ' button')\n",
"--------------------------------------------------\n"
]
},
{
"data": {
"text/plain": [
"['Press the back button']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prompt = \"How do I go back?\"\n",
"syn_llm.infer(prompt, debug=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "codex",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
13 changes: 12 additions & 1 deletion syncode/grammar_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,15 @@ def _log_greedy_difference(self, greedy_grammar_token, partial_code, r, greedy_t
self.logger.log(f"Greedy token: {repr(greedy_token)}")
self.logger.log(f"Greedy grammar-based token: {repr(greedy_grammar_token)}")
self._log_current_status(partial_code, r)


def print_debug(self):
print('-'*50)
print('Parsed terminals:')

name_to_pattern = {}
for term in self.inc_parser.base_parser.terminals:
name_to_pattern[term.name] = term.pattern

for token in self.inc_parser.parsed_lexer_tokens:
print(f"(type: {name_to_pattern[token.type]} | value: '{token.value}')")
print('-'*50)
18 changes: 12 additions & 6 deletions syncode/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from syncode.evaluation.fol_eval import FOLEval


def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", grammar=None, dataset="input", num_few_shot=0, chat_mode=False, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", num_tasks=None, task_id=None, seed=None, opp=True, **kwargs):
def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", grammar=None, dataset="input", num_few_shot=0, chat_mode=False, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", num_tasks=None, task_id=None, seed=None, opp=True, debug=False, **kwargs):

syncode = Syncode(model, mode=mode, quantize=quantize, device=device, grammar=grammar, chat_mode=chat_mode, dev_mode=dev_mode, log_level=log_level, new_mask_store=new_mask_store, parser=parser, seed=seed, opp=opp, **kwargs)

if dataset == "input":
syncode.infer()
syncode.infer(debug=debug)
else:
# Setup output directory and logger
num_samples = kwargs.get('num_return_sequences', 1)
Expand Down Expand Up @@ -146,8 +146,8 @@ def __init__(
def is_grammar_mode(self):
return self.mode == 'grammar_mask' or self.mode == 'grammar_strict'

def infer(self, prompt=None, stop_words=None):
output = self.user_input(prompt, stop_words=stop_words)
def infer(self, prompt=None, stop_words=None, debug=False):
output = self.user_input(prompt, stop_words=stop_words, debug=debug)
return output

def evaluate(
Expand Down Expand Up @@ -196,17 +196,23 @@ def evaluate(
logger.close()
return output

def user_input(self, prompt:str, stop_words=None):
def user_input(self, prompt:str, stop_words=None, debug=False):
"""
Run user input on the model with grammar mask

Args:
prompt (str): User input prompt
stop_words (list, optional): Stop words to use. Defaults to None.
debug (bool, optional): Debug mode. Defaults to False.
"""
if prompt:
if self.grammar_decoder is not None: # TODO: Remove this check
self.grammar_decoder.reset(prompt)

if self.chat_mode:
return self.model.generate_chat_completion_grammar(prompt)
else:
return self.model.generate_batch_completion_grammar(prompt, self.num_samples, stop_words=stop_words)
return self.model.generate_batch_completion_grammar(prompt, self.num_samples, stop_words=stop_words, debug=debug)
else:
while True:
prompt = input('Enter prompt: ')
Expand Down
27 changes: 23 additions & 4 deletions syncode/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,23 @@ def get_grammar_decoder(self):
return None

@torch.inference_mode()
def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=None, return_token_ids=False) -> Iterable[str]:
def generate_batch_completion_grammar(
self,
prompt,
batch_size,
stop_words=None,
return_token_ids=False,
debug=False
) -> Iterable[str]:
'''
Generates batch_size completions for the given prompt.

Args:
prompt (str): The prompt for which completions are generated.
batch_size (int): The number of completions to generate.
stop_words (list): A list of stop words. If the completion ends with any of the stop words, the completion is returned.
return_token_ids (bool): If True, returns the token ids of the completions.
debug (bool): If True, prints debug information.
'''
# Reset the grammar decoder
if self.grammar_decoder is not None:
Expand Down Expand Up @@ -99,7 +113,8 @@ def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=None,
gen_config,
gen_mode,
grammar_decoder=self.grammar_decoder,
stop_criteria=stop_criteria
stop_criteria=stop_criteria,
debug=debug
)
else:
if self.opp:
Expand Down Expand Up @@ -137,7 +152,8 @@ def _generate(
gen_config:GenerationConfig,
gen_mode:GenerationMode,
grammar_decoder:SyncodeLogitsProcessor=None,
stop_criteria:StoppingCriteria=[]
stop_criteria:StoppingCriteria=[],
debug=False
):
"""
We support greedy search and sampling for batch size 1 otherwise we use the generate function from transformers library.
Expand Down Expand Up @@ -192,7 +208,10 @@ def _generate(

# Update attention mask
attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype).to(self.device)], dim=-1)


if debug:
grammar_decoder.print_debug()

return token_ids

def _get_next_token(self, gen_mode, token_ids, logits_processor, next_token_scores):
Expand Down
23 changes: 23 additions & 0 deletions syncode/parsers/grammars/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@ LR(1) is more powerful in terms of representing certain syntax, however common f
In most cases, it should be possible to fix these errors by rewriting some of the grammar rules.
However, in some rare cases it is possible that it is impossible to represent the grammar as LR(1)

### Debugging Grammars

SynCode provides a flag `--debug` to help debug grammars. This flag will print out the parsed terminals and their corresponding values in generation.
(Refer to [this](../../../notebooks/tests/debug_grammar.ipynb) notebook for code example)

For example, consider the following grammar:
```ebnf
start: "foo" "(" ident ")" ";"
ident: [a-z]+
```

Given the output `foo(abc);`, the debug flag will print:
```
--------------------------------------------------
Parsed terminals:
(type: 'FOO', value: 'foo')
(type: 'LPAR', value: '(')
(type: 'IDENT', value: 'abc')
(type: 'RPAR', value: ')')
(type: 'SEMI', value: ';')
--------------------------------------------------
```

### Lexer Ambiguity
When defining a grammar, be cautious of lexer ambiguities that arise when one terminal is a substring of another.
In some cases, these ambiguities can lead to unexpected behavior in the parser.
Expand Down
Loading