In the class ConsistencySamplingAndEditing, the __mask_transform method, which I think corresponds to Algorithm 4 in the paper, essentially applies the A matrix as in the paper. In that case, the transform_fn should also apply on x?
def __mask_transform(
self,
x: Tensor,
y: Tensor,
mask: Tensor,
transform_fn: Callable[[Tensor], Tensor] = lambda x: x,
inverse_transform_fn: Callable[[Tensor], Tensor] = lambda x: x,
) -> Tensor:
return inverse_transform_fn(transform_fn(y) * (1.0 - mask) + x * mask)