Skip to content

stable Snake activation #351

@victor-shepardson

Description

@victor-shepardson

I was experimenting with training all-snake models (https://github.com/victor-shepardson/RAVE/blob/vs-exp/rave/configs/allsnake2.gin) and noticed some complete failures to train. In the Snake activation, I noticed (self.alpha + 1e-9).reciprocal() doesn't stabilize the reciprocal as it appears intended to, because alpha is not constrained to be positive. Instead Snake can be expressed without any division by alpha via torch.sinc:

@gin.configurable
class Snake(nn.Module):

    def __init__(self, dim: int, init:float=1) -> None:
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(dim, 1).mul_(init))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.alpha*(x*(self.alpha*x/torch.pi).sinc()).pow(2)

(see https://www.wolframalpha.com/input?i=x+%2B+a*x*x*sinc%28a*x%29**2)

in practice I'm not sure how much this matters. It turned out my dataset needed a lower value for blocks.Snake.init, regardless of the sin or sinc implementation 🤷

There's also this more complicated implementation using a custom Function with manual backward pass and torch.where: https://github.com/EdwardDixon/snake/blob/master/snake/activations.py#L37

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions