-
Notifications
You must be signed in to change notification settings - Fork 21
Description
I am looking into a modification of a regular masked autoregressive flow where the base distribution is an N-dimensional uniform and the first variable does not get transformed, while the rest of the variables get transformed via a rational quadratic spline. I have removed the shuffling in the masked_autoregressive_flow function via removing the _add_default_permute, and modified the _flat_params_to_transformer in the MaskedAutoregressive class to apply an Identity transformer to the first dimension in the following way
def _flat_params_to_transformer(self, params: Array, y_dim=1):
"""Reshape to dim X params_per_dim, then vmap."""
dim = self.shape[-1]
transformer_params = jnp.reshape(params, (dim, -1))
transformer_params = transformer_params[y_dim:, :]
transformer = eqx.filter_vmap(self.transformer_constructor)(transformer_params)
return Concatenate(
[Identity((y_dim,)), Vmap(transformer, in_axes=eqx.if_array(0))]
)
My understanding is that in this way the masked_autoregressive_mlp will still produce a set of spline parameters for the first variable, that then never get used, and that this should be harmless. My experiments seem to produce the expected results but I am not sure that this is the most efficient way to go about this or whether I am disregarding anything relevant, so would love to hear your opinion as to how to make the best use of your package. Thanks again for all the amazing work!