- torch==1.4.0
- PyYAML==3.13
We borrow the embedding from the deepmind/leo repo
You can download the pretrained embeddings here,
or do
$ wget http://storage.googleapis.com/leo-embeddings/embeddings.zip
$ unzip embeddings.zip
python3 main.py -train \
-verbose \
-N 5 \
-K 1 \
-embedding_dir $(EMBEDDING_DIR) \
-dataset miniImageNet \
-exp_name toy-example \
-save_checkpoint
where
-N,-Kmeans N-way K-shot training,-exp_namehelp you keep track of your experiment,-save_checkpointto save model for later testing.
for full arguments, see main.py
python3 main.py -test \
-N 5 \
-K 1 \
-embedding_dir $(EMBEDDING_DIR) \
-dataset miniImageNet \
-verbose \
-load $(model_path)
The testing result will be printed on the console.
This projects comes with Comet.ml support. If you want to disable logging, just add -disable_comet as an argument.
You will need to modify the COMET_PROJECT_NAME and COMET_WORKSPACE in config.yml to enable monitoring.
*If you do not save your comet API key in .comet.config, you will have to specify API key in line 147 in solver.py.
You can modify the hyperparameters in config.yml, where detailed descriptions are also provided.
The hyperparameters that yield the best result in this code are as follow:
| Hyperparameters | miniImageNet 1-shot | miniImageNet 5-shot | tieredImageNet 1-shot | tieredImageNet 5-shot |
|---|---|---|---|---|
outer_lr |
0.0005 | 0.0006 | 0.0006 | 0.0006 |
l2_penalty_weight |
0.0001 | 8.5e-6 | 3.6e-10 | 3.6e-10 |
orthogonality_penalty_weight |
303.0 | 0.00152 | 0.188 | 0.188 |
dropout |
0.3 | 0.3 | 0.3 | 0.3 |
kl_weight |
0 | 0.001 | 0.001 | 0.001 |
encoder_penalty_weight |
1e-9 | 2.66e-7 | 5.7e-6 | 5.7e-6 |
| Implementation | miniImageNet 1-shot | miniImageNet 5-shot | tieredImageNet 1-shot | tieredImageNet 5-shot |
|---|---|---|---|---|
| LEO Paper | 61.76 ± 0.08% | 77.59 ± 0.12% | 66.33 ± 0.05% | 81.44 ± 0.09% |
| this code | 59.46 ± 0.08% | 76.01 ± 0.09% | 66.62 ± 0.07% | 81.72 ± 0.09% |
*The result we obtained may not be comparable since the model is trained on both the training set and validation set in the paper, while our model is only trained on the training set and validated on the validation set.
Note: This project is licensed under the terms of the MIT license.