diff --git a/mamba-1p1p1/mamba_ssm/modules/mamba_simple.py b/mamba-1p1p1/mamba_ssm/modules/mamba_simple.py index ff1536c..caf1c30 100644 --- a/mamba-1p1p1/mamba_ssm/modules/mamba_simple.py +++ b/mamba-1p1p1/mamba_ssm/modules/mamba_simple.py @@ -397,7 +397,7 @@ def forward(self, hidden_states, inference_params=None): ) else: out = F.linear( - rearrange(out + out_b.flip([-1]), "b d l -> b l d") / 2, + (rearrange(out + out_b.flip([-1]), "b d l -> b l d") / 2) * F.silu(rearrange(z, "b d l -> b l d")), self.out_proj.weight, self.out_proj.bias, )