-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
It would be nice if we don't need to write code for example inputs.
For example, we need something like:
TICO/test/modules/model/LlamaWithKVCache/model.py
Lines 36 to 70 in a3ed23e
| def get_example_inputs(self): | |
| past_seq_len = 511 | |
| cur_seq_len = 1 | |
| input_ids = torch.tensor([[812]]).to(torch.long) | |
| attention_mask = torch.ones(1, past_seq_len + cur_seq_len) | |
| position_ids = torch.tensor([[past_seq_len]]).to(torch.long) | |
| past_key_values = DynamicCache() | |
| for layer_id in range(self.config.num_hidden_layers): | |
| past_key_values.update( | |
| torch.randn( | |
| [ | |
| 1, | |
| self.config.num_attention_heads, | |
| past_seq_len, | |
| self.config.head_dim, | |
| ] | |
| ), | |
| torch.randn( | |
| [ | |
| 1, | |
| self.config.num_attention_heads, | |
| past_seq_len, | |
| self.config.head_dim, | |
| ] | |
| ), | |
| layer_id, | |
| ) | |
| return ( | |
| input_ids, | |
| attention_mask, | |
| position_ids, | |
| past_key_values, | |
| ) |
To correctly working example input, we need to understand much about the target model.
I tried to capture the input from user-level inputs.
I succeed to write a working version, tested with Maykeye/TinyLLama-v0.
It confirmed it works for transformers version below:
- 4.49.0 ❌ DynamicCache is not pytree-flattenable
- 4.51.3 ⭕
- 4.52.4 ⭕
A good news is that PyTorch supports transformers DynamicCahce as pytree-flatenable since 4.50.0. (though it seems some bug on MacOS, all versions after 4.50.1 would work.
I hope and guess it will work for other models by modifying model name and user inputs.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels