diff --git a/README.md b/README.md index 8501f3b..26d7a15 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,8 @@ Neural networks come from different sources. With `thingsvision`, you can extrac - `dino-rn50`, `dino-xcit-{small/medium}-{12/24}-p{8/16}` - `dino-vit-{tiny/small/base}-p{8/16}` - `dinov2-vit-{small/base/large/giant}-p14` - - `mae-vit-{base/large}-p16`, `mae-vit-huge-p14`
+ - `mae-vit-{base/large}-p16`, `mae-vit-huge-p14` + - `capi-vitl14-{p205/lvd/in22k/in1k}` (trained on different datasets)
- [OpenCLIP](https://github.com/mlfoundations/open_clip) models (CLIP trained on LAION-{400M/2B/5B}) - [CLIP](https://github.com/openai/CLIP) models (CLIP trained on WiT) - a few custom models (Alexnet, VGG-16, Resnet50, and Inception_v3) trained on [Ecoset](https://www.pnas.org/doi/10.1073/pnas.2011417118)
diff --git a/requirements.txt b/requirements.txt index 8bcb4c0..9e68788 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ scikit-learn scipy tensorflow<2.16 timm -torch>=2.0.0 +torch>=2.4.0 torchvision==0.15.2 torchtyping tqdm @@ -24,3 +24,7 @@ keras-cv-attention-models>=1.3.5 vit-keras==0.1.2 git+https://github.com/serre-lab/Harmonization.git dreamsim==0.1.3 +jaxtyping +omegaconf +einops +rich \ No newline at end of file diff --git a/tests/helper.py b/tests/helper.py index 2733a98..4271f83 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -219,6 +219,13 @@ "source": "ssl", "kwargs": {"token_extraction": "avg_pool"}, }, + "capi-vitl14-in1k": { + "model_name": "capi-vitl14-in1k", + "modules": ["norm", "fc_norm"], + "pretrained": True, + "source": "ssl", + "kwargs": {"token_extraction": "cls_token"}, + }, # Additional models "Harmonization_visual_ResNet50": { "model_name": "Harmonization", diff --git a/thingsvision/core/extraction/extractors.py b/thingsvision/core/extraction/extractors.py index ec6728a..6ba4df3 100644 --- a/thingsvision/core/extraction/extractors.py +++ b/thingsvision/core/extraction/extractors.py @@ -383,6 +383,26 @@ class SSLExtractor(PyTorchExtractor): "type": "hub", "checkpoint_url": "https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth", }, + "capi-vitl14-p205": { + "repository": "facebookresearch/capi:main", + "arch": "capi_vitl14_p205", + "type": "hub", + }, + "capi-vitl14-lvd": { + "repository": "facebookresearch/capi:main", + "arch": "capi_vitl14_lvd", + "type": "hub", + }, + "capi-vitl14-in22k": { + "repository": "facebookresearch/capi:main", + "arch": "capi_vitl14_in22k", + "type": "hub", + }, + "capi-vitl14-in1k": { + "repository": "facebookresearch/capi:main", + "arch": "capi_vitl14_in1k", + "type": "hub", + }, } def __init__(