diff --git a/syncode/language_model.py b/syncode/language_model.py index 5632b88b..8968422e 100644 --- a/syncode/language_model.py +++ b/syncode/language_model.py @@ -93,13 +93,13 @@ def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=None, stop_criteria = [] # Generate completions - if self.opp and (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH) and batch_size == 1: # Use our own implementation for greedy search and sampling + if self.opp and (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH): # Use our own implementation for greedy search and sampling generated_ids = self._generate( inputs, gen_config, gen_mode, grammar_decoder=self.grammar_decoder, - stop_criteria=stop_criteria + stopping_criteria=stop_criteria ) else: if self.opp: @@ -137,20 +137,19 @@ def _generate( gen_config:GenerationConfig, gen_mode:GenerationMode, grammar_decoder:SyncodeLogitsProcessor=None, - stop_criteria:StoppingCriteria=[] + stopping_criteria:StoppingCriteria=[] ): """ We support greedy search and sampling for batch size 1 otherwise we use the generate function from transformers library. """ token_ids, attention_mask, past_key_values = inputs['input_ids'], inputs['attention_mask'], None - - # This does not include grammar decoder - self.model._prepare_special_tokens(gen_config, False, device=self.device) - logits_processor = self.model._get_logits_processor(gen_config, token_ids.size(1), token_ids, prefix_allowed_tokens_fn=None, logits_processor=[]) - + logit_warper = self.model._get_logits_warper(gen_config, device=self.device) max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1) - - while True: + num_outputs = token_ids.size(0) + unfinished_sequences = torch.ones(num_outputs, dtype=torch.long, device=self.device) + this_peer_finished = False + + while not this_peer_finished: try: if past_key_values: # Get the last token if kv is cached for all previous tokens input_ids = token_ids[..., -1].unsqueeze(-1) @@ -168,30 +167,39 @@ def _generate( next_token_scores, past_key_values = outputs.logits[:, -1, :], outputs.past_key_values if grammar_decoder is not None: - next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores) - is_valid = grammar_decoder.is_valid(token_ids, next_token) - - if not is_valid: - # calling grammar decoder is expensive. Hence, in the opportunist mode, we call it only when the standard generation is syntactically incorrect - next_token_scores = grammar_decoder(token_ids, next_token_scores) - next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores) + # batch of next tokens + next_tokens = self._get_next_token(gen_mode, token_ids, logit_warper, next_token_scores) + + for idx in range(token_ids.size(0)): + token_ids_i = token_ids[idx:idx+1] + next_token_scores_i = next_token_scores[idx:idx+1] + next_token_i = next_tokens[idx:idx+1] + + is_valid = grammar_decoder.is_valid(token_ids_i, next_token_i) + + if not is_valid: + # calling grammar decoder is expensive. Hence, in the opportunist mode, we call it only when the standard generation is syntactically incorrect + next_token_scores_i = grammar_decoder(token_ids_i, next_token_scores_i) + next_tokens[idx] = self._get_next_token(gen_mode, token_ids_i, logit_warper, next_token_scores_i) else: - next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores) + next_tokens = self._get_next_token(gen_mode, token_ids, logit_warper, next_token_scores) - token_ids = torch.cat([token_ids, next_token[:, None]], dim=-1) - # Check stopping criteria - finish_generation = False - for stop_criterion in stop_criteria: - if stop_criterion(token_ids, next_token_scores): - finish_generation = True + # Update the next token + next_tokens = next_tokens * unfinished_sequences + self.tokenizer.eos_token_id * (1 - unfinished_sequences) + + token_ids = torch.cat([token_ids, next_tokens[:, None]], dim=-1) # Check if the next token is the end of the sequence or the max tokens is reached - if finish_generation or next_token == self.tokenizer.eos_token_id or token_ids.size(1) >= max_tokens: + if token_ids.size(1) >= max_tokens: break # Update attention mask attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype).to(self.device)], dim=-1) + + # Update the unfinished sequences + unfinished_sequences = unfinished_sequences & ~(stopping_criteria(token_ids, next_token_scores) | (token_ids[:, -1] == self.tokenizer.eos_token_id)) + this_peer_finished = unfinished_sequences.max() == 0 return token_ids