This is intended to be a flax version of the pytorch-cifar repo that I've often used as a reference.
The ResNet models have 'imagenet' and 'cifar' variants, which are intended for 224x224 and 32x32 images respectively.
The WideResNet and ResNetV2 implementations were based on the objax.zoo package.
The ResNetV1 implementation was based on the flax imagenet example.
The DenseNet and VGG implementations were based on the pytorch-cifar repo.
All results using default config (see configs/default.py).
Refer to the wandb project.
| Model | Params | Acc. |
|---|---|---|
| VGG-11 backbone | 9,228,362 | 90.6% |
| VGG-13 backbone | 9,413,066 | 92.7% |
| VGG-16 backbone | 14,724,042 | 92.1% |
| VGG-19 backbone | 20,035,018 | 92.3% |
| VGG-11 | 28,154,954 | 89.7% |
| VGG-13 | 28,339,658 | 91.9% |
| VGG-16 | 33,650,634 | 92.2% |
| VGG-19 | 38,961,610 | 90.3% |
| ResNetV1-18 | 11,173,962 | 94.1% |
| ResNetV1-50 | 23,520,842 | 94.0% |
| ResNetV2-18 | 11,172,170 | 94.7% |
| ResNetV2-50 | 23,513,162 | 93.8% |
| WideResNet-28-2 | 1,467,610 | 93.3% |
| WideResNet-28-8 | 23,354,842 | 95.1% |
| DenseNet121-12 | 1,000,618 | 93.6% |
| DenseNet121-32 | 6,956,298 | 94.8% |
| DenseNet169-32 | 12,493,322 | 94.2% |
python main.py --dataset_root=path/to/cifar10 --config=configs/default.py --config.arch=resnet_v1_18absl-py
flax
jax[cuda]
jaxopt
ml_collections
numpy
torch
torchvision
tqdm
wandbTo install jax with CUDA, refer to the installation instructions.