-
Notifications
You must be signed in to change notification settings - Fork 10
Open
Description
I love this exploration! Thanks for writing and coding this up. Right now, we're working on modifications to the causal conv1d and selective scan CUDA kernels to support defining the input state, so we are reviewing your code carefully.
What is the objective of the exponential fall-off in the cache clearing in train-infinite.py?
if completed_steps % clear_cache_interval == 0:
for layer_idx in range(model.config.n_layer):
conv_state = torch.zeros((1, model.config.d_model*2, 3), dtype=torch.bfloat16, device=accelerator.device).detach()
ssm_state = torch.zeros((1, model.config.d_model*2, 16), dtype=torch.bfloat16, device=accelerator.device).detach()
previous_hidden_states.append((conv_state, ssm_state))
clear_cache_interval *= 2Also, a general question: do you have a feeling for why your current implementation isn't working? Might vanishing gradients be an issue when running over longer sequences? I noticed that you're using bf16. I found this caused instability, and using amp for higher precision seemed to help.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels