-
Notifications
You must be signed in to change notification settings - Fork 210
Description
I was experimenting with training all-snake models (https://github.com/victor-shepardson/RAVE/blob/vs-exp/rave/configs/allsnake2.gin) and noticed some complete failures to train. In the Snake activation, I noticed (self.alpha + 1e-9).reciprocal() doesn't stabilize the reciprocal as it appears intended to, because alpha is not constrained to be positive. Instead Snake can be expressed without any division by alpha via torch.sinc:
@gin.configurable
class Snake(nn.Module):
def __init__(self, dim: int, init:float=1) -> None:
super().__init__()
self.alpha = nn.Parameter(torch.ones(dim, 1).mul_(init))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.alpha*(x*(self.alpha*x/torch.pi).sinc()).pow(2)(see https://www.wolframalpha.com/input?i=x+%2B+a*x*x*sinc%28a*x%29**2)
in practice I'm not sure how much this matters. It turned out my dataset needed a lower value for blocks.Snake.init, regardless of the sin or sinc implementation 🤷
There's also this more complicated implementation using a custom Function with manual backward pass and torch.where: https://github.com/EdwardDixon/snake/blob/master/snake/activations.py#L37