This project focuses on training and evaluating neural networks for handwritten digit classification using the MNIST dataset. The code is organized into three main files: S5.ipynb, model.py, and utils.py.
This Jupyter Notebook contains the main code for the MNIST classification task. The notebook is structured as follows:
-
Imports and Setup
- Import necessary libraries and modules.
- Check for CUDA availability.
- Define data transformations and create data loaders.
-
Data Visualization
- Display sample images from the training dataset.
-
Model Summary
- Utilize
torchsummaryto display the summary of the neural network model (NetandNet2).
- Utilize
-
Model Training
- Train the
Netmodel using SGD optimizer, learning rate scheduler, and negative log-likelihood loss. - Print training and testing accuracies over epochs.
- Train the
-
Visualize Training Progress
- Plot training and testing losses, as well as training and testing accuracies.
This file contains the definition of the neural network models used for the MNIST classification task:
-
Net Class
- Convolutional Neural Network with four convolutional layers and two fully connected layers.
- Implements the forward pass with ReLU activation and log-softmax output.
-
Net2 Class
- Similar to
Netbut with bias terms in convolutional and fully connected layers set to False.
- Similar to
This file includes utility functions and data transformations:
-
Data Transformations
- Defines training and testing data transformations using
torchvision.transforms.
- Defines training and testing data transformations using
-
Neural Network Training Functions
GetCorrectPredCount: Returns the count of correct predictions.train: Performs training of the neural network.test: Evaluates the neural network on the test dataset.
-
Graph Plotting Function
allgraphs: Plots training and testing losses, as well as training and testing accuracies.
- Open and run the cells in
S5.ipynbsequentially to train and evaluate the model. - Do not forget to add 'model.py' and 'utils.py' as auxillary files onto the Colab environment
- Ensure that required libraries (
torch,torchvision,matplotlib,tqdm) are installed. - Make sure to have GPU support for faster training (if available).