Skip to content

Commit 2e3e22f

Browse files
committed
local inference for batch size > 1
1 parent c8d58ac commit 2e3e22f

File tree

2 files changed

+126
-53
lines changed

2 files changed

+126
-53
lines changed

notebooks/example_json.ipynb

Lines changed: 93 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,8 @@
99
"name": "stderr",
1010
"output_type": "stream",
1111
"text": [
12-
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
13-
]
14-
},
15-
{
16-
"name": "stderr",
17-
"output_type": "stream",
18-
"text": [
19-
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 3.79it/s]\n",
20-
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
21-
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 3.84it/s]\n",
22-
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
23-
]
24-
},
25-
{
26-
"name": "stdout",
27-
"output_type": "stream",
28-
"text": [
29-
"Creating DFA mask store for CodeGenTokenizerFast and json, may take more than 10 minutes. Caching at /home/shubham/syncode/cache/mask_stores/CodeGenTokenizerFast/grammar_strict_1003218229_50257.pkl.\n",
30-
"Ignore whitespace tokens is True\n"
31-
]
32-
},
33-
{
34-
"name": "stderr",
35-
"output_type": "stream",
36-
"text": [
37-
"100%|██████████| 58/58 [00:22<00:00, 2.54it/s]\n"
12+
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13+
" from .autonotebook import tqdm as notebook_tqdm\n"
3814
]
3915
}
4016
],
@@ -44,8 +20,15 @@
4420
"import warnings\n",
4521
"warnings.filterwarnings('ignore')\n",
4622
"\n",
47-
"model_name = \"microsoft/phi-2\"\n",
48-
"\n",
23+
"model_name = \"microsoft/phi-2\""
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": null,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
4932
"# Load the unconstrained original model\n",
5033
"llm = Syncode(model=model_name, mode='original', max_new_tokens=50)\n",
5134
"\n",
@@ -118,6 +101,88 @@
118101
"output = syn_llm.infer(prompt)[0]\n",
119102
"print(f\"SynCode output:\\n{output}\")"
120103
]
104+
},
105+
{
106+
"cell_type": "code",
107+
"execution_count": 3,
108+
"metadata": {},
109+
"outputs": [
110+
{
111+
"name": "stderr",
112+
"output_type": "stream",
113+
"text": [
114+
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 3.30it/s]\n"
115+
]
116+
},
117+
{
118+
"name": "stdout",
119+
"output_type": "stream",
120+
"text": [
121+
"WARNING: Opportunistic mode requires batch_size of 1.\n"
122+
]
123+
},
124+
{
125+
"name": "stderr",
126+
"output_type": "stream",
127+
"text": [
128+
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
129+
]
130+
},
131+
{
132+
"name": "stdout",
133+
"output_type": "stream",
134+
"text": [
135+
"Warning: Exact lookup not found for (LSQB, 1) in the DFA mask store. This could be an error.\n",
136+
"Warning: Exact lookup not found for (UNESCAPED_STRING, 3) in the DFA mask store. This could be an error.\n",
137+
"SynCode output 1:\n",
138+
"{\n",
139+
" \"name\" : \"India\",\n",
140+
" \"capital\" : \"New Delhi\",\n",
141+
" \"population\" : \"1.3 billion\"\n",
142+
"}\n",
143+
"\n",
144+
"SynCode output 2:\n",
145+
"{\n",
146+
" \"name\" : \"India\"\n",
147+
" }\n",
148+
"\n",
149+
"SynCode output 3:\n",
150+
"{\n",
151+
" \"name\" : \"India\"\n",
152+
" , \"capital\" : \"New Delhi\"\n",
153+
" , \"population\" : \"1.3 billion\"\n",
154+
"}\n",
155+
"\n",
156+
"SynCode output 4:\n",
157+
"{\n",
158+
" \"name\" : \"India\"\n",
159+
"\n",
160+
"SynCode output 5:\n",
161+
"{\n",
162+
" \"name\" : \"India\"\n",
163+
" , \"capital\" : \"New Delhi\"\n",
164+
" , \"population\" : \"1367\n",
165+
"}\n",
166+
"\n",
167+
"A:\n",
168+
"\n",
169+
"Try this :\n",
170+
"\n",
171+
"var data = [\n",
172+
"\n"
173+
]
174+
}
175+
],
176+
"source": [
177+
"syn_llm = Syncode(model=model_name, grammar='json', parse_output_only=True, max_new_tokens=50, num_return_sequences=5, do_sample=True, temperature=0.7)\n",
178+
"\n",
179+
"prompt = \"Please return a json object to represent country India with name, capital and population?\"\n",
180+
"output = syn_llm.infer(prompt)\n",
181+
"\n",
182+
"for i, out in enumerate(output):\n",
183+
" out = out.strip()\n",
184+
" print(f\"SynCode output {i+1}:\\n{out}\\n\")"
185+
]
121186
}
122187
],
123188
"metadata": {

syncode/language_model.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,13 @@ def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=None,
9393
stop_criteria = []
9494

9595
# Generate completions
96-
if self.opp and (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH) and batch_size == 1: # Use our own implementation for greedy search and sampling
96+
if self.opp and (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH): # Use our own implementation for greedy search and sampling
9797
generated_ids = self._generate(
9898
inputs,
9999
gen_config,
100100
gen_mode,
101101
grammar_decoder=self.grammar_decoder,
102-
stop_criteria=stop_criteria
102+
stopping_criteria=stop_criteria
103103
)
104104
else:
105105
if self.opp:
@@ -137,20 +137,19 @@ def _generate(
137137
gen_config:GenerationConfig,
138138
gen_mode:GenerationMode,
139139
grammar_decoder:SyncodeLogitsProcessor=None,
140-
stop_criteria:StoppingCriteria=[]
140+
stopping_criteria:StoppingCriteria=[]
141141
):
142142
"""
143143
We support greedy search and sampling for batch size 1 otherwise we use the generate function from transformers library.
144144
"""
145145
token_ids, attention_mask, past_key_values = inputs['input_ids'], inputs['attention_mask'], None
146-
147-
# This does not include grammar decoder
148-
self.model._prepare_special_tokens(gen_config, False, device=self.device)
149-
logits_processor = self.model._get_logits_processor(gen_config, token_ids.size(1), token_ids, prefix_allowed_tokens_fn=None, logits_processor=[])
150-
146+
logit_warper = self.model._get_logits_warper(gen_config, device=self.device)
151147
max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1)
152-
153-
while True:
148+
num_outputs = token_ids.size(0)
149+
unfinished_sequences = torch.ones(num_outputs, dtype=torch.long, device=self.device)
150+
this_peer_finished = False
151+
152+
while not this_peer_finished:
154153
try:
155154
if past_key_values: # Get the last token if kv is cached for all previous tokens
156155
input_ids = token_ids[..., -1].unsqueeze(-1)
@@ -168,30 +167,39 @@ def _generate(
168167
next_token_scores, past_key_values = outputs.logits[:, -1, :], outputs.past_key_values
169168

170169
if grammar_decoder is not None:
171-
next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores)
172-
is_valid = grammar_decoder.is_valid(token_ids, next_token)
173-
174-
if not is_valid:
175-
# calling grammar decoder is expensive. Hence, in the opportunist mode, we call it only when the standard generation is syntactically incorrect
176-
next_token_scores = grammar_decoder(token_ids, next_token_scores)
177-
next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores)
170+
# batch of next tokens
171+
next_tokens = self._get_next_token(gen_mode, token_ids, logit_warper, next_token_scores)
172+
173+
for idx in range(token_ids.size(0)):
174+
token_ids_i = token_ids[idx:idx+1]
175+
next_token_scores_i = next_token_scores[idx:idx+1]
176+
next_token_i = next_tokens[idx:idx+1]
177+
178+
is_valid = grammar_decoder.is_valid(token_ids_i, next_token_i)
179+
180+
if not is_valid:
181+
# calling grammar decoder is expensive. Hence, in the opportunist mode, we call it only when the standard generation is syntactically incorrect
182+
next_token_scores_i = grammar_decoder(token_ids_i, next_token_scores_i)
183+
next_tokens[idx] = self._get_next_token(gen_mode, token_ids_i, logit_warper, next_token_scores_i)
178184
else:
179-
next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores)
185+
next_tokens = self._get_next_token(gen_mode, token_ids, logit_warper, next_token_scores)
180186

181-
token_ids = torch.cat([token_ids, next_token[:, None]], dim=-1)
182187

183-
# Check stopping criteria
184-
finish_generation = False
185-
for stop_criterion in stop_criteria:
186-
if stop_criterion(token_ids, next_token_scores):
187-
finish_generation = True
188+
# Update the next token
189+
next_tokens = next_tokens * unfinished_sequences + self.tokenizer.eos_token_id * (1 - unfinished_sequences)
190+
191+
token_ids = torch.cat([token_ids, next_tokens[:, None]], dim=-1)
188192

189193
# Check if the next token is the end of the sequence or the max tokens is reached
190-
if finish_generation or next_token == self.tokenizer.eos_token_id or token_ids.size(1) >= max_tokens:
194+
if token_ids.size(1) >= max_tokens:
191195
break
192196

193197
# Update attention mask
194198
attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype).to(self.device)], dim=-1)
199+
200+
# Update the unfinished sequences
201+
unfinished_sequences = unfinished_sequences & ~(stopping_criteria(token_ids, next_token_scores) | (token_ids[:, -1] == self.tokenizer.eos_token_id))
202+
this_peer_finished = unfinished_sequences.max() == 0
195203

196204
return token_ids
197205

0 commit comments

Comments
 (0)