Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions syncode/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=None,
stop_criteria = []

# Generate completions
if self.opp and (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH) and batch_size == 1: # Use our own implementation for greedy search and sampling
if self.opp and (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH): # Use our own implementation for greedy search and sampling
generated_ids = self._generate(
inputs,
gen_config,
gen_mode,
grammar_decoder=self.grammar_decoder,
stop_criteria=stop_criteria
stopping_criteria=stop_criteria
)
else:
if self.opp:
Expand Down Expand Up @@ -137,20 +137,19 @@ def _generate(
gen_config:GenerationConfig,
gen_mode:GenerationMode,
grammar_decoder:SyncodeLogitsProcessor=None,
stop_criteria:StoppingCriteria=[]
stopping_criteria:StoppingCriteria=[]
):
"""
We support greedy search and sampling for batch size 1 otherwise we use the generate function from transformers library.
"""
token_ids, attention_mask, past_key_values = inputs['input_ids'], inputs['attention_mask'], None

# This does not include grammar decoder
self.model._prepare_special_tokens(gen_config, False, device=self.device)
logits_processor = self.model._get_logits_processor(gen_config, token_ids.size(1), token_ids, prefix_allowed_tokens_fn=None, logits_processor=[])

logit_warper = self.model._get_logits_warper(gen_config, device=self.device)
max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1)

while True:
num_outputs = token_ids.size(0)
unfinished_sequences = torch.ones(num_outputs, dtype=torch.long, device=self.device)
this_peer_finished = False

while not this_peer_finished:
try:
if past_key_values: # Get the last token if kv is cached for all previous tokens
input_ids = token_ids[..., -1].unsqueeze(-1)
Expand All @@ -168,30 +167,39 @@ def _generate(
next_token_scores, past_key_values = outputs.logits[:, -1, :], outputs.past_key_values

if grammar_decoder is not None:
next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores)
is_valid = grammar_decoder.is_valid(token_ids, next_token)

if not is_valid:
# calling grammar decoder is expensive. Hence, in the opportunist mode, we call it only when the standard generation is syntactically incorrect
next_token_scores = grammar_decoder(token_ids, next_token_scores)
next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores)
# batch of next tokens
next_tokens = self._get_next_token(gen_mode, token_ids, logit_warper, next_token_scores)

for idx in range(token_ids.size(0)):
token_ids_i = token_ids[idx:idx+1]
next_token_scores_i = next_token_scores[idx:idx+1]
next_token_i = next_tokens[idx:idx+1]

is_valid = grammar_decoder.is_valid(token_ids_i, next_token_i)

if not is_valid:
# calling grammar decoder is expensive. Hence, in the opportunist mode, we call it only when the standard generation is syntactically incorrect
next_token_scores_i = grammar_decoder(token_ids_i, next_token_scores_i)
next_tokens[idx] = self._get_next_token(gen_mode, token_ids_i, logit_warper, next_token_scores_i)
else:
next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores)
next_tokens = self._get_next_token(gen_mode, token_ids, logit_warper, next_token_scores)

token_ids = torch.cat([token_ids, next_token[:, None]], dim=-1)

# Check stopping criteria
finish_generation = False
for stop_criterion in stop_criteria:
if stop_criterion(token_ids, next_token_scores):
finish_generation = True
# Update the next token
next_tokens = next_tokens * unfinished_sequences + self.tokenizer.eos_token_id * (1 - unfinished_sequences)

token_ids = torch.cat([token_ids, next_tokens[:, None]], dim=-1)

# Check if the next token is the end of the sequence or the max tokens is reached
if finish_generation or next_token == self.tokenizer.eos_token_id or token_ids.size(1) >= max_tokens:
if token_ids.size(1) >= max_tokens:
break

# Update attention mask
attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype).to(self.device)], dim=-1)

# Update the unfinished sequences
unfinished_sequences = unfinished_sequences & ~(stopping_criteria(token_ids, next_token_scores) | (token_ids[:, -1] == self.tokenizer.eos_token_id))
this_peer_finished = unfinished_sequences.max() == 0

return token_ids

Expand Down
Loading