Skip to content

cache clearing interval for previous hidden states #2

@ekg

Description

@ekg

I love this exploration! Thanks for writing and coding this up. Right now, we're working on modifications to the causal conv1d and selective scan CUDA kernels to support defining the input state, so we are reviewing your code carefully.

What is the objective of the exponential fall-off in the cache clearing in train-infinite.py?

        if completed_steps % clear_cache_interval == 0:
            for layer_idx in range(model.config.n_layer):
                conv_state = torch.zeros((1, model.config.d_model*2, 3), dtype=torch.bfloat16, device=accelerator.device).detach()
                ssm_state = torch.zeros((1, model.config.d_model*2, 16), dtype=torch.bfloat16, device=accelerator.device).detach()
                previous_hidden_states.append((conv_state, ssm_state))
            clear_cache_interval *= 2

Also, a general question: do you have a feeling for why your current implementation isn't working? Might vanishing gradients be an issue when running over longer sequences? I noticed that you're using bf16. I found this caused instability, and using amp for higher precision seemed to help.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions