Skip to content

SamplingAndEditing: mask transform should also apply to x? #14

@sentient-codebot

Description

@sentient-codebot

In the class ConsistencySamplingAndEditing, the __mask_transform method, which I think corresponds to Algorithm 4 in the paper, essentially applies the A matrix as in the paper. In that case, the transform_fn should also apply on x?

def __mask_transform(
        self,
        x: Tensor,
        y: Tensor,
        mask: Tensor,
        transform_fn: Callable[[Tensor], Tensor] = lambda x: x,
        inverse_transform_fn: Callable[[Tensor], Tensor] = lambda x: x,
    ) -> Tensor:
        return inverse_transform_fn(transform_fn(y) * (1.0 - mask) + x * mask)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions