Visual Galaxy Classification Using Convolutional Neural Networks
In this project we used two methods to classify images from the Galaxy10 DECaLS dataset. One method is a Convolutional Neural Network. The other one is a Random Forest.
Image created by @nicktky
This project uses the Galaxy10 DECaLS dataset by Leung, W. Henry and Bovy, Jo, which can be found under the following DOI: https://doi.org/10.5281/zenodo.10845025.
This project is created as an installable python package for python versions >= 3.12.
To install it, you first have to clone this repository. After that you can execute the following command inside of the cloned directory:
$ pip install -e .
Important
The required packages to execute the commands and scripts are installed automatically. Some of the scripts require the availability of a pytorch supported GPU.
Important
To use the results of our training and optimization, you can use the files in the best_models directory in this repository.
It contains
- The
vacation.sqlite3file, which contains the Optuna optimization studies. - The
vacation_v2.pyfile, which is the exported best CNN model of the Optuna study. - The
rf_optimized.ziparchive, which contains thejoblibdump of the best Random Forest model
The training and analysis of the project can be reproduced using the built in Command-Line Interface (CLI) of this project. With that you are able to
- Create and process the datasets
- Start the hyperparameter optimization of the CNN
- Start the hyperparameter optimization of the Random Forest
- Create visualizations for the dataset, hyperparameter optimization and the evaluation
If you are unsure about the usage of a command or its arguments, you can use the --help flag
in order to get an overview of the command.
Example:
$ vacation --help
$ vacation optim cnn --helpTo create and process the dataset you will first have to choose a directory where the data should be created.
Warning
The entire dataset collection requires a disk space of about 8.6 GiB.
Enter your chosen directory and execute the following command:
$ vacation dataset create ./Feel free to adjust the arguments of this command like the memory consumption.
Tip
If you want the dataset to overwrite or redownload parts of the dataset, use the --overwrite or --redownload flags.
Important
These steps require GPU support!
After downloading and generating all necessary datasets, you can proceed to start the Optuna hyperparameter optimization. For that go to a directory where you want your Optuna study to be saved. This won't take up much disk space.
Then you can use
$ vacation optim cnn PATH/TO/TRAIN_DATASET PATH/TO/VALID_DATASET --checkpoint-dir PATH/TO/DESIRED/CHECKPOINT_DIRFor the train and validation dataset, provide the paths of the created files Galaxy10_DECals_train.h5 and Galaxy10_DECals_valid.h5 datasets
you created previously.
The --checkpoint-dir path can be chosen freely. This can also take up some disk space but not as much as the datasets, only about 1.2 GiB.
Important
These steps require GPU support!
To start the Random Forest optimization, use the following command at in an arbitrary location:
$ vacation optim rf PATH/TO/TRAIN_DATASET PATH/TO/VALID_DATASET PATH/TO/NON_AUGMENTED_TRAIN_DATASET --checkpoint-dir PATH/TO/DESIRED/CHECKPOINT_DIRFor the train and validation dataset, provide the paths of the created files Galaxy10_DECals_train.h5 and Galaxy10_DECals_valid.h5 datasets
you created previously. The non augmented training dataset can be found in the same directory with the name Galaxy10_DECals_proc_train.
The --checkpoint-dir path can be chosen freely. This can also take up some disk space but not as much as the datasets, only about 515 MiB.
The CLI has multiple visualizations:
- Plot of the class distribution, example images and augmentation examples of a HDF5 dataset using
vacation dataset plot - Plot of a HOG feature extraction example with
vacation dataset hog - Plots of the Random Forest test evaluation results with
vacation rf eval - Plots of the CNN hyperparameter optimization results with
vacation cnn plot_metricandvacation optim plot_importance - Plots of the CNN test evaluation results with
vacation cnn eval
For further information on these commands use the --help flag.
Tip
The test dataset can be found under the file name Galaxy10_DECals_proc_test.h5 in the data directory.
Warning
Exported CNN models (.pt files) contain a parameter determining the location of the train and validation datasets on the system the model was created on.
If you want to use the provided models, you will have to provide these values yourself. There should be some kind of --dataset-directory option that you
have to set to your dataset directory (not file!). If you are using the functions from the code itself, you can provide the datasets as arguments to the
functions!