From 9c2a2442c7cb16ac69577894b2aa9b9aa70c462b Mon Sep 17 00:00:00 2001 From: Shubham Ugare Date: Mon, 2 Dec 2024 22:10:15 -0600 Subject: [PATCH] Add a debug flag --- notebooks/tests/debug_grammar.ipynb | 88 +++++++++++++++++++++++++++++ syncode/grammar_decoder.py | 13 ++++- syncode/infer.py | 18 ++++-- syncode/language_model.py | 27 +++++++-- syncode/parsers/grammars/README.md | 23 ++++++++ 5 files changed, 158 insertions(+), 11 deletions(-) create mode 100644 notebooks/tests/debug_grammar.ipynb diff --git a/notebooks/tests/debug_grammar.ipynb b/notebooks/tests/debug_grammar.ipynb new file mode 100644 index 00000000..37e4d372 --- /dev/null +++ b/notebooks/tests/debug_grammar.ipynb @@ -0,0 +1,88 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 11.27it/s]\n" + ] + } + ], + "source": [ + "from syncode.infer import Syncode\n", + "\n", + "# Load the unconstrained original model\n", + "model_name = \"microsoft/Phi-3-mini-4k-instruct\"\n", + "\n", + "grammar = r\"\"\" \n", + " start: instruction\n", + " instruction: \"Press the \" button \" button\"\n", + " button: \"power\" | \"volume up\" | \"volume down\" | \"home\" | \"back\" | \"recent apps\" | \"menu\" | \"search\"\n", + " \"\"\"\n", + "\n", + "syn_llm = Syncode(model=model_name, grammar=grammar)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------------------------------------\n", + "Parsed terminals:\n", + "(type: 'Press\\\\ the\\\\ ' | value: 'Press the ')\n", + "(type: 'back' | value: 'back')\n", + "(type: '\\\\ button' | value: ' button')\n", + "--------------------------------------------------\n" + ] + }, + { + "data": { + "text/plain": [ + "['Press the back button']" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompt = \"How do I go back?\"\n", + "syn_llm.infer(prompt, debug=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "codex", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/syncode/grammar_decoder.py b/syncode/grammar_decoder.py index a0993f0d..4aed54b4 100644 --- a/syncode/grammar_decoder.py +++ b/syncode/grammar_decoder.py @@ -204,4 +204,15 @@ def _log_greedy_difference(self, greedy_grammar_token, partial_code, r, greedy_t self.logger.log(f"Greedy token: {repr(greedy_token)}") self.logger.log(f"Greedy grammar-based token: {repr(greedy_grammar_token)}") self._log_current_status(partial_code, r) - \ No newline at end of file + + def print_debug(self): + print('-'*50) + print('Parsed terminals:') + + name_to_pattern = {} + for term in self.inc_parser.base_parser.terminals: + name_to_pattern[term.name] = term.pattern + + for token in self.inc_parser.parsed_lexer_tokens: + print(f"(type: {name_to_pattern[token.type]} | value: '{token.value}')") + print('-'*50) \ No newline at end of file diff --git a/syncode/infer.py b/syncode/infer.py index 88245231..d2090974 100644 --- a/syncode/infer.py +++ b/syncode/infer.py @@ -15,12 +15,12 @@ from syncode.evaluation.fol_eval import FOLEval -def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", grammar=None, dataset="input", num_few_shot=0, chat_mode=False, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", num_tasks=None, task_id=None, seed=None, opp=True, **kwargs): +def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", grammar=None, dataset="input", num_few_shot=0, chat_mode=False, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", num_tasks=None, task_id=None, seed=None, opp=True, debug=False, **kwargs): syncode = Syncode(model, mode=mode, quantize=quantize, device=device, grammar=grammar, chat_mode=chat_mode, dev_mode=dev_mode, log_level=log_level, new_mask_store=new_mask_store, parser=parser, seed=seed, opp=opp, **kwargs) if dataset == "input": - syncode.infer() + syncode.infer(debug=debug) else: # Setup output directory and logger num_samples = kwargs.get('num_return_sequences', 1) @@ -146,8 +146,8 @@ def __init__( def is_grammar_mode(self): return self.mode == 'grammar_mask' or self.mode == 'grammar_strict' - def infer(self, prompt=None, stop_words=None): - output = self.user_input(prompt, stop_words=stop_words) + def infer(self, prompt=None, stop_words=None, debug=False): + output = self.user_input(prompt, stop_words=stop_words, debug=debug) return output def evaluate( @@ -196,17 +196,23 @@ def evaluate( logger.close() return output - def user_input(self, prompt:str, stop_words=None): + def user_input(self, prompt:str, stop_words=None, debug=False): """ Run user input on the model with grammar mask + + Args: + prompt (str): User input prompt + stop_words (list, optional): Stop words to use. Defaults to None. + debug (bool, optional): Debug mode. Defaults to False. """ if prompt: if self.grammar_decoder is not None: # TODO: Remove this check self.grammar_decoder.reset(prompt) + if self.chat_mode: return self.model.generate_chat_completion_grammar(prompt) else: - return self.model.generate_batch_completion_grammar(prompt, self.num_samples, stop_words=stop_words) + return self.model.generate_batch_completion_grammar(prompt, self.num_samples, stop_words=stop_words, debug=debug) else: while True: prompt = input('Enter prompt: ') diff --git a/syncode/language_model.py b/syncode/language_model.py index 5632b88b..ec3be393 100644 --- a/syncode/language_model.py +++ b/syncode/language_model.py @@ -67,9 +67,23 @@ def get_grammar_decoder(self): return None @torch.inference_mode() - def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=None, return_token_ids=False) -> Iterable[str]: + def generate_batch_completion_grammar( + self, + prompt, + batch_size, + stop_words=None, + return_token_ids=False, + debug=False + ) -> Iterable[str]: ''' Generates batch_size completions for the given prompt. + + Args: + prompt (str): The prompt for which completions are generated. + batch_size (int): The number of completions to generate. + stop_words (list): A list of stop words. If the completion ends with any of the stop words, the completion is returned. + return_token_ids (bool): If True, returns the token ids of the completions. + debug (bool): If True, prints debug information. ''' # Reset the grammar decoder if self.grammar_decoder is not None: @@ -99,7 +113,8 @@ def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=None, gen_config, gen_mode, grammar_decoder=self.grammar_decoder, - stop_criteria=stop_criteria + stop_criteria=stop_criteria, + debug=debug ) else: if self.opp: @@ -137,7 +152,8 @@ def _generate( gen_config:GenerationConfig, gen_mode:GenerationMode, grammar_decoder:SyncodeLogitsProcessor=None, - stop_criteria:StoppingCriteria=[] + stop_criteria:StoppingCriteria=[], + debug=False ): """ We support greedy search and sampling for batch size 1 otherwise we use the generate function from transformers library. @@ -192,7 +208,10 @@ def _generate( # Update attention mask attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype).to(self.device)], dim=-1) - + + if debug: + grammar_decoder.print_debug() + return token_ids def _get_next_token(self, gen_mode, token_ids, logits_processor, next_token_scores): diff --git a/syncode/parsers/grammars/README.md b/syncode/parsers/grammars/README.md index d8477b45..4c085765 100644 --- a/syncode/parsers/grammars/README.md +++ b/syncode/parsers/grammars/README.md @@ -19,6 +19,29 @@ LR(1) is more powerful in terms of representing certain syntax, however common f In most cases, it should be possible to fix these errors by rewriting some of the grammar rules. However, in some rare cases it is possible that it is impossible to represent the grammar as LR(1) +### Debugging Grammars + +SynCode provides a flag `--debug` to help debug grammars. This flag will print out the parsed terminals and their corresponding values in generation. +(Refer to [this](../../../notebooks/tests/debug_grammar.ipynb) notebook for code example) + +For example, consider the following grammar: +```ebnf +start: "foo" "(" ident ")" ";" +ident: [a-z]+ +``` + +Given the output `foo(abc);`, the debug flag will print: +``` +-------------------------------------------------- +Parsed terminals: +(type: 'FOO', value: 'foo') +(type: 'LPAR', value: '(') +(type: 'IDENT', value: 'abc') +(type: 'RPAR', value: ')') +(type: 'SEMI', value: ';') +-------------------------------------------------- +``` + ### Lexer Ambiguity When defining a grammar, be cautious of lexer ambiguities that arise when one terminal is a substring of another. In some cases, these ambiguities can lead to unexpected behavior in the parser.