A package to use CNN to classify the MNIST dataset
-
Install Requirements: Run
invoke requirementsto install all required libraries listed inrequirements.txt. This ensures your environment is set up correctly. -
Preprocess the Data: Run
invoke preprocess-datato preprocess the raw MNIST data and save it in thedata/processedfolder. This will normalize the images (mean = 0, std = 1) and save the processed data as.ptfiles. -
Train the Model: Run
invoke trainto train the CNN model on the MNIST dataset. The script will load the processed data fromdata/processedand save the trained model checkpoint in themodels/folder. -
Visualize the Embeddings: Run
python src/mnist_classifier/visualize.py --model-checkpoint <path_to_model_checkpoint>to visualize the model's predictions and embeddings. Replace<path_to_model_checkpoint>with the path to your saved model checkpoint (e.g.,models/my_model.pth). The script will generate a t-SNE visualization of the embeddings and save it asembeddings.pngin thereports/figures/folder.
The directory structure of the project looks like this:
├── .github/ # Github actions and dependabot
│ ├── dependabot.yaml
│ └── workflows/
│ └── tests.yaml
├── configs/ # Configuration files
├── data/ # Data directory
│ ├── processed
│ └── raw
├── dockerfiles/ # Dockerfiles
│ ├── api.Dockerfile
│ └── train.Dockerfile
├── docs/ # Documentation
│ ├── mkdocs.yml
│ └── source/
│ └── index.md
├── models/ # Trained models
├── notebooks/ # Jupyter notebooks
├── reports/ # Reports
│ └── figures/
├── src/ # Source code
│ ├── project_name/
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── data.py
│ │ ├── evaluate.py
│ │ ├── models.py
│ │ ├── train.py
│ │ └── visualize.py
└── tests/ # Tests
│ ├── __init__.py
│ ├── test_api.py
│ ├── test_data.py
│ └── test_model.py
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── pyproject.toml # Python project file
├── README.md # Project README
├── requirements.txt # Project requirements
├── requirements_dev.txt # Development requirements
└── tasks.py # Project tasksCreated using mlops_template, a cookiecutter template for getting started with Machine Learning Operations (MLOps).