From 9e0d45db8de2f730d82b532cf042af663ff3ffa7 Mon Sep 17 00:00:00 2001 From: Saarthak Kapse <44542181+saarthak-kapse@users.noreply.github.com> Date: Fri, 7 Feb 2025 21:07:01 -0500 Subject: [PATCH] Update mamba_simple.py --- mamba-1p1p1/mamba_ssm/modules/mamba_simple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba-1p1p1/mamba_ssm/modules/mamba_simple.py b/mamba-1p1p1/mamba_ssm/modules/mamba_simple.py index ff1536c..caf1c30 100644 --- a/mamba-1p1p1/mamba_ssm/modules/mamba_simple.py +++ b/mamba-1p1p1/mamba_ssm/modules/mamba_simple.py @@ -397,7 +397,7 @@ def forward(self, hidden_states, inference_params=None): ) else: out = F.linear( - rearrange(out + out_b.flip([-1]), "b d l -> b l d") / 2, + (rearrange(out + out_b.flip([-1]), "b d l -> b l d") / 2) * F.silu(rearrange(z, "b d l -> b l d")), self.out_proj.weight, self.out_proj.bias, )