diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/decode.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/decode.py new file mode 100644 index 00000000..7ec412d9 --- /dev/null +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/decode.py @@ -0,0 +1,69 @@ +# User input +prompt = "Lily picked up a flower." +model_name = "Maykeye/TinyLLama-v0" + +# Tokenizer +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" +inputs = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=30, + truncation=True, +) + +# Generator +import torch + +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() + +from tico.utils.record_input import RecordingInput + +# past_key_values +# --------------- +# During prefill, "past_key_values" not None, but an empty Cache instance. +# Passing None makes torch.export happy. + + +input_to_remove = [ + "attention_mask", + # For left pad, [0, ⋯, 0, 1, ⋯, 1] + # For right right pad, [1, ⋯, 1, 0, ⋯, 0] + # ( 0 is pad-token ) + # This script uses right pad and pass all-1 attention mask (including pad). + # Npu computes all positions whether it is pad or not. +] +condition_fn = lambda args_dict: args_dict["past_key_values"].get_seq_length() != 0 + +with torch.no_grad(), RecordingInput( + model, condition_fn, input_to_remove=input_to_remove +) as rec: + outputs = model.generate( + **inputs, + max_new_tokens=32, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + captured_input = rec.captured_input + +generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(generated_text) + +# Tico +import tico +from tico.serialize.operators.onert.op_attention import llama_attention_forward_adapter +from transformers.models.llama.modeling_llama import LlamaAttention + +LlamaAttention.forward = llama_attention_forward_adapter + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() +circle_model = tico.convert(model, captured_input) +circle_model.save(f"tinyllama.decode.circle") diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py new file mode 100644 index 00000000..4b035adc --- /dev/null +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py @@ -0,0 +1,56 @@ +# User input +prompt = "Lily picked up a flower." +model_name = "Maykeye/TinyLLama-v0" + +# Tokenizer +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" +inputs = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=31, + truncation=True, +) + +# Generator +import torch + +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() + +from tico.utils.record_input import RecordingInput + +target_model = model.model.layers[0] +condition_fn = lambda args_dict: args_dict["past_key_value"].get_seq_length() != 0 + +with torch.no_grad(), RecordingInput(target_model, condition_fn) as rec: + outputs = model.generate( + **inputs, + max_new_tokens=32, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + captured_input = rec.captured_input + +generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(generated_text) + + +# Convert + +import tico +from tico.serialize.operators.onert.op_attention import llama_attention_forward_adapter +from transformers.models.llama.modeling_llama import LlamaAttention + +LlamaAttention.forward = llama_attention_forward_adapter + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() +circle_model = tico.convert(model.model.layers[0], captured_input) +circle_model.save(f"tinyllama.layer.attn.circle") diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py new file mode 100644 index 00000000..74e3db4e --- /dev/null +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py @@ -0,0 +1,110 @@ +# User input +prompt = "Lily picked up a flower." +model_name = "Maykeye/TinyLLama-v0" + +# Tokenizer +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" +inputs = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=32, + truncation=True, +) + +# Generator +import torch + +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() + +from tico.utils.record_input import RecordingInput + +target_model = model.model.layers[0] +condition_fn = lambda args_dict: args_dict["past_key_value"].get_seq_length() != 0 + +with torch.no_grad(), RecordingInput(target_model, condition_fn) as rec: + outputs = model.generate( + **inputs, + max_new_tokens=32, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + captured_input = rec.captured_input + +generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(generated_text) + +from typing import Any, Optional, Tuple + +# Define DecoderLayers + +from torch import nn +from transformers.cache_utils import Cache +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaModel + + +# DecoderLayers is not nn.Module. Not torch.export-able. +# Let's define decoder layers as nn.Module. + + +class LlamaDecoderLayers(nn.Module): + def __init__(self, model: LlamaModel): + super().__init__() + self.config = model.config + self.layers = model.layers + + # Make sure signature is same to capturing input. + # Just copy and Paste from LlamaDecoderLayer::forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Any, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = layer_outputs[0] + + return hidden_states + + +# Convert + +import tico + +# NOTE: +# If you want to restore forward, it may be implemented as context manager. +# However, it is just a simple script to export. No one uses forward after tico conversion. +from tico.serialize.operators.onert.op_attention import llama_attention_forward_adapter + +LlamaAttention.forward = llama_attention_forward_adapter + +model = AutoModelForCausalLM.from_pretrained(model_name) +layers = LlamaDecoderLayers(model.model) +layers.eval() +circle_model = tico.convert(layers, captured_input) +circle_model.save(f"tinyllama.layers.attn.circle") diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/prefill.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/prefill.py new file mode 100644 index 00000000..942223d5 --- /dev/null +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/prefill.py @@ -0,0 +1,76 @@ +# User input +prompt = "Lily picked up a flower." +model_name = "Maykeye/TinyLLama-v0" + +# Tokenizer +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" +inputs = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=32, + truncation=True, +) + +# Generator +import torch + +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() + +from tico.utils.record_input import RecordingInput + +# past_key_values +# --------------- +# During prefill, "past_key_values" not None, but an empty Cache instance. +# Passing None makes torch.export happy. + + +input_to_remove = [ + "past_key_values", + # DynamicCache is flatten-able operator since 4.50. + # See _pytree.py > tree_flatten + # SUPPORTED_NODES has *transformers.DynamicCache* + # After flattening, DynamicCache becomes { "key_cache": [] , "value_cache": [ ] } + # dict.value is returne. dict.key is stored in treespec. + # + # On prefill, DynamicCache is empty, and dict is empty after flattening. + # PyTorch removes empty dict! + # If number of args is 4 (including cache), it becomes 3! + # To avoid this error, don't pass empty cache, just pass None. + "attention_mask", + # For left pad, [0, ⋯, 0, 1, ⋯, 1] + # For right right pad, [1, ⋯, 1, 0, ⋯, 0] + # ( 0 is pad-token ) + # This script uses right pad and pass all-1 attention mask (including pad). + # Npu computes all positions whether it is pad or not. + "cache_position" + # It is the list of cache position like [0, 1, ..., 11]. + # For npu, we always store all values (including pad). +] + +with torch.no_grad(), RecordingInput(model, input_to_remove=input_to_remove) as rec: + outputs = model.generate( + **inputs, + max_new_tokens=32, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + captured_input = rec.captured_input + +generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(generated_text) + +# Tico +import tico + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() +circle_model = tico.convert(model, captured_input) +circle_model.save(f"tinyllama.prefill.circle") diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/requirements.txt b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/requirements.txt new file mode 100644 index 00000000..5393938f --- /dev/null +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/requirements.txt @@ -0,0 +1 @@ +transformers>=4.50.1 diff --git a/tico/serialize/operators/onert/op_attention.py b/tico/serialize/operators/onert/op_attention.py new file mode 100644 index 00000000..c52f5ee3 --- /dev/null +++ b/tico/serialize/operators/onert/op_attention.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, TYPE_CHECKING + +if TYPE_CHECKING: + import torch._ops + import torch.fx +import torch +from circle_schema import circle + +from tico.serialize.circle_graph import CircleSubgraph +from tico.serialize.operators.hashable_opcode import OpCode +from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor +from tico.serialize.operators.utils import create_builtin_operator, get_op_index + +from torch.library import Library + +lib = Library("circle", "DEF") +lib.define( + """ +attention.llama( + Tensor hidden_states, + Tensor wq, + Tensor wk, + Tensor wv, + Tensor wo, + Tensor position_cos, + Tensor position_sin, + Tensor attention_mask, + Tensor past_key, + Tensor past_value, + Tensor cache_position +) -> Tensor +""" +) + +# ATTENTION FUSER + + +@torch.library.register_fake("circle::attention.llama") +def attention_llama(*args, **kwargs): + ( + hidden_states, + q_proj, + k_proj, + v_proj, + o_proj, + position_cos, + position_sin, + attention_mask, + past_key, + past_value, + cache_position, + ) = args + return hidden_states + + +from typing import List, Optional + +from transformers.cache_utils import DynamicCache +from transformers.models.llama.modeling_llama import LlamaAttention + + +def llama_attention_forward_adapter( + self: LlamaAttention, + hidden_states: torch.Tensor, + position_embeddings: List[torch.Tensor], + attention_mask: torch.Tensor, + past_key_value: DynamicCache, + cache_position: torch.Tensor, + **kwargs, +): + # past_key_value is a dict with key_cache and value_cache. + # It needs to be decomposed for tico and circle which does not know dict. + key_cache = past_key_value.key_cache # type: ignore[union-attr] + value_cache = past_key_value.value_cache # type: ignore[union-attr] + return ( + torch.ops.circle.attention.llama( + hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + self.o_proj.weight, + position_embeddings[0], # cos + position_embeddings[1], # sin + attention_mask, + # key_cache is a list of cache for each decoder layer. + # Assumtion: key cache is continuous + # + # k_cache[0] | k_cache[1] | ... | k_cache[n] + key_cache[self.layer_idx], + value_cache[self.layer_idx], # Same to value_cache + cache_position, + ), + None, + ) + + +@register_node_visitor +class AttentionVisitor(NodeVisitor): + target: List[torch._ops.OpOverload] = [ + torch.ops.circle.attention.llama, + ] + + def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): + super().__init__(op_codes, graph) + + def define_node( + self, + node: torch.fx.Node, + ) -> circle.Operator.OperatorT: + ( + hidden_states, + wq, + wk, + wv, + wo, + position_cos, + position_sin, + attention_mask, + past_key, + past_value, + cache_position, + ) = node.args + + op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes + ) + + # remove last arg (= layer_idx) from inputs. + # layer_idx is attention op's param, not input. + inputs = node.args[:-1] + outputs = [node] + operator = create_builtin_operator(self.graph, op_index, inputs, outputs) + + # Op-specific option + operator.builtinOptionsType = ( + circle.BuiltinOptions.BuiltinOptions.AttentionOptions + ) + operator.builtinOptions = circle.AttentionOptions.AttentionOptionsT() + + return operator