From cfb1e6b74a4761b0f10033743cc72358b5e91997 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Thu, 5 Feb 2026 16:13:05 +0900 Subject: [PATCH 1/3] Revisit pytree_utils --- .../__init__.py | 0 .../model.py | 0 .../requirements.txt | 0 .../__init__.py | 1 + .../model.py | 62 +++ .../requirements.txt | 2 + tico/utils/pytree_utils.py | 462 ++++++++++++++---- tico/utils/pytree_utils.save.md | 311 ++++++++++++ 8 files changed, 744 insertions(+), 94 deletions(-) rename test/modules/model/{LlamaAttentionWithKVCache => LlamaAttentionWithDynamicCache}/__init__.py (100%) rename test/modules/model/{LlamaAttentionWithKVCache => LlamaAttentionWithDynamicCache}/model.py (100%) rename test/modules/model/{LlamaAttentionWithKVCache => LlamaAttentionWithDynamicCache}/requirements.txt (100%) create mode 100644 test/modules/model/LlamaAttentionWithDynamicCache_transformers500/__init__.py create mode 100644 test/modules/model/LlamaAttentionWithDynamicCache_transformers500/model.py create mode 100644 test/modules/model/LlamaAttentionWithDynamicCache_transformers500/requirements.txt create mode 100644 tico/utils/pytree_utils.save.md diff --git a/test/modules/model/LlamaAttentionWithKVCache/__init__.py b/test/modules/model/LlamaAttentionWithDynamicCache/__init__.py similarity index 100% rename from test/modules/model/LlamaAttentionWithKVCache/__init__.py rename to test/modules/model/LlamaAttentionWithDynamicCache/__init__.py diff --git a/test/modules/model/LlamaAttentionWithKVCache/model.py b/test/modules/model/LlamaAttentionWithDynamicCache/model.py similarity index 100% rename from test/modules/model/LlamaAttentionWithKVCache/model.py rename to test/modules/model/LlamaAttentionWithDynamicCache/model.py diff --git a/test/modules/model/LlamaAttentionWithKVCache/requirements.txt b/test/modules/model/LlamaAttentionWithDynamicCache/requirements.txt similarity index 100% rename from test/modules/model/LlamaAttentionWithKVCache/requirements.txt rename to test/modules/model/LlamaAttentionWithDynamicCache/requirements.txt diff --git a/test/modules/model/LlamaAttentionWithDynamicCache_transformers500/__init__.py b/test/modules/model/LlamaAttentionWithDynamicCache_transformers500/__init__.py new file mode 100644 index 00000000..0c29109f --- /dev/null +++ b/test/modules/model/LlamaAttentionWithDynamicCache_transformers500/__init__.py @@ -0,0 +1 @@ +# DO NOT REMOVE THIS FILE diff --git a/test/modules/model/LlamaAttentionWithDynamicCache_transformers500/model.py b/test/modules/model/LlamaAttentionWithDynamicCache_transformers500/model.py new file mode 100644 index 00000000..04ae94ff --- /dev/null +++ b/test/modules/model/LlamaAttentionWithDynamicCache_transformers500/model.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from tico.utils.pytree_utils import register_dynamic_cache, register_dynamic_layer +from transformers.cache_utils import DynamicCache +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaConfig + +from test.modules.base import TestModuleBase + + +class LlamaAttentionWithKVCache_transformers500(TestModuleBase): + def __init__(self): + super().__init__() + + self.config = LlamaConfig(use_cache=True, attn_implementation="sdpa") + self.model = LlamaAttention(config=self.config, layer_idx=0).to("cpu") + self.rtol = 1e-4 + self.atol = 1e-4 + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def get_example_inputs(self): + seq_len = 1 # Assume token generation + hidden_size = self.config.hidden_size + head_dim = self.config.head_dim + num_heads = self.config.num_attention_heads + + hidden_states = torch.randn(1, seq_len, hidden_size) + position_embeddings = ( + torch.randn(1, seq_len, head_dim), + torch.randn(1, seq_len, head_dim), + ) + attention_mask = torch.Tensor([[[[0.0]] * seq_len]]) # shape: 1, 1, seq_len, 1 + # This attention_mask will become a causal_mask of shape: (batch_size, 1, query_length, key_value_length) + prev_seq_len = 4 + past_key_values = DynamicCache() + register_dynamic_cache() + register_dynamic_layer() + + past_key_values.update( + torch.randn(1, num_heads, prev_seq_len, head_dim), + torch.randn(1, num_heads, prev_seq_len, head_dim), + 0, + ) + return ( + hidden_states, + position_embeddings, + attention_mask, + ), {"past_key_values": past_key_values} diff --git a/test/modules/model/LlamaAttentionWithDynamicCache_transformers500/requirements.txt b/test/modules/model/LlamaAttentionWithDynamicCache_transformers500/requirements.txt new file mode 100644 index 00000000..17ed30af --- /dev/null +++ b/test/modules/model/LlamaAttentionWithDynamicCache_transformers500/requirements.txt @@ -0,0 +1,2 @@ +numpy==1.24.1 +transformers==5.0.0 diff --git a/tico/utils/pytree_utils.py b/tico/utils/pytree_utils.py index 7645a675..78f919c0 100644 --- a/tico/utils/pytree_utils.py +++ b/tico/utils/pytree_utils.py @@ -6,121 +6,347 @@ from tico.utils import logging from tico.utils.installed_packages import is_transformers_installed -__all__ = ["register_dynamic_cache"] +import torch +from packaging.version import Version +from transformers.cache_utils import StaticCache, StaticLayer, DynamicCache, DynamicLayer, EncoderDecoderCache +__all__ = [ + "register_dynamic_cache", + "register_static_cache", + "register_dynamic_layer", + "register_static_layer", + "register_encoder_decoder_cache", -def register_dynamic_cache(): - PyTreeRegistryHelper().register_dynamic_cache() + # "register_cache_utils_to_pytree" # Covers most of torch cache-utils + ] +# def register_encoder_decoder_cache(): +# PyTreeRegistryHelper().register(EncoderDecoderCache) -class PyTreeRegistryHelper: - """ - Thread-safe singleton helper class for registering custom PyTree nodes. +# def register_dynamic_cache(): +# PyTreeRegistryHelper().register(DynamicCache) - This class provides functionality to register DynamicCache as a PyTree node - for torch.export compatibility. This registration is only needed for - transformers versions below 4.50.0. +# def register_static_cache(): +# PyTreeRegistryHelper().register(StaticCache) - Thread Safety: - - Uses a class-level threading.Lock() to ensure thread-safe singleton instantiation - - Uses the same lock to protect the registration process from concurrent calls - """ +# def register_static_layer(): +# PyTreeRegistryHelper().register(StaticLayer) - _instance = None # Class variable to hold the singleton instance - _has_called = False # Flag to track if registration has been performed - _lock = threading.Lock() # Class-level lock for thread-safe operations +# def register_dynamic_layer(): +# PyTreeRegistryHelper().register(DynamicLayer) - def __init__(self): - """Private constructor to prevent direct instantiation""" - pass - def __new__(cls, *args, **kwargs): - """ - Thread-safe singleton instance creation using double-checked locking pattern. +# class PyTreeRegistryHelper: +# """ +# Thread-safe singleton helper class for registering custom PyTree nodes. + +# Thread Safety: +# - Uses a class-level threading.Lock() to ensure thread-safe singleton instantiation +# - Uses the same lock to protect the registration process from concurrent calls +# """ + +# _instance = None # Class variable to hold the singleton instance +# _lock = threading.Lock() # Class-level lock for thread-safe operations + +# def __init__(self): +# """Private constructor to prevent direct instantiation""" +# pass + +# def __new__(cls, *args, **kwargs): +# """ +# Thread-safe singleton instance creation using double-checked locking pattern. + +# Returns: +# PyTreeRegistryHelper: The singleton instance of this class +# """ +# if not cls._instance: +# with cls._lock: # Acquire lock for thread-safe instantiation +# if not cls._instance: # Double-check after acquiring lock +# cls._instance = super().__new__(cls) +# return cls._instance + +# def register(self, cache_cls): +# """ +# Registers torch cache utility classes as a PyTree node for torch.export compatibility. + +# Raises: +# ImportError: If transformers package is not installed +# """ +# with self._lock: # Acquire lock for thread-safe registration +# if not is_transformers_installed: +# raise ImportError("transformers package is not installed") + +# import transformers +# if Version( +# "4.50.0" +# ) < Version(transformers.__version__) < Version("4.56.0"): +# logger = logging.getLogger(__name__) +# logger.warn("{} is be already registered as pytree-flattenable in transformers version 4.50.0 - 4.56.0. (Your transformers version: {transformers.__version__})") + +# try: +# torch.utils._pytree.register_pytree_node( +# cache_cls, +# _flatten_static_cache, +# _unflatten_static_cache, +# serialized_type_name=f"{cache_cls.__module__}.{cache_cls.__name__}", +# flatten_with_keys_fn=_flatten_with_keys_static_cache, +# ) +# torch.fx._pytree.register_pytree_flatten_spec( +# cache_cls, _flatten_static_cache_for_fx +# ) +# except ValueError as e: +# logger = logging.getLogger(__name__) +# logger.warning(f"{cache_cls} is already registered as pytree flattenable. {e}") + + +################################################################################## +# These _flatten_*/_unflatten_* function must be located **outside** - on module scope, not inside function, +# to be registered in pytree clearly. +################################################################################## + +def _flatten_static_cache(cache): + children = (cache.layers,) + aux_data = { + "layer_class_to_replicate": getattr(cache, "layer_class_to_replicate", None), + "offloading": getattr(cache, "offloading", False), + } + return children, aux_data - Returns: - PyTreeRegistryHelper: The singleton instance of this class +def _unflatten_static_cache(children, aux_data): + instance = StaticCache.__new__(StaticCache) + layers, = children + instance.layers = layers + + for key, value in aux_data.items(): + setattr(instance, key, value) + + return instance + +def _flatten_with_keys_static_cache(cache: StaticCache): + return torch.utils._pytree._dict_flatten_with_keys(cache.__dict__) + +def _flatten_static_cache_for_fx(cache, spec): + return torch.fx._pytree._dict_flatten_spec(cache.__dict__, spec) + +def register_static_cache(): + try: + torch.utils._pytree.register_pytree_node( + StaticCache, + _flatten_static_cache, + _unflatten_static_cache, + serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_static_cache, + ) + torch.fx._pytree.register_pytree_flatten_spec( + StaticCache, _flatten_static_cache_for_fx + ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.warning(f"StaticCache is already registered as pytree flattenable. {e}") + +# def _flatten_static_layer(cache: StaticLayer): +# nodes = { +# "keys": cache.keys, +# "values": cache.values, +# } +# return torch.utils._pytree._dict_flatten(nodes) + +# def _unflatten_static_layer(values, context: torch.utils._pytree.Context): +# data = torch.utils._pytree._dict_unflatten(values, context) + +# instance = StaticLayer.__new__(StaticLayer) +# for k, v in data.items(): +# setattr(instance, k, v) + +# return instance +from typing import Tuple, Any, Dict +def _flatten_static_layer(layer) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """ + 객체를 (변경 가능한 텐서들, 정적인 메타데이터)로 분리합니다. """ - if not cls._instance: - with cls._lock: # Acquire lock for thread-safe instantiation - if not cls._instance: # Double-check after acquiring lock - cls._instance = super().__new__(cls) - return cls._instance + # 1. Children: 모델 최적화나 자동 미분 시 추적해야 할 텐서 데이터 + children = (layer.keys, layer.values) + + # 2. Aux Data: 객체를 재구성할 때 필요한 정적 정보 (해시 가능해야 함) + aux_data = { + "max_cache_len": layer.max_cache_len, + "is_initialized": layer.is_initialized, + } + # 초기화된 경우라면 메타데이터(dtype, device 등)를 추가로 저장할 수 있습니다. + if layer.is_initialized: + aux_data.update({ + "dtype": layer.keys.dtype, + "device": layer.keys.device, + "max_batch_size": layer.max_batch_size, + "num_heads": layer.num_heads, + "k_head_dim": layer.k_head_dim, + "v_head_dim": layer.v_head_dim, + }) + + return children, aux_data - def register_dynamic_cache(self): + +def _unflatten_static_layer(children: Tuple[Any, ...], aux_data: Dict[str, Any]) -> "StaticLayer": + """ + flatten된 데이터로부터 새로운 객체를 복구합니다. """ - Registers DynamicCache as a PyTree node for torch.export compatibility. + keys, values = children + # 1. 새 인스턴스 생성 + obj = StaticLayer(max_cache_len=aux_data["max_cache_len"]) + + # 2. 상태 복구 + obj.is_initialized = aux_data["is_initialized"] + obj.keys = keys + obj.values = values + + # 3. 초기화되었던 상태라면 나머지 속성들도 복구 + if obj.is_initialized: + obj.dtype = aux_data["dtype"] + obj.device = aux_data["device"] + obj.max_batch_size = aux_data["max_batch_size"] + obj.num_heads = aux_data["num_heads"] + obj.k_head_dim = aux_data["k_head_dim"] + obj.v_head_dim = aux_data["v_head_dim"] + + return obj - This method is thread-safe and idempotent - it will only perform the - registration once, even if called multiple times from different threads. +def _flatten_with_keys_static_layer(cache: StaticLayer): + return torch.utils._pytree._dict_flatten_with_keys(cache.__dict__) - Note: - This registration is only needed for transformers versions below 4.50.0. +def _flatten_static_cache_layer(cache, spec): + return torch.fx._pytree._dict_flatten_spec(cache.__dict__, spec) - Raises: - ImportError: If transformers package is not installed - """ - with self._lock: # Acquire lock for thread-safe registration - if self.__class__._has_called: - logger = logging.getLogger(__name__) - logger.debug("register_dynamic_cache already called, skipping") - return +def register_static_layer(): + try: + torch.utils._pytree.register_pytree_node( + StaticLayer, + _flatten_static_layer, + _unflatten_static_layer, + serialized_type_name=f"{StaticLayer.__module__}.{StaticLayer.__name__}", + flatten_with_keys_fn=_flatten_with_keys_static_layer, + ) + torch.fx._pytree.register_pytree_flatten_spec( + StaticLayer, _flatten_static_cache_layer + ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.warning(f"StaticLayer is already registered as pytree flattenable. {e}") + + + +from typing import Tuple, Any, Dict +def _flatten_dynamic_layer(layer) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + if not layer.is_initialized: + raise ValueError(f"{layer} cannot be flattened. DynamicLayer must be initialized with tensor with specific shape to be used as input for torch.export") + + children = (layer.keys, layer.values) + + aux_data = {} + aux_data.update({ + "is_initialized": layer.is_initialized, + "dtype": layer.keys.dtype, + "device": layer.keys.device, + # "get_mask_sizes": layer.get_mask_sizes, + # "get_max_layer_shape": layer.get_max_layer_shape, + # "get_seq_length": layer.get_seq_length, + # "is_sliding": layer.is_sliding, + }) + + return children, aux_data + +def _unflatten_dynamic_layer(children: Tuple[Any, ...], aux_data: Dict[str, Any]) -> "DynamicLayer": + keys, values = children + obj = DynamicLayer() + + obj.keys = keys + obj.values = values + + obj.is_initialized = aux_data["is_initialized"] + obj.dtype = aux_data["dtype"] + obj.device = aux_data["device"] + # ADD OTHERS? + + return obj - self.__class__._has_called = True - logger = logging.getLogger(__name__) - logger.info("Registering DynamicCache PyTree node") +import torch.utils._pytree as pytree - if not is_transformers_installed: # type: ignore[truthy-function] - raise ImportError("transformers package is not installed") +def _flatten_with_keys_dynamic_layer(layer: DynamicLayer): + breakpoint() + children = [ + (pytree.MappingKeyPath("keys"), layer.keys), + (pytree.MappingKeyPath("values"), layer.values), + ] + + # 2. 텐서가 아닌 고정 정보(metadata) + aux_data = { + "is_initialized": layer.is_initialized, + "dtype": layer.keys.dtype, + "device": layer.keys.device, + } + + return children, aux_data - import transformers +import torch.fx._pytree as fx_pytree - HAS_TRANSFORMERS_LESS_4_50_0 = Version(transformers.__version__) < Version( - "4.50.0" +def _flatten_with_keys_dynamic_layer(layer: DynamicLayer, spec): + # DynamicLayer의 핵심 데이터를 딕셔너리 형태로 추출 + breakpoint() + # spec에 정의된 필드들과 일치해야 합니다. + layer_dict = { + "keys": layer.keys, + "values": layer.values, + "is_initialized": layer.is_initialized, + "dtype": layer.keys.dtype, + "device": layer.keys.device, + } + # FX의 dict_flatten_spec을 사용하여 spec 구조에 맞게 flatten + return fx_pytree._dict_flatten_spec(layer_dict, spec) + + +# def _flatten_with_keys_dynamic_layer(layer: DynamicLayer): +# return torch.utils._pytree._dict_flatten_with_keys(layer.__dict__) + +def _flatten_dynamic_layer_for_fx(layer, spec): + breakpoint() + return torch.fx._pytree._dict_flatten_spec(layer.__dict__, spec) + + +def register_dynamic_layer(): + try: + torch.utils._pytree.register_pytree_node( + DynamicLayer, + _flatten_dynamic_layer, + _unflatten_dynamic_layer, + serialized_type_name=f"{DynamicLayer.__module__}.{DynamicLayer.__name__}", + flatten_with_keys_fn=_flatten_with_keys_dynamic_layer, + ) + torch.fx._pytree.register_pytree_flatten_spec( + DynamicLayer, _flatten_dynamic_layer_for_fx ) - if not HAS_TRANSFORMERS_LESS_4_50_0: - return - - from transformers.cache_utils import DynamicCache - - def _flatten_dynamic_cache(dynamic_cache: DynamicCache): - if not isinstance(dynamic_cache, DynamicCache): - raise RuntimeError( - "This pytree flattening function should only be applied to DynamicCache" - ) - HAS_TORCH_2_6_0 = Version(torch.__version__) >= Version("2.6.0") - if not HAS_TORCH_2_6_0: - logger = logging.getLogger(__name__) - logger.warning_once( - "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." - ) - dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), - } - return torch.utils._pytree._dict_flatten(dictionary) - - def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): - dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), - } - return torch.utils._pytree._dict_flatten_with_keys(dictionary) - - def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context): - dictionary = torch.utils._pytree._dict_unflatten(values, context) - cache = DynamicCache() - for k, v in dictionary.items(): - setattr(cache, k, v) - return cache - - def _flatten_dynamic_cache_for_fx(cache, spec): - dictionary = { - "key_cache": getattr(cache, "key_cache"), - "value_cache": getattr(cache, "value_cache"), - } - return torch.fx._pytree._dict_flatten_spec(dictionary, spec) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.warning(f"DynamicLayer is already registered as pytree flattenable. {e}") + +def _flatten_dynamic_cache(cache: DynamicCache): + return torch.utils._pytree._dict_flatten(cache.__dict__) + +def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context): + data = torch.utils._pytree._dict_unflatten(values, context) + + instance = DynamicCache.__new__(DynamicCache) + + instance.__dict__.update(data) + return instance + +def _flatten_with_keys_dynamic_cache(cache: DynamicCache): + return torch.utils._pytree._dict_flatten_with_keys(cache.__dict__) + +def _flatten_dynamic_cache_for_fx(cache, spec): + return torch.fx._pytree._dict_flatten_spec(cache.__dict__, spec) + +def register_dynamic_cache(): + try: torch.utils._pytree.register_pytree_node( DynamicCache, _flatten_dynamic_cache, @@ -128,7 +354,55 @@ def _flatten_dynamic_cache_for_fx(cache, spec): serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, ) - # TODO: This won't be needed in torch 2.7+. torch.fx._pytree.register_pytree_flatten_spec( DynamicCache, _flatten_dynamic_cache_for_fx ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.warning(f"DynamicCache is already registered as pytree flattenable. {e}") + +def _flatten_encoder_decoder_cache(cache: EncoderDecoderCache): + # EncoderDecoderCache는 구조적으로 self/cross cache를 가집니다. + dictionary = { + "self_attention_cache": cache.self_attention_cache, + "cross_attention_cache": cache.cross_attention_cache, + } + return torch.utils._pytree._dict_flatten(dictionary) + +def _unflatten_encoder_decoder_cache(values, context: torch.utils._pytree.Context): + dictionary = torch.utils._pytree._dict_unflatten(values, context) + # __init__(*caches) 시그니처에 맞춰 복원 + return EncoderDecoderCache( + dictionary["self_attention_cache"], + dictionary["cross_attention_cache"] + ) + +def _flatten_with_keys_encoder_decoder_cache(cache: EncoderDecoderCache): + dictionary = { + "self_attention_cache": cache.self_attention_cache, + "cross_attention_cache": cache.cross_attention_cache, + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) + +def _flatten_encoder_decoder_cache_for_fx(cache, spec): + dictionary = { + "self_attention_cache": cache.self_attention_cache, + "cross_attention_cache": cache.cross_attention_cache, + } + return torch.fx._pytree._dict_flatten_spec(dictionary, spec) + +def register_encoder_decoder_cache(): + try: + torch.utils._pytree.register_pytree_node( + EncoderDecoderCache, + _flatten_encoder_decoder_cache, + _unflatten_encoder_decoder_cache, + serialized_type_name=f"{EncoderDecoderCache.__module__}.{EncoderDecoderCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_encoder_decoder_cache, + ) + torch.fx._pytree.register_pytree_flatten_spec( + EncoderDecoderCache, _flatten_encoder_decoder_cache_for_fx + ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.warning(f"EncoderDecoderCache is already registered as pytree flattenable. {e}") diff --git a/tico/utils/pytree_utils.save.md b/tico/utils/pytree_utils.save.md new file mode 100644 index 00000000..6681e575 --- /dev/null +++ b/tico/utils/pytree_utils.save.md @@ -0,0 +1,311 @@ +import threading + +import torch +from packaging.version import Version + +from tico.utils import logging +from tico.utils.installed_packages import is_transformers_installed + +import torch +from packaging.version import Version +from transformers.cache_utils import StaticCache, StaticLayer, DynamicCache, DynamicLayer, EncoderDecoderCache + +__all__ = [ + "register_dynamic_cache", + "register_static_cache", + # "register_dynamic_layer", + "register_static_layer", + "register_encoder_decoder_cache", + + # "register_cache_utils_to_pytree" # Covers most of torch cache-utils + ] + +# def register_encoder_decoder_cache(): +# PyTreeRegistryHelper().register(EncoderDecoderCache) + +# def register_dynamic_cache(): +# PyTreeRegistryHelper().register(DynamicCache) + +# def register_static_cache(): +# PyTreeRegistryHelper().register(StaticCache) + +# def register_static_layer(): +# PyTreeRegistryHelper().register(StaticLayer) + +# def register_dynamic_layer(): +# PyTreeRegistryHelper().register(DynamicLayer) + + +# class PyTreeRegistryHelper: +# """ +# Thread-safe singleton helper class for registering custom PyTree nodes. + +# Thread Safety: +# - Uses a class-level threading.Lock() to ensure thread-safe singleton instantiation +# - Uses the same lock to protect the registration process from concurrent calls +# """ + +# _instance = None # Class variable to hold the singleton instance +# _lock = threading.Lock() # Class-level lock for thread-safe operations + +# def __init__(self): +# """Private constructor to prevent direct instantiation""" +# pass + +# def __new__(cls, *args, **kwargs): +# """ +# Thread-safe singleton instance creation using double-checked locking pattern. + +# Returns: +# PyTreeRegistryHelper: The singleton instance of this class +# """ +# if not cls._instance: +# with cls._lock: # Acquire lock for thread-safe instantiation +# if not cls._instance: # Double-check after acquiring lock +# cls._instance = super().__new__(cls) +# return cls._instance + +# def register(self, cache_cls): +# """ +# Registers torch cache utility classes as a PyTree node for torch.export compatibility. + +# Raises: +# ImportError: If transformers package is not installed +# """ +# with self._lock: # Acquire lock for thread-safe registration +# if not is_transformers_installed: +# raise ImportError("transformers package is not installed") + +# import transformers +# if Version( +# "4.50.0" +# ) < Version(transformers.__version__) < Version("4.56.0"): +# logger = logging.getLogger(__name__) +# logger.warn("{} is be already registered as pytree-flattenable in transformers version 4.50.0 - 4.56.0. (Your transformers version: {transformers.__version__})") + +# try: +# torch.utils._pytree.register_pytree_node( +# cache_cls, +# _flatten_static_cache, +# _unflatten_static_cache, +# serialized_type_name=f"{cache_cls.__module__}.{cache_cls.__name__}", +# flatten_with_keys_fn=_flatten_with_keys_static_cache, +# ) +# torch.fx._pytree.register_pytree_flatten_spec( +# cache_cls, _flatten_static_cache_for_fx +# ) +# except ValueError as e: +# logger = logging.getLogger(__name__) +# logger.warning(f"{cache_cls} is already registered as pytree flattenable. {e}") + + +################################################################################## +# These _flatten_*/_unflatten_* function must be located **outside** - on module scope, not inside function, +# to be registered in pytree clearly. +################################################################################## + +def _flatten_static_cache(cache): + children = (cache.layers,) + aux_data = { + "layer_class_to_replicate": getattr(cache, "layer_class_to_replicate", None), + "offloading": getattr(cache, "offloading", False), + } + return children, aux_data + +def _unflatten_static_cache(children, aux_data): + instance = StaticCache.__new__(StaticCache) + layers, = children + instance.layers = layers + + for key, value in aux_data.items(): + setattr(instance, key, value) + + return instance + +def _flatten_with_keys_static_cache(cache: StaticCache): + return torch.utils._pytree._dict_flatten_with_keys(cache.__dict__) + +def _flatten_static_cache_for_fx(cache, spec): + return torch.fx._pytree._dict_flatten_spec(cache.__dict__, spec) + +def register_static_cache(): + try: + torch.utils._pytree.register_pytree_node( + StaticCache, + _flatten_static_cache, + _unflatten_static_cache, + serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_static_cache, + ) + torch.fx._pytree.register_pytree_flatten_spec( + StaticCache, _flatten_static_cache_for_fx + ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.warning(f"StaticCache is already registered as pytree flattenable. {e}") + +# def _flatten_static_layer(cache: StaticLayer): +# nodes = { +# "keys": cache.keys, +# "values": cache.values, +# } +# return torch.utils._pytree._dict_flatten(nodes) + +# def _unflatten_static_layer(values, context: torch.utils._pytree.Context): +# data = torch.utils._pytree._dict_unflatten(values, context) + +# instance = StaticLayer.__new__(StaticLayer) +# for k, v in data.items(): +# setattr(instance, k, v) + +# return instance +from typing import Tuple, Any, Dict +def _flatten_static_layer(cache) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """ + 객체를 (변경 가능한 텐서들, 정적인 메타데이터)로 분리합니다. + """ + # 1. Children: 모델 최적화나 자동 미분 시 추적해야 할 텐서 데이터 + children = (cache.keys, cache.values) + + # 2. Aux Data: 객체를 재구성할 때 필요한 정적 정보 (해시 가능해야 함) + aux_data = { + "max_cache_len": cache.max_cache_len, + "is_initialized": cache.is_initialized, + } + # 초기화된 경우라면 메타데이터(dtype, device 등)를 추가로 저장할 수 있습니다. + if cache.is_initialized: + aux_data.update({ + "dtype": cache.keys.dtype, + "device": cache.keys.device, + "max_batch_size": cache.max_batch_size, + "num_heads": cache.num_heads, + "k_head_dim": cache.k_head_dim, + "v_head_dim": cache.v_head_dim, + }) + + return children, aux_data + +def _unflatten_static_layer(children: Tuple[Any, ...], aux_data: Dict[str, Any]) -> "StaticLayer": + """ + flatten된 데이터로부터 새로운 객체를 복구합니다. + """ + keys, values = children + # 1. 새 인스턴스 생성 + obj = StaticLayer(max_cache_len=aux_data["max_cache_len"]) + + # 2. 상태 복구 + obj.is_initialized = aux_data["is_initialized"] + obj.keys = keys + obj.values = values + + # 3. 초기화되었던 상태라면 나머지 속성들도 복구 + if obj.is_initialized: + obj.dtype = aux_data["dtype"] + obj.device = aux_data["device"] + obj.max_batch_size = aux_data["max_batch_size"] + obj.num_heads = aux_data["num_heads"] + obj.k_head_dim = aux_data["k_head_dim"] + obj.v_head_dim = aux_data["v_head_dim"] + + return obj + +def _flatten_with_keys_static_layer(cache: StaticLayer): + return torch.utils._pytree._dict_flatten_with_keys(cache.__dict__) + +def _flatten_static_cache_layer(cache, spec): + return torch.fx._pytree._dict_flatten_spec(cache.__dict__, spec) + +def register_static_layer(): + try: + torch.utils._pytree.register_pytree_node( + StaticLayer, + _flatten_static_layer, + _unflatten_static_layer, + serialized_type_name=f"{StaticLayer.__module__}.{StaticLayer.__name__}", + flatten_with_keys_fn=_flatten_with_keys_static_layer, + ) + torch.fx._pytree.register_pytree_flatten_spec( + StaticLayer, _flatten_static_cache_layer + ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.warning(f"StaticLayer is already registered as pytree flattenable. {e}") + +def _flatten_dynamic_cache(cache: DynamicCache): + return torch.utils._pytree._dict_flatten(cache.__dict__) + +def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context): + data = torch.utils._pytree._dict_unflatten(values, context) + + instance = DynamicCache.__new__(DynamicCache) + + instance.__dict__.update(data) + return instance + +def _flatten_with_keys_dynamic_cache(cache: DynamicCache): + return torch.utils._pytree._dict_flatten_with_keys(cache.__dict__) + +def _flatten_dynamic_cache_for_fx(cache, spec): + return torch.fx._pytree._dict_flatten_spec(cache.__dict__, spec) + +def register_dynamic_cache(): + try: + torch.utils._pytree.register_pytree_node( + DynamicCache, + _flatten_dynamic_cache, + _unflatten_dynamic_cache, + serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, + ) + torch.fx._pytree.register_pytree_flatten_spec( + DynamicCache, _flatten_dynamic_cache_for_fx + ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.warning(f"DynamicCache is already registered as pytree flattenable. {e}") + +def _flatten_encoder_decoder_cache(cache: EncoderDecoderCache): + # EncoderDecoderCache는 구조적으로 self/cross cache를 가집니다. + dictionary = { + "self_attention_cache": cache.self_attention_cache, + "cross_attention_cache": cache.cross_attention_cache, + } + return torch.utils._pytree._dict_flatten(dictionary) + +def _unflatten_encoder_decoder_cache(values, context: torch.utils._pytree.Context): + dictionary = torch.utils._pytree._dict_unflatten(values, context) + # __init__(*caches) 시그니처에 맞춰 복원 + return EncoderDecoderCache( + dictionary["self_attention_cache"], + dictionary["cross_attention_cache"] + ) + +def _flatten_with_keys_encoder_decoder_cache(cache: EncoderDecoderCache): + dictionary = { + "self_attention_cache": cache.self_attention_cache, + "cross_attention_cache": cache.cross_attention_cache, + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) + +def _flatten_encoder_decoder_cache_for_fx(cache, spec): + dictionary = { + "self_attention_cache": cache.self_attention_cache, + "cross_attention_cache": cache.cross_attention_cache, + } + return torch.fx._pytree._dict_flatten_spec(dictionary, spec) + +def register_encoder_decoder_cache(): + try: + torch.utils._pytree.register_pytree_node( + EncoderDecoderCache, + _flatten_encoder_decoder_cache, + _unflatten_encoder_decoder_cache, + serialized_type_name=f"{EncoderDecoderCache.__module__}.{EncoderDecoderCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_encoder_decoder_cache, + ) + torch.fx._pytree.register_pytree_flatten_spec( + EncoderDecoderCache, _flatten_encoder_decoder_cache_for_fx + ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.warning(f"EncoderDecoderCache is already registered as pytree flattenable. {e}") From 1e3fa382c00b9bdfb9ca1c143bc999c023840a5e Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Thu, 5 Feb 2026 19:09:59 +0900 Subject: [PATCH 2/3] temp --- tico/utils/pytree_utils.py | 34 ++-------------------------------- 1 file changed, 2 insertions(+), 32 deletions(-) diff --git a/tico/utils/pytree_utils.py b/tico/utils/pytree_utils.py index 78f919c0..6e1d1431 100644 --- a/tico/utils/pytree_utils.py +++ b/tico/utils/pytree_utils.py @@ -144,21 +144,6 @@ def register_static_cache(): logger = logging.getLogger(__name__) logger.warning(f"StaticCache is already registered as pytree flattenable. {e}") -# def _flatten_static_layer(cache: StaticLayer): -# nodes = { -# "keys": cache.keys, -# "values": cache.values, -# } -# return torch.utils._pytree._dict_flatten(nodes) - -# def _unflatten_static_layer(values, context: torch.utils._pytree.Context): -# data = torch.utils._pytree._dict_unflatten(values, context) - -# instance = StaticLayer.__new__(StaticLayer) -# for k, v in data.items(): -# setattr(instance, k, v) - -# return instance from typing import Tuple, Any, Dict def _flatten_static_layer(layer) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """ @@ -191,15 +176,12 @@ def _unflatten_static_layer(children: Tuple[Any, ...], aux_data: Dict[str, Any]) flatten된 데이터로부터 새로운 객체를 복구합니다. """ keys, values = children - # 1. 새 인스턴스 생성 obj = StaticLayer(max_cache_len=aux_data["max_cache_len"]) - # 2. 상태 복구 obj.is_initialized = aux_data["is_initialized"] obj.keys = keys obj.values = values - # 3. 초기화되었던 상태라면 나머지 속성들도 복구 if obj.is_initialized: obj.dtype = aux_data["dtype"] obj.device = aux_data["device"] @@ -233,7 +215,6 @@ def register_static_layer(): logger.warning(f"StaticLayer is already registered as pytree flattenable. {e}") - from typing import Tuple, Any, Dict def _flatten_dynamic_layer(layer) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: if not layer.is_initialized: @@ -271,17 +252,13 @@ def _unflatten_dynamic_layer(children: Tuple[Any, ...], aux_data: Dict[str, Any] import torch.utils._pytree as pytree def _flatten_with_keys_dynamic_layer(layer: DynamicLayer): - breakpoint() children = [ (pytree.MappingKeyPath("keys"), layer.keys), (pytree.MappingKeyPath("values"), layer.values), ] - # 2. 텐서가 아닌 고정 정보(metadata) aux_data = { "is_initialized": layer.is_initialized, - "dtype": layer.keys.dtype, - "device": layer.keys.device, } return children, aux_data @@ -289,25 +266,18 @@ def _flatten_with_keys_dynamic_layer(layer: DynamicLayer): import torch.fx._pytree as fx_pytree def _flatten_with_keys_dynamic_layer(layer: DynamicLayer, spec): - # DynamicLayer의 핵심 데이터를 딕셔너리 형태로 추출 - breakpoint() - # spec에 정의된 필드들과 일치해야 합니다. layer_dict = { "keys": layer.keys, "values": layer.values, "is_initialized": layer.is_initialized, - "dtype": layer.keys.dtype, - "device": layer.keys.device, } - # FX의 dict_flatten_spec을 사용하여 spec 구조에 맞게 flatten return fx_pytree._dict_flatten_spec(layer_dict, spec) -# def _flatten_with_keys_dynamic_layer(layer: DynamicLayer): -# return torch.utils._pytree._dict_flatten_with_keys(layer.__dict__) +def _flatten_with_keys_dynamic_layer(layer: DynamicLayer): + return torch.utils._pytree._dict_flatten_with_keys(layer.__dict__) def _flatten_dynamic_layer_for_fx(layer, spec): - breakpoint() return torch.fx._pytree._dict_flatten_spec(layer.__dict__, spec) From 9cade01c62d641e903af5465d4f436f6cd9aa229 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Fri, 6 Feb 2026 16:40:11 +0900 Subject: [PATCH 3/3] fix DynamicCache related codes --- tico/utils/pytree_utils.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/tico/utils/pytree_utils.py b/tico/utils/pytree_utils.py index 6e1d1431..2c3634a1 100644 --- a/tico/utils/pytree_utils.py +++ b/tico/utils/pytree_utils.py @@ -298,17 +298,38 @@ def register_dynamic_layer(): logger.warning(f"DynamicLayer is already registered as pytree flattenable. {e}") -def _flatten_dynamic_cache(cache: DynamicCache): - return torch.utils._pytree._dict_flatten(cache.__dict__) -def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context): - data = torch.utils._pytree._dict_unflatten(values, context) - +def _flatten_dynamic_cache(cache): + children = (cache.layers,) + aux_data = { + "layer_class_to_replicate": getattr(cache, "layer_class_to_replicate", None), + "offloading": getattr(cache, "offloading", False), + } + return children, aux_data + +def _unflatten_dynamic_cache(children, aux_data): instance = DynamicCache.__new__(DynamicCache) + layers, = children + instance.layers = layers - instance.__dict__.update(data) + for key, value in aux_data.items(): + setattr(instance, key, value) + return instance + + +# def _flatten_dynamic_cache(cache: DynamicCache): +# return torch.utils._pytree._dict_flatten(cache.__dict__) + +# def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context): +# data = torch.utils._pytree._dict_unflatten(values, context) + +# instance = DynamicCache.__new__(DynamicCache) + +# instance.__dict__.update(data) +# return instance + def _flatten_with_keys_dynamic_cache(cache: DynamicCache): return torch.utils._pytree._dict_flatten_with_keys(cache.__dict__)