Skip to content

replace forget_gate * cell_gate with input_gate * cell_gate#181

Merged
ArmRyan merged 3 commits intoARM-software:mainfrom
seh2bp:bugfix/lstm_use_input_gate
Jul 17, 2025
Merged

replace forget_gate * cell_gate with input_gate * cell_gate#181
ArmRyan merged 3 commits intoARM-software:mainfrom
seh2bp:bugfix/lstm_use_input_gate

Conversation

@seh2bp
Copy link
Contributor

@seh2bp seh2bp commented Jul 3, 2025

def lstm_cell(
    input: Tensor,
    hidden: tuple[Tensor, Tensor],
    w_ih: Tensor,
    w_hh: Tensor,
    b_ih: Tensor,
    b_hh: Tensor,
) -> tuple[Tensor, Tensor]:
    hx, cx = hidden
    gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh

    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)

-   cy = (forgetgate * cx) + (forgetgate * cellgate)
+   cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * torch.tanh(cy)

    return hy, cy

@ArmRyan ArmRyan self-assigned this Jul 8, 2025
@ArmRyan
Copy link
Collaborator

ArmRyan commented Jul 8, 2025

Hi @seh2bp , Thank you for the contribution!

I am not very familiar with LSTM so I will just need to verify the change and run the patch through our CI also

@ArmRyan
Copy link
Collaborator

ArmRyan commented Jul 16, 2025

Hi @seh2bp I ran this through our ci and its all good, the only thing missing is the versions on both files need to be updated and also the dates, then I am good to approve it!

I am wondering though if you had a problem with the function before? The forget gate and input gate both point to the same buffer so the output should be the same. I think this is still more clear so I think it is a good change, but if you had a problem is shouldnt fix it?

Copy link
Collaborator

@ArmRyan ArmRyan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • $Date: 26 March 2024
  • $Revision: V.1.0.0

Should be updated to current date and version should be V.1.0.1 in both files

@seh2bp
Copy link
Contributor Author

seh2bp commented Jul 16, 2025

Hi @seh2bp I ran this through our ci and its all good, the only thing missing is the versions on both files need to be updated and also the dates, then I am good to approve it!

I am wondering though if you had a problem with the function before? The forget gate and input gate both point to the same buffer so the output should be the same. I think this is still more clear so I think it is a good change, but if you had a problem is shouldnt fix it?

Updated the revision number & date.
I changed the code for my test before checking whether the original works.
I agree it was not a bug, but it did confuse me, and also agree that with the change further confusion might be prevented.

@ArmRyan
Copy link
Collaborator

ArmRyan commented Jul 17, 2025

I changed the code for my test before checking whether the original works. I agree it was not a bug, but it did confuse me, and also agree that with the change further confusion might be prevented.

Ok great, just wanted to make sure that there was no other bug to deal with! Thanks for the contribution :)

@ArmRyan ArmRyan merged commit 88f1982 into ARM-software:main Jul 17, 2025
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants