Conversation
There was a problem hiding this comment.
I'd be curious to see the performance benefit here compared to plain numpy. Usually numba doesn't optimize much when using a lot of numpy operations.
There was a problem hiding this comment.
Yes, and the DSS is also by far not as computationlly intensive as e.g. ES or VS so I guess even if we get a good relative speed improvement the user might not feel it
There was a problem hiding this comment.
I suppose this would look more or less the same for the other backends?
There was a problem hiding this comment.
Pretty much I guess. I have not so much experience in tf/torch though, so I'm not sure yet about the exact implementation.
There was a problem hiding this comment.
The det and inv functions should be easy/similar to define in the other backends.
For cov:
- jax should be almost exactly the same as numpy.
- torch would need "tranpose" being replaced with "permute".
- tensorflow also uses "transpose" but does not have a "cov" function, so this would need to be adapted.
Although, as mentioned in a separate comment, the transpose function currently assumes we only have 3 dimensions.
| return np.cov(x) | ||
| else: | ||
| centered = x - x.mean(axis=-2, keepdims=True) | ||
| return (centered.transpose(0, 2, 1) @ centered) / (x.shape[-2] - 1) |
There was a problem hiding this comment.
centered.transpose(0, 2, 1) assumes the array is 3D. Is there a way to generalise this to calculate the covariance matrix for higher dimensions?
There was a problem hiding this comment.
Merry Christmas :)
If we assume that we always have the square matrix in the last two dimensions (like numpy does on the batched version of e.g. np.linalg.inv()), something like centered.swapaxes(-1, -2) should do the trick. Is that convention fine for you @sallen12?
|
Have now addressed these remaining tasks in #104, so will now close this PR. Thanks a lot Simon for your help! |
Hi @frazane as discussed, please have a look.
Compared to #32 I've calculated the batched covariance here as matrix multiplication which I feel is the cleaner way - but a bit more tricky to figure out the correct axes to multiply along.
To-Do: