From 29d7e0f3ba144868ca6e4f782079b6770f7b26ca Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 10 Jul 2025 10:37:39 +0900 Subject: [PATCH 1/9] [test] Use input captured from user input It captures the target model input from user input (e.g. prompt). It helps us to prepare example input in high level. TICO-DCO-1.0-Signed-off-by: Sanggyu Lee --- .../modules/model/LlamaCapturedInput/model.py | 83 +++++++++++++++++++ .../model/LlamaCapturedInput/requirements.txt | 1 + 2 files changed, 84 insertions(+) create mode 100644 test/modules/model/LlamaCapturedInput/model.py create mode 100644 test/modules/model/LlamaCapturedInput/requirements.txt diff --git a/test/modules/model/LlamaCapturedInput/model.py b/test/modules/model/LlamaCapturedInput/model.py new file mode 100644 index 00000000..a49b011b --- /dev/null +++ b/test/modules/model/LlamaCapturedInput/model.py @@ -0,0 +1,83 @@ +# User input +prompt = "Lily picked up a flower." +model_name = "Maykeye/TinyLLama-v0" + +# Capturer +captured_inputs = {} + +import inspect, copy ,types +from transformers import LlamaForCausalLM +original_forward = LlamaForCausalLM.forward + +def patched_forward(self, *args, **kwargs): + global captured_inputs + + # Prepare args tuple for TICO.convert() + # Get arg_names in positional args order using inspect + sig = inspect.signature(original_forward) + args_names = [ + # signature includes `self`` and `kwargs``. + # Just retrieve the ordinary positional inputs only + name for name in sig.parameters.keys() if name not in ("self", "kwargs") + ] + + args_dict = dict(zip(args_names, args)) + args_dict.update(kwargs) + + def populate_args(args_dict, filter): + for key in filter: + args_dict.pop(key, None) + args_tuple = tuple(args_dict.get(name, None) for name in args_names) + return copy.deepcopy(args_tuple) + + if captured_inputs.get("prefill", None) == None: + input_to_remove = [ + "past_key_values", + "use_cache", + "attention_mask", + "cache_position", + ] + captured_inputs["prefill"] = populate_args(args_dict, input_to_remove) + elif captured_inputs.get("decode", None) == None: + input_to_remove = [ "use_cache" ] + captured_inputs["decode"] = populate_args(args_dict, input_to_remove) + + return original_forward(self, *args, **kwargs) + + +# Tokenizer +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" +inputs = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=32, + truncation=True, +) + +# Generator +from transformers import AutoModelForCausalLM +import torch +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() +model.forward = types.MethodType(patched_forward, model) +with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=32, + do_sample=False, + pad_token_id=tokenizer.eos_token_id + ) +generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(generated_text) + +# Tico +import tico +for key in captured_inputs.keys(): + model = AutoModelForCausalLM.from_pretrained(model_name) + model.eval() + circle_model = tico.convert(model, captured_inputs[key]) + circle_model.save(f'{key}.circle') diff --git a/test/modules/model/LlamaCapturedInput/requirements.txt b/test/modules/model/LlamaCapturedInput/requirements.txt new file mode 100644 index 00000000..5393938f --- /dev/null +++ b/test/modules/model/LlamaCapturedInput/requirements.txt @@ -0,0 +1 @@ +transformers>=4.50.1 From 87b55ebe2b7a83532763f9bdcdf4001153510fc0 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 10 Jul 2025 11:07:30 +0900 Subject: [PATCH 2/9] Make formatter happy --- test/modules/model/LlamaCapturedInput/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/modules/model/LlamaCapturedInput/model.py b/test/modules/model/LlamaCapturedInput/model.py index a49b011b..fb6bed72 100644 --- a/test/modules/model/LlamaCapturedInput/model.py +++ b/test/modules/model/LlamaCapturedInput/model.py @@ -39,7 +39,7 @@ def populate_args(args_dict, filter): ] captured_inputs["prefill"] = populate_args(args_dict, input_to_remove) elif captured_inputs.get("decode", None) == None: - input_to_remove = [ "use_cache" ] + input_to_remove = ["use_cache"] captured_inputs["decode"] = populate_args(args_dict, input_to_remove) return original_forward(self, *args, **kwargs) @@ -80,4 +80,4 @@ def populate_args(args_dict, filter): model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() circle_model = tico.convert(model, captured_inputs[key]) - circle_model.save(f'{key}.circle') + circle_model.save(f"{key}.circle") From 98c0fafa1d47990984db765e64a939bfcbd43d28 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 10 Jul 2025 13:41:23 +0900 Subject: [PATCH 3/9] make pylint happy though it makes it worse TICO-DCO-1.0-Signed-off-by: Sanggyu Lee --- test/modules/model/LlamaCapturedInput/model.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/test/modules/model/LlamaCapturedInput/model.py b/test/modules/model/LlamaCapturedInput/model.py index fb6bed72..c83a0fdd 100644 --- a/test/modules/model/LlamaCapturedInput/model.py +++ b/test/modules/model/LlamaCapturedInput/model.py @@ -2,13 +2,18 @@ prompt = "Lily picked up a flower." model_name = "Maykeye/TinyLLama-v0" +import torch + # Capturer captured_inputs = {} -import inspect, copy ,types +import copy, inspect, types + from transformers import LlamaForCausalLM + original_forward = LlamaForCausalLM.forward + def patched_forward(self, *args, **kwargs): global captured_inputs @@ -18,7 +23,9 @@ def patched_forward(self, *args, **kwargs): args_names = [ # signature includes `self`` and `kwargs``. # Just retrieve the ordinary positional inputs only - name for name in sig.parameters.keys() if name not in ("self", "kwargs") + name + for name in sig.parameters.keys() + if name not in ("self", "kwargs") ] args_dict = dict(zip(args_names, args)) @@ -47,6 +54,7 @@ def populate_args(args_dict, filter): # Tokenizer from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" @@ -58,9 +66,10 @@ def populate_args(args_dict, filter): truncation=True, ) + # Generator from transformers import AutoModelForCausalLM -import torch + model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() model.forward = types.MethodType(patched_forward, model) @@ -69,13 +78,14 @@ def populate_args(args_dict, filter): **inputs, max_new_tokens=32, do_sample=False, - pad_token_id=tokenizer.eos_token_id + pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) print(generated_text) # Tico import tico + for key in captured_inputs.keys(): model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() From 7995511c5d90fab081eb655652cd56d20a1875c2 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 10 Jul 2025 16:03:57 +0900 Subject: [PATCH 4/9] rename LlamaCapturedInput to LlamaPrefillDecode --- .../model/{LlamaCapturedInput => LlamaPrefillDecode}/model.py | 0 .../{LlamaCapturedInput => LlamaPrefillDecode}/requirements.txt | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename test/modules/model/{LlamaCapturedInput => LlamaPrefillDecode}/model.py (100%) rename test/modules/model/{LlamaCapturedInput => LlamaPrefillDecode}/requirements.txt (100%) diff --git a/test/modules/model/LlamaCapturedInput/model.py b/test/modules/model/LlamaPrefillDecode/model.py similarity index 100% rename from test/modules/model/LlamaCapturedInput/model.py rename to test/modules/model/LlamaPrefillDecode/model.py diff --git a/test/modules/model/LlamaCapturedInput/requirements.txt b/test/modules/model/LlamaPrefillDecode/requirements.txt similarity index 100% rename from test/modules/model/LlamaCapturedInput/requirements.txt rename to test/modules/model/LlamaPrefillDecode/requirements.txt From b3ed5af3ee32520e56c99986e954a88cc54257ba Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 10 Jul 2025 17:31:21 +0900 Subject: [PATCH 5/9] Suppress lint error about type annotation --- test/modules/model/LlamaPrefillDecode/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/modules/model/LlamaPrefillDecode/model.py b/test/modules/model/LlamaPrefillDecode/model.py index c83a0fdd..b747bf9a 100644 --- a/test/modules/model/LlamaPrefillDecode/model.py +++ b/test/modules/model/LlamaPrefillDecode/model.py @@ -5,7 +5,7 @@ import torch # Capturer -captured_inputs = {} +captured_inputs = {} # type: ignore[var-annotated] import copy, inspect, types From 8708c72a1a2a42974d294768ed6030d8beafbe0d Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 17 Jul 2025 10:16:13 +0900 Subject: [PATCH 6/9] Separate prefill and decode --- .../model.py | 35 ++++++++----------- .../requirements.txt | 0 2 files changed, 15 insertions(+), 20 deletions(-) rename test/modules/model/{LlamaPrefillDecode => LlamaPrefill}/model.py (66%) rename test/modules/model/{LlamaPrefillDecode => LlamaPrefill}/requirements.txt (100%) diff --git a/test/modules/model/LlamaPrefillDecode/model.py b/test/modules/model/LlamaPrefill/model.py similarity index 66% rename from test/modules/model/LlamaPrefillDecode/model.py rename to test/modules/model/LlamaPrefill/model.py index b747bf9a..84993556 100644 --- a/test/modules/model/LlamaPrefillDecode/model.py +++ b/test/modules/model/LlamaPrefill/model.py @@ -2,24 +2,21 @@ prompt = "Lily picked up a flower." model_name = "Maykeye/TinyLLama-v0" -import torch - -# Capturer -captured_inputs = {} # type: ignore[var-annotated] +captured_input = None # type: ignore[var-annotated] import copy, inspect, types from transformers import LlamaForCausalLM -original_forward = LlamaForCausalLM.forward +forward_old = LlamaForCausalLM.forward -def patched_forward(self, *args, **kwargs): - global captured_inputs +def capture_and_forward(self, *args, **kwargs): + global captured_input # Prepare args tuple for TICO.convert() # Get arg_names in positional args order using inspect - sig = inspect.signature(original_forward) + sig = inspect.signature(forward_old) args_names = [ # signature includes `self`` and `kwargs``. # Just retrieve the ordinary positional inputs only @@ -37,19 +34,16 @@ def populate_args(args_dict, filter): args_tuple = tuple(args_dict.get(name, None) for name in args_names) return copy.deepcopy(args_tuple) - if captured_inputs.get("prefill", None) == None: + if len(args_dict["past_key_values"].key_cache) == 0: input_to_remove = [ "past_key_values", "use_cache", "attention_mask", "cache_position", ] - captured_inputs["prefill"] = populate_args(args_dict, input_to_remove) - elif captured_inputs.get("decode", None) == None: - input_to_remove = ["use_cache"] - captured_inputs["decode"] = populate_args(args_dict, input_to_remove) + captured_input = populate_args(args_dict, input_to_remove) - return original_forward(self, *args, **kwargs) + return forward_old(self, *args, **kwargs) # Tokenizer @@ -68,11 +62,13 @@ def populate_args(args_dict, filter): # Generator +import torch + from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() -model.forward = types.MethodType(patched_forward, model) +model.forward = types.MethodType(capture_and_forward, model) with torch.no_grad(): outputs = model.generate( **inputs, @@ -86,8 +82,7 @@ def populate_args(args_dict, filter): # Tico import tico -for key in captured_inputs.keys(): - model = AutoModelForCausalLM.from_pretrained(model_name) - model.eval() - circle_model = tico.convert(model, captured_inputs[key]) - circle_model.save(f"{key}.circle") +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() +circle_model = tico.convert(model, captured_input) +circle_model.save(f"llama.prefill.circle") diff --git a/test/modules/model/LlamaPrefillDecode/requirements.txt b/test/modules/model/LlamaPrefill/requirements.txt similarity index 100% rename from test/modules/model/LlamaPrefillDecode/requirements.txt rename to test/modules/model/LlamaPrefill/requirements.txt From 0ccd7a068737eb9b92454b668d972138cbdd9e97 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 17 Jul 2025 11:57:51 +0900 Subject: [PATCH 7/9] Rename and make lint happy --- test/modules/model/LlamaPrefill/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/modules/model/LlamaPrefill/model.py b/test/modules/model/LlamaPrefill/model.py index 84993556..570472f9 100644 --- a/test/modules/model/LlamaPrefill/model.py +++ b/test/modules/model/LlamaPrefill/model.py @@ -2,13 +2,13 @@ prompt = "Lily picked up a flower." model_name = "Maykeye/TinyLLama-v0" -captured_input = None # type: ignore[var-annotated] +captured_input = () import copy, inspect, types from transformers import LlamaForCausalLM -forward_old = LlamaForCausalLM.forward +forward_org = LlamaForCausalLM.forward def capture_and_forward(self, *args, **kwargs): @@ -16,7 +16,7 @@ def capture_and_forward(self, *args, **kwargs): # Prepare args tuple for TICO.convert() # Get arg_names in positional args order using inspect - sig = inspect.signature(forward_old) + sig = inspect.signature(forward_org) args_names = [ # signature includes `self`` and `kwargs``. # Just retrieve the ordinary positional inputs only @@ -43,7 +43,7 @@ def populate_args(args_dict, filter): ] captured_input = populate_args(args_dict, input_to_remove) - return forward_old(self, *args, **kwargs) + return forward_org(self, *args, **kwargs) # Tokenizer From 037940f14e19539097584e4ca83a265001f09c96 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 17 Jul 2025 12:09:33 +0900 Subject: [PATCH 8/9] Use Cache.get_seq_length() instead of len(Cache.key_cache) --- test/modules/model/LlamaPrefill/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/modules/model/LlamaPrefill/model.py b/test/modules/model/LlamaPrefill/model.py index 570472f9..2228a06e 100644 --- a/test/modules/model/LlamaPrefill/model.py +++ b/test/modules/model/LlamaPrefill/model.py @@ -34,7 +34,7 @@ def populate_args(args_dict, filter): args_tuple = tuple(args_dict.get(name, None) for name in args_names) return copy.deepcopy(args_tuple) - if len(args_dict["past_key_values"].key_cache) == 0: + if args_dict["past_key_values"].get_seq_length() == 0: input_to_remove = [ "past_key_values", "use_cache", From 4f55be8eb846663671b70adee44e997c9b310a11 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 17 Jul 2025 13:56:05 +0900 Subject: [PATCH 9/9] Rename output file name --- test/modules/model/LlamaPrefill/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/modules/model/LlamaPrefill/model.py b/test/modules/model/LlamaPrefill/model.py index 2228a06e..cae392ca 100644 --- a/test/modules/model/LlamaPrefill/model.py +++ b/test/modules/model/LlamaPrefill/model.py @@ -85,4 +85,4 @@ def populate_args(args_dict, filter): model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() circle_model = tico.convert(model, captured_input) -circle_model.save(f"llama.prefill.circle") +circle_model.save(f"tinyllama.prefill.circle")