DRESS: Disentangled Representation-based Self-Supervised Meta-Learning for Diverse Tasks [arXiv]
Authors: Wei Cui, Tongzi Wu, Jesse C. Cresswell, Yi Sui, Keyvan Golestan
This repository contains the official implementation of the paper DRESS: Disentangled Representation-based Self-Supervised Meta-Learning for Diverse Tasks. It includes both training and evaluation code.
The code files within the repository are organized as follows:
main.py: the main entrance point of the program.partition_generators.py: implementation of generating supervised and self-supervised partitions on each dataset.task_generator.py: implementation of generating few-shot learning tasks from any given partition.utils.py: implementation of helper functions.
The sub-folders within the repository are as follows:
scripts/: the folder including the scripts to train, evaluate, and obtain visulizations.encoders/: the folder containing classes of encoders for obtaining the latent spaces.dataset_loaders/: the folder containing scripts for loading each of the dataset for experiments.baselines/: the folder containing implementations of baseline methods.analyze_results/: the folder containing scripts for post-processing results.visualization_results/: the folder containing visualizations on constructed tasks via DRESS.
Create a folder named data/ under the main directory to house the raw data.
The datasets experimented are loaded from their respective dataset loader script under dataset_loaders/. The source data preparations are as follows:
- smallNORB: automatically downloaded within our script via the
tensorflow_datasetspackage. - shapes3D: download
3dshapes.h5from Google Cloud Storage and place it underdata/shapes3d/. - causal3D: download
trainset.tar.gzandtestset.tar.gzfrom the dataset homepage and extract them underdata/causal3d/train/anddata/causal3d/test/resectively. - MPI3D: download
mpi3d_toy.npzfrom this link and place it underdata/mpi3d/. - CelebA: automatically downloaded within our script via the
torchvisionpackage.
Simply install an anaconda environment using the environment.yml file under this repository.
