Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 238 additions & 13 deletions notebooks/tests/builtin_grammar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,9 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"outputs": [],
"source": [
"import torch\n",
"from syncode import SyncodeLogitsProcessor\n",
Expand Down Expand Up @@ -75,8 +66,6 @@
}
],
"source": [
"# grammar_str = \"python\"\n",
"# grammar_str = \"go\"\n",
"grammar_str = \"java\"\n",
"\n",
"grammar = Grammar(grammar_str)\n",
Expand Down Expand Up @@ -105,6 +94,242 @@
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
"print(\"[OUTPUT]\", output_str)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Write a python function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
" \n",
"\n",
"--------------------------------------------------\n",
"Parsing failed! Falling back to unconstrained decoding.\n",
"Exception: Unexpected token Token('NAME', 'simple') at line 3, column 11.\n",
"Expected one of: \n",
"\t* __ANON_9\n",
"\t* __ANON_21\n",
"\t* AMPERSAND\n",
"\t* RPAR\n",
"\t* __ANON_4\n",
"\t* LESSTHAN\n",
"\t* IF\n",
"\t* STAR\n",
"\t* RSQB\n",
"\t* __ANON_5\n",
"\t* __ANON_17\n",
"\t* LSQB\n",
"\t* SLASH\n",
"\t* MINUS\n",
"\t* VBAR\n",
"\t* _NL\n",
"\t* FROM\n",
"\t* __ANON_20\n",
"\t* EQUAL\n",
"\t* __ANON_22\n",
"\t* __ANON_13\n",
"\t* OR\n",
"\t* SEMICOLON\n",
"\t* PLUS\n",
"\t* LPAR\n",
"\t* CIRCUMFLEX\n",
"\t* FOR\n",
"\t* __ANON_2\n",
"\t* NOT\n",
"\t* AT\n",
"\t* __ANON_10\n",
"\t* COMMA\n",
"\t* __ANON_18\n",
"\t* COLON\n",
"\t* MORETHAN\n",
"\t* AS\n",
"\t* __ANON_6\n",
"\t* ELSE\n",
"\t* __ANON_16\n",
"\t* __ANON_11\n",
"\t* DOT\n",
"\t* IN\n",
"\t* __ANON_7\n",
"\t* ASYNC\n",
"\t* IS\n",
"\t* RBRACE\n",
"\t* __ANON_8\n",
"\t* __ANON_3\n",
"\t* AND\n",
"\t* __ANON_15\n",
"\t* __ANON_19\n",
"\t* __ANON_14\n",
"\t* PERCENT\n",
"\t* __ANON_12\n",
"\t* __ANON_1\n",
"\n",
"Partial code: ### Printing 'Hello World' in Reverse\n",
"\n",
"Here is a simple Python\n",
"Parsed lexical tokens: [Token('_NL', \"### Printing 'Hello World' in Reverse\\n\\n\"), Token('NAME', 'Here'), Token('IS', 'is'), Token('NAME', 'a'), Token('NAME', 'simple')]\n",
"--------------------------------------------------\n",
"[OUTPUT] ### Printing 'Hello World' in Reverse\n",
"\n",
"Here is a simple Python function that prints 'Hello World' in reverse:\n",
"\n",
"```python\n",
"def print_hello_world_reverse():\n",
" \"\"\"\n",
" Prints 'Hello World' in reverse.\n",
" \"\"\"\n",
" print(\"Hello World\")\n",
"\n",
"# Example usage:\n",
"print_hello_world_reverse()\n",
"```\n",
"\n",
"When you run this code, it will output:\n",
"```\n",
"olleH dlroW\n",
"```\n",
"\n",
"Alternatively, you can also use slicing to reverse the string:\n",
"\n",
"```python\n",
"def print_hello_world_reverse():\n",
" \"\"\"\n",
" Prints 'Hello World' in reverse.\n",
" \"\"\"\n",
" print(\" \".join([\"H\", \"e\", \"l\", \"l\", \"o\", \" \", \"W\", \"o\", \"r\", \"l\", \"d\"])\n",
"\n",
"# Example usage:\n",
"print_hello_world_reverse()\n",
"```\n",
"\n",
"This will output:\n",
"```\n",
"olleH dlroW\n",
"```\n",
"\n",
"Note: The `join()` method is used to concatenate the elements of a list into a single string, which is then printed.\n"
]
}
],
"source": [
"grammar_str = \"python\"\n",
"\n",
"grammar = Grammar(grammar_str)\n",
"syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n",
"\n",
"prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n",
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
"prompt = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
" )\n",
"print(\"[PROMPT]\", prompt, \"\\n\")\n",
"\n",
"syncode_logits_processor.reset(prompt)\n",
"\n",
"inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
"\n",
"attention_mask = torch.ones_like(inputs)\n",
"output = model.generate(\n",
" inputs,\n",
" attention_mask=attention_mask,\n",
" max_length=512, \n",
" num_return_sequences=1, \n",
" pad_token_id=tokenizer.eos_token_id, \n",
" logits_processor=[syncode_logits_processor]\n",
" )\n",
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
"print(\"[OUTPUT]\", output_str)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Write a go function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
" \n",
"\n",
"[OUTPUT] // \n",
"\n",
"package main\n",
"\n",
"import (\n",
" \"fmt\" // Import the fmt package\n",
")\n",
"\n",
"// Function to print 'hello world' in reverse\n",
"func printHelloWorld() {\n",
" // Declare a variable to hold the string 'hello world'\n",
" var s string = \"hello world\" // Define the string\n",
" // Use string reverse() to reverse the string\n",
" var reversed string = strings.Reverses(s) // Reverse the string\n",
" // Print the reversed string\n",
" fmt.Println(reversed) // Print the reversed string\n",
"}\n",
"\n",
"func main() {\n",
" // Call the function to print 'hello world' in reverse\n",
" printHelloWorld() // Call the function\n",
"} \n",
"\n",
"// Note: The string reverse() function in Go returns a string slice, not a string.\n",
"// If you want to convert the string slice to a string, you can use the string slice's string() method.\n",
"// Here, we are using the Reverses() function which returns a string slice.\n"
]
}
],
"source": [
"grammar_str = \"go\"\n",
"\n",
"grammar = Grammar(grammar_str)\n",
"syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n",
"\n",
"prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n",
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
"prompt = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
" )\n",
"print(\"[PROMPT]\", prompt, \"\\n\")\n",
"\n",
"syncode_logits_processor.reset(prompt)\n",
"\n",
"inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
"\n",
"attention_mask = torch.ones_like(inputs)\n",
"output = model.generate(\n",
" inputs,\n",
" attention_mask=attention_mask,\n",
" max_length=512, \n",
" num_return_sequences=1, \n",
" pad_token_id=tokenizer.eos_token_id, \n",
" logits_processor=[syncode_logits_processor]\n",
" )\n",
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
"print(\"[OUTPUT]\", output_str)"
]
}
],
"metadata": {
Expand Down
Loading