Skip to content

Commit fb0dc52

Browse files
authored
Merge pull request #139 from uiuc-focal-lab/instruct
Add Go and Python example in the builtin grammar notebook
2 parents 12a5e28 + 2c542d2 commit fb0dc52

File tree

1 file changed

+238
-13
lines changed

1 file changed

+238
-13
lines changed

notebooks/tests/builtin_grammar.ipynb

Lines changed: 238 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,9 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 2,
5+
"execution_count": 4,
66
"metadata": {},
7-
"outputs": [
8-
{
9-
"name": "stderr",
10-
"output_type": "stream",
11-
"text": [
12-
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13-
" from .autonotebook import tqdm as notebook_tqdm\n"
14-
]
15-
}
16-
],
7+
"outputs": [],
178
"source": [
189
"import torch\n",
1910
"from syncode import SyncodeLogitsProcessor\n",
@@ -75,8 +66,6 @@
7566
}
7667
],
7768
"source": [
78-
"# grammar_str = \"python\"\n",
79-
"# grammar_str = \"go\"\n",
8069
"grammar_str = \"java\"\n",
8170
"\n",
8271
"grammar = Grammar(grammar_str)\n",
@@ -105,6 +94,242 @@
10594
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
10695
"print(\"[OUTPUT]\", output_str)"
10796
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": 5,
101+
"metadata": {},
102+
"outputs": [
103+
{
104+
"name": "stdout",
105+
"output_type": "stream",
106+
"text": [
107+
"[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
108+
"\n",
109+
"Cutting Knowledge Date: December 2023\n",
110+
"Today Date: 26 Jul 2024\n",
111+
"\n",
112+
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
113+
"\n",
114+
"Write a python function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
115+
"\n",
116+
" \n",
117+
"\n",
118+
"--------------------------------------------------\n",
119+
"Parsing failed! Falling back to unconstrained decoding.\n",
120+
"Exception: Unexpected token Token('NAME', 'simple') at line 3, column 11.\n",
121+
"Expected one of: \n",
122+
"\t* __ANON_9\n",
123+
"\t* __ANON_21\n",
124+
"\t* AMPERSAND\n",
125+
"\t* RPAR\n",
126+
"\t* __ANON_4\n",
127+
"\t* LESSTHAN\n",
128+
"\t* IF\n",
129+
"\t* STAR\n",
130+
"\t* RSQB\n",
131+
"\t* __ANON_5\n",
132+
"\t* __ANON_17\n",
133+
"\t* LSQB\n",
134+
"\t* SLASH\n",
135+
"\t* MINUS\n",
136+
"\t* VBAR\n",
137+
"\t* _NL\n",
138+
"\t* FROM\n",
139+
"\t* __ANON_20\n",
140+
"\t* EQUAL\n",
141+
"\t* __ANON_22\n",
142+
"\t* __ANON_13\n",
143+
"\t* OR\n",
144+
"\t* SEMICOLON\n",
145+
"\t* PLUS\n",
146+
"\t* LPAR\n",
147+
"\t* CIRCUMFLEX\n",
148+
"\t* FOR\n",
149+
"\t* __ANON_2\n",
150+
"\t* NOT\n",
151+
"\t* AT\n",
152+
"\t* __ANON_10\n",
153+
"\t* COMMA\n",
154+
"\t* __ANON_18\n",
155+
"\t* COLON\n",
156+
"\t* MORETHAN\n",
157+
"\t* AS\n",
158+
"\t* __ANON_6\n",
159+
"\t* ELSE\n",
160+
"\t* __ANON_16\n",
161+
"\t* __ANON_11\n",
162+
"\t* DOT\n",
163+
"\t* IN\n",
164+
"\t* __ANON_7\n",
165+
"\t* ASYNC\n",
166+
"\t* IS\n",
167+
"\t* RBRACE\n",
168+
"\t* __ANON_8\n",
169+
"\t* __ANON_3\n",
170+
"\t* AND\n",
171+
"\t* __ANON_15\n",
172+
"\t* __ANON_19\n",
173+
"\t* __ANON_14\n",
174+
"\t* PERCENT\n",
175+
"\t* __ANON_12\n",
176+
"\t* __ANON_1\n",
177+
"\n",
178+
"Partial code: ### Printing 'Hello World' in Reverse\n",
179+
"\n",
180+
"Here is a simple Python\n",
181+
"Parsed lexical tokens: [Token('_NL', \"### Printing 'Hello World' in Reverse\\n\\n\"), Token('NAME', 'Here'), Token('IS', 'is'), Token('NAME', 'a'), Token('NAME', 'simple')]\n",
182+
"--------------------------------------------------\n",
183+
"[OUTPUT] ### Printing 'Hello World' in Reverse\n",
184+
"\n",
185+
"Here is a simple Python function that prints 'Hello World' in reverse:\n",
186+
"\n",
187+
"```python\n",
188+
"def print_hello_world_reverse():\n",
189+
" \"\"\"\n",
190+
" Prints 'Hello World' in reverse.\n",
191+
" \"\"\"\n",
192+
" print(\"Hello World\")\n",
193+
"\n",
194+
"# Example usage:\n",
195+
"print_hello_world_reverse()\n",
196+
"```\n",
197+
"\n",
198+
"When you run this code, it will output:\n",
199+
"```\n",
200+
"olleH dlroW\n",
201+
"```\n",
202+
"\n",
203+
"Alternatively, you can also use slicing to reverse the string:\n",
204+
"\n",
205+
"```python\n",
206+
"def print_hello_world_reverse():\n",
207+
" \"\"\"\n",
208+
" Prints 'Hello World' in reverse.\n",
209+
" \"\"\"\n",
210+
" print(\" \".join([\"H\", \"e\", \"l\", \"l\", \"o\", \" \", \"W\", \"o\", \"r\", \"l\", \"d\"])\n",
211+
"\n",
212+
"# Example usage:\n",
213+
"print_hello_world_reverse()\n",
214+
"```\n",
215+
"\n",
216+
"This will output:\n",
217+
"```\n",
218+
"olleH dlroW\n",
219+
"```\n",
220+
"\n",
221+
"Note: The `join()` method is used to concatenate the elements of a list into a single string, which is then printed.\n"
222+
]
223+
}
224+
],
225+
"source": [
226+
"grammar_str = \"python\"\n",
227+
"\n",
228+
"grammar = Grammar(grammar_str)\n",
229+
"syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n",
230+
"\n",
231+
"prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n",
232+
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
233+
"prompt = tokenizer.apply_chat_template(\n",
234+
" messages, tokenize=False, add_generation_prompt=True\n",
235+
" )\n",
236+
"print(\"[PROMPT]\", prompt, \"\\n\")\n",
237+
"\n",
238+
"syncode_logits_processor.reset(prompt)\n",
239+
"\n",
240+
"inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
241+
"\n",
242+
"attention_mask = torch.ones_like(inputs)\n",
243+
"output = model.generate(\n",
244+
" inputs,\n",
245+
" attention_mask=attention_mask,\n",
246+
" max_length=512, \n",
247+
" num_return_sequences=1, \n",
248+
" pad_token_id=tokenizer.eos_token_id, \n",
249+
" logits_processor=[syncode_logits_processor]\n",
250+
" )\n",
251+
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
252+
"print(\"[OUTPUT]\", output_str)"
253+
]
254+
},
255+
{
256+
"cell_type": "code",
257+
"execution_count": 7,
258+
"metadata": {},
259+
"outputs": [
260+
{
261+
"name": "stdout",
262+
"output_type": "stream",
263+
"text": [
264+
"[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
265+
"\n",
266+
"Cutting Knowledge Date: December 2023\n",
267+
"Today Date: 26 Jul 2024\n",
268+
"\n",
269+
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
270+
"\n",
271+
"Write a go function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
272+
"\n",
273+
" \n",
274+
"\n",
275+
"[OUTPUT] // \n",
276+
"\n",
277+
"package main\n",
278+
"\n",
279+
"import (\n",
280+
" \"fmt\" // Import the fmt package\n",
281+
")\n",
282+
"\n",
283+
"// Function to print 'hello world' in reverse\n",
284+
"func printHelloWorld() {\n",
285+
" // Declare a variable to hold the string 'hello world'\n",
286+
" var s string = \"hello world\" // Define the string\n",
287+
" // Use string reverse() to reverse the string\n",
288+
" var reversed string = strings.Reverses(s) // Reverse the string\n",
289+
" // Print the reversed string\n",
290+
" fmt.Println(reversed) // Print the reversed string\n",
291+
"}\n",
292+
"\n",
293+
"func main() {\n",
294+
" // Call the function to print 'hello world' in reverse\n",
295+
" printHelloWorld() // Call the function\n",
296+
"} \n",
297+
"\n",
298+
"// Note: The string reverse() function in Go returns a string slice, not a string.\n",
299+
"// If you want to convert the string slice to a string, you can use the string slice's string() method.\n",
300+
"// Here, we are using the Reverses() function which returns a string slice.\n"
301+
]
302+
}
303+
],
304+
"source": [
305+
"grammar_str = \"go\"\n",
306+
"\n",
307+
"grammar = Grammar(grammar_str)\n",
308+
"syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n",
309+
"\n",
310+
"prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n",
311+
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
312+
"prompt = tokenizer.apply_chat_template(\n",
313+
" messages, tokenize=False, add_generation_prompt=True\n",
314+
" )\n",
315+
"print(\"[PROMPT]\", prompt, \"\\n\")\n",
316+
"\n",
317+
"syncode_logits_processor.reset(prompt)\n",
318+
"\n",
319+
"inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
320+
"\n",
321+
"attention_mask = torch.ones_like(inputs)\n",
322+
"output = model.generate(\n",
323+
" inputs,\n",
324+
" attention_mask=attention_mask,\n",
325+
" max_length=512, \n",
326+
" num_return_sequences=1, \n",
327+
" pad_token_id=tokenizer.eos_token_id, \n",
328+
" logits_processor=[syncode_logits_processor]\n",
329+
" )\n",
330+
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
331+
"print(\"[OUTPUT]\", output_str)"
332+
]
108333
}
109334
],
110335
"metadata": {

0 commit comments

Comments
 (0)