diff --git a/main.py b/main.py index 3efd6ae..6b75cd3 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ from sampling import autoregressive_sampling, speculative_sampling, speculative_sampling_v2 from globals import Decoder +import time @@ -95,32 +96,44 @@ def generate(input_text, approx_model_name, target_model_name, num_tokens=20, ga top_p = 0.9 torch.manual_seed(123) + start = time.time() output = autoregressive_sampling(input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) + end = time.time() color_print(f"large (target) model autoregressive_sampling: {generated_text}") + color_print(f"Elapsed time for Large Autoregressive_sampling: {end-start}") if use_benchmark: benchmark(autoregressive_sampling, "AS_large", use_profiling, input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p) torch.manual_seed(123) + start = time.time() output = autoregressive_sampling(input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) + end = time.time() color_print(f"small (approx) model autoregressive_sampling: {generated_text}") + color_print(f"Elapsed time for Small Autoregressive_sampling: {end-start}") if use_benchmark: benchmark(autoregressive_sampling, "AS_small", use_profiling, input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p) torch.manual_seed(123) + start = time.time() output = speculative_sampling_v2(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) - color_print(f"deepmind's speculative_sampling: {generated_text}") + end = time.time() + color_print(f"deepmind's speculative_sampling: {generated_text}") + color_print(f"Elapsed time for deepmind's speculative_sampling: {end-start}") torch.manual_seed(123) + start = time.time() output = speculative_sampling(input_ids, small_model, large_model, num_tokens, gamma = gamma, top_k = top_k, top_p=top_p, random_seed = random_seed, verbose = verbose) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) + end = time.time() color_print(f"google's speculative_sampling: {generated_text}") + color_print(f"Elapsed time for google's speculative_sampling: {end-start}") if use_benchmark: benchmark(speculative_sampling, "SP", use_profiling, diff --git a/sampling/kvcache_model.py b/sampling/kvcache_model.py index 1337ae9..5d792ce 100644 --- a/sampling/kvcache_model.py +++ b/sampling/kvcache_model.py @@ -1,5 +1,6 @@ import torch from typing import Optional +from transformers.cache_utils import DynamicCache from sampling.utils import norm_logits, sample from transformers.models.bloom.modeling_bloom import BloomForCausalLM @@ -33,6 +34,18 @@ def _forward_with_kvcache(self, input_ids : torch.Tensor, use_debug = True) -> t self._past_key_values = outputs.past_key_values last_q = self._prob_history[:, -1, :] else: + if isinstance(self._past_key_values, DynamicCache): + cached_len = self._past_key_values.get_seq_length() + else: + cached_len = 0 + for kv in self._past_key_values: + k, v = kv + cached_len = k.shape[2] # For Bloom + if k.dim() == 3: # Handle standard (batch, heads, seq_len, dim) format + cached_len = k.shape[2] + else: + cached_len = k.shape[-2] + break # return the last token's logits cached_len = 0 for kv in self._past_key_values: @@ -90,28 +103,28 @@ def generate(self, input : torch.Tensor, gamma : int) -> torch.Tensor: return output @torch.no_grad() - def rollback(self, end_pos : int): - past_key_values_trimmed = [] - assert self._past_key_values - for kv in self._past_key_values: - k, v = kv - # NOTE() the indexing is specific for bloom. This won't work for other models - # For example llama k, v should be (batch, num_head, seq_len, hidden_dim) - - # Bloom is special one - if isinstance(self._model, BloomForCausalLM): - # k (batch * head, hidden_dim, seq); v (batch * head, seq, hidden_dim) - k = k[:, :, :end_pos] - v = v[:, :end_pos, :] - kv_trimmed = (k, v) - past_key_values_trimmed.append(kv_trimmed) - else: - # k, v (batch, head, seq, hidden_dim) - k = k[:, :, :end_pos, :] - v = v[:, :, :end_pos, :] - kv_trimmed = (k, v) - past_key_values_trimmed.append(kv_trimmed) + def rollback(self, end_pos: int): + if isinstance(self._past_key_values, DynamicCache): + # Truncate DynamicCache + new_cache = DynamicCache() + for layer_idx in range(len(self._past_key_values.key_cache)): + k = self._past_key_values.key_cache[layer_idx][..., :end_pos, :] + v = self._past_key_values.value_cache[layer_idx][..., :end_pos, :] + new_cache.key_cache.append(k) + new_cache.value_cache.append(v) + self._past_key_values = new_cache + else: + # Original tuple-based handling + past_key_values_trimmed = [] + for kv in self._past_key_values: + k, v = kv + if isinstance(self._model, BloomForCausalLM): + k = k[:, :, :end_pos] + v = v[:, :end_pos, :] + else: + k = k[..., :end_pos, :] + v = v[..., :end_pos, :] + past_key_values_trimmed.append((k, v)) + self._past_key_values = past_key_values_trimmed - self._past_key_values = past_key_values_trimmed self._prob_history = self._prob_history[:, :end_pos, :] -