Skip to content

Auto example input generation #207

@glistening

Description

@glistening

It would be nice if we don't need to write code for example inputs.

For example, we need something like:

def get_example_inputs(self):
past_seq_len = 511
cur_seq_len = 1
input_ids = torch.tensor([[812]]).to(torch.long)
attention_mask = torch.ones(1, past_seq_len + cur_seq_len)
position_ids = torch.tensor([[past_seq_len]]).to(torch.long)
past_key_values = DynamicCache()
for layer_id in range(self.config.num_hidden_layers):
past_key_values.update(
torch.randn(
[
1,
self.config.num_attention_heads,
past_seq_len,
self.config.head_dim,
]
),
torch.randn(
[
1,
self.config.num_attention_heads,
past_seq_len,
self.config.head_dim,
]
),
layer_id,
)
return (
input_ids,
attention_mask,
position_ids,
past_key_values,
)

To correctly working example input, we need to understand much about the target model.

I tried to capture the input from user-level inputs.

I succeed to write a working version, tested with Maykeye/TinyLLama-v0.

It confirmed it works for transformers version below:

  • 4.49.0 ❌ DynamicCache is not pytree-flattenable
  • 4.51.3 ⭕
  • 4.52.4 ⭕

A good news is that PyTorch supports transformers DynamicCahce as pytree-flatenable since 4.50.0. (though it seems some bug on MacOS, all versions after 4.50.1 would work.

I hope and guess it will work for other models by modifying model name and user inputs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions