The goal of this group project is to develop and compare different deep learning computer vision methods for the semantic segmentation of dead trees in aerial images. The dataset used for this project is , which contains RGB and NRG (NIR-RG) images.
This project is implemented using and
, and it includes over 60+ deep learning methods for semantic segmentation powered by
. And, we also try zero-shot and fine-tuning with
.
Table of contents:
- COMP9517 Computer Vision 25T2 Group Project
This project has a Zotero library that contains references to papers, articles, and other resources relevant to the project. You can access the Zotero library at the following link:
https://www.zotero.org/groups/6056458/cvers
This project is organized into several directories and files, each serving a specific purpose. Below is an overview of the project structure:
CV9517_Group-Project/
├── assets/ # Directory for assets (e.g., images, icons)
├── checkpoints/ # Directory for model checkpoints (Local storage)
├── data/ # Directory for data processing scripts
│ ├── __init__.py # Init file for data processing module
│ ├── datamodule.py # Script for lightning datamodule (Supports merged, RGB, and NRG modalities)
│ ├── dataset.py # Script for dataset class
│ ├── transforms.py # Script for data transformations
│ └── utils.py # Utility functions for data processing
├── datasplits # Directory for data splits csv files
│ └── data_split_42_70_10.csv # seed 42, train 70%, val 10%, test 20%
├── gradio/ # Directory for Gradio app script
│ └── app.py # Script for Gradio app (RGB modality only for now)
├── lighting_modules # Directory for lightning modules
│ ├── __init__.py # Init file for lightning modules
│ ├── sam2_module # Script for SAM2 lightning module (Future work)
│ ├── segmentation_module # Base lightning module for this dead tree segmentation project
│ ├── smp_module # Script for Segmentation Models PyTorch lightning module
│ └── u2net_module # Script for U2Net lightning module
├── logs/ # Directory for logs (Local storage, e.g., TensorBoard logs)
├── models/ # Directory for model scripts
│ ├── __init__.py # Init file for models module
│ ├── smp_models_util.py # Utility functions for models of Segmentation Models PyTorch
│ └── u2net.py # Script for U2Net model for semantic segmentation
├── notebooks/ # Directory for Jupyter notebooks (It is for simple demonstration and testing)
│ └─ segmentation_models.ipynb # Notebook for testing Segmentation Models PyTorch
├── outputs/ # Directory for output files (e.g., model predictions, evaluation logs)
├── sam2/ # Directory for SAM2 (You should clone files from the SAM2 official GitHub repository)
├── scripts/ # Directory for scripts
│ ├── sam2_fine_tune.py # Script for fine-tuning SAM2 model
│ ├── sam2_ft_lightning.py # Script for fine-tuning SAM2 model using PyTorch Lightning (Future work)
│ ├── sam2_inference_vis.py # Script for inference using SAM2 model (Visualize the results)
│ ├── sam2_zero_shot.py # Script for test zero-shot segmentation using SAM2
│ ├── test_sam2_ft.py # Script for testing fine-tuned SAM2 model
│ ├── test_smp.py # Script for testing Segmentation Models PyTorch model
│ ├── train_smp.py # Script for training Segmentation Models PyTorch model (Over 60 models, with different architectures & feature extractors, on 3 modalities)
│ └── train_u2net.py # Script for training U2Net model (Future work)
├── utils/ # Directory for utility scripts
│ ├── __init__.py # Init file for utils module
│ ├── callbacks.py # Script for abstracting lightning trainer callbacks
│ ├── logger.py # Script for abstracting logger construction for each kind of activity
│ └── paths.py # Script for paths management
├── .gitignore # Git ignore file to exclude unnecessary files from version control
├── environment.yaml # Conda environment file for dependencies
├── README.md # Project overview and instructions
└── requirements.txt # Pip requirements file
git clone https://github.com/ParzHe/CV9517_Group-Project.git
cd CV9517_Group-ProjectWe suggest you to use Linux for the trying. If you are using Windows, we suggest you use WSL (Windows Subsystem for Linux).
If you do not have conda installed, you can install it from the Anaconda website.
-
Create a
CVersconda environment using the providedenvironment.yamlfile:# in the root directory of the project conda env create -f ./environment.yaml -
Activate the
CVersenvironment:conda activate CVers
-
Install
sam2:git clone https://github.com/facebookresearch/sam2.git && cd sam2 pip install -e .
-
Download the SAM2 model weights with the following command:
cd sam2/checkpoints bash download_ckpts.sh -
Then back to the root directory of the project:
cd ../..
Our environment can be activated using the following commands:
conda activate CVers # For the Pytorch Lightning environment.Person in Charge: Lintao He 何林涛
Participants (listed in alphabetical order by name): Bowei Cheng 程柏威, Chencan Que 阙晨灿, Zhen Yang 杨震, Zitong Wei 魏子童
You can find the papers by directly clicking the architecture and feature extractor names below.
- Unet
- Unet++: Do not support
Mix Vision Transformer (MixViT)as the feature extractor (encoder). - Linknet: Do not support
Mix Vision Transformer (MixViT)as the feature extractor (encoder). - FPN
- PSPNet
- PAN: Do not support
DenseNetas the feature extractor (encoder). - DeepLabV3: Do not support
DenseNetas the feature extractor (encoder). - DeepLabV3+: Do not support
DenseNetas the feature extractor (encoder). - UperNet
- SegFormer
Select the following feature extractors (encoders) because they have close parameters.
- DenseNet161: Do not support
DeepLabV3,DeepLabV3+andPANarchitectures. - EfficientNet-B5
- Mix Vision Transformer (MixViT): Do not support
Unet++andLinknetarchitectures. - ResNet50
- ResNeXt50_32x4d
- SE-ResNet50 and SE-ResNeXt50_32x4d
Person in Charge: Zhen Yang(杨震)
Participant: Lintao He 何林涛
Refer to:
- The GitHub repository
- Enabling Meta’s SAM 2 model for Geospatial AI on satellite imagery
- axXiv paper: "Customized SAM 2 for Referring Remote Sensing Image Segmentation"
- axXiv paper: "Zero-Shot Tree Detection and Segmentation from Aerial Forest Imagery"
Dev Environment: CVers or other environments as needed
Note
You need to manually clone the SAM2 repository and download the model weights as described in 2.2.2 Subsection.
The code will automatically download the dataset from Kaggle, with the kagglehub package. If you want to delete the dataset, you need to go to the .cache/kagglehub/datasets directory. And, the datasplit csv files are stored in the datasplits directory. You can set different seeds and split ratios by modifying the parameters of AerialDeadTreeSegDataModule in different training or test scripts.
For Segmentation Models PyTorch, the pre-trained weights will be automatically downloaded. If you want to delete the pre-trained weights, you need to go to the ~/.cache/huggingface/hub directory to delete as you wish.
For SAM2, the pre-trained weights should download manually as described in the 2.2.2 Subsection.
To train the Segmentation Models PyTorch, you can use the following command in the root directory of the project:
conda activate CVers # Activate the CVers environment
python scripts/train_smp.pyThis will train all the architectures and feature extractors specified in the arch_list and encoder_only variables in the scripts/train_smp.py file. The training will be performed on the merged, RGB and NRG modalities by default.
Tip
You can modify the BATCH_SIZE and ACCUMULATE_GRAD_BATCHES variables in the scripts/train_smp.py file to adjust the batch size and gradient accumulation. The default values are BATCH_SIZE = 32 and ACCUMULATE_GRAD_BATCHES = 1. If you encounter out-of-memory (OOM) errors, you can try reducing the batch size and increasing the gradient accumulation.
4.2.1.1 Train on a specific architecture or feature extractor:
If you want to train a specific architecture or feature extractor, you can modify the arch_list and encoder_only variables in the scripts/train_smp.py file. For example, to train the Unet architecture with DenseNet161 as the feature extractor, you can set:
arch_list = ['Unet']
encoder_only = ['densenet161']4.2.1.2 Train on a specific modality:
If you want to train on a specific modality, you can modify the modality_list variable in the scripts/train_smp.py file. For example, to train only on the RGB modality, you can set:
modality_list = ['rgb']4.2.1.3 Loss Function:
The default loss function is combination of JaccardLoss and FocalLoss. You can modify the loss function in the scripts/train_smp.py file by changing the LOSS1 and LOSS2 variables. For example, to use only JaccardLoss, you can set:
LOSS1 = 'JaccardLoss'
LOSS2 = None4.2.1.4 Early Stopping:
The training script will automatically stop if the validation loss does not improve for 30 epochs. This can avoid overfitting and save training time. You can modify the EARLY_STOP_PATIENCE variable in the scripts/train_smp.py file to change the patience value.
4.2.1.5 Checkpoints Saving:
The training script will automatically save the best 2 model checkpoints in the checkpoints/ directory, each models have 3 modalities (merged, RGB, NRG) folders to save the checkpoints. The checkpoints will be saved with the following naming format:
smp_{encoder}_{arch}/{modality}_{target size}_{version suffix}/{epoch}-{val per_image_mIou}.ckpt
Where:
{encoder}: The feature extractor (encoder) name.{arch}: The architecture name.{modality}: The modality name (merged, rgb, nrg).{target size}: The target size of the input images.{version suffix}: The version suffix, can be changed as needed.{epoch}: The epoch number of the saving model.{val per_image_mIou}: The validation imagewise mean Intersection over Union (mIoU) value.
Note
- The training script will automatically search the suggested learning rate using the
find_lrmethod from thelightninglibrary. - The training script will automatically log the training and validation metrics to the
logs/directory, which can be viewed using TensorBoard. And, logging summary will also be done in the certaincheckpoints/directory.
To fine-tune the SAM2 model, you can use the following command in the root directory of the project:
conda activate CVers # Activate the CVers environment
python scripts/sam2_fine_tune.pyTo test the Segmentation Models PyTorch, you can use the following command:
conda activate CVers # Activate the CVers environment
python scripts/test_smp.pyThis script will automatically load the best model checkpoints from the checkpoints/ directory and perform inference on the test dataset split with the specified split and modality. The results will generate a csv file in the outputs/smp_test_results directory.
4.3.2.1 Zero-Shot Segmentation
For zero-shot segmentation using SAM2, you can use the following command:
conda activate CVers # Activate the CVers environment
python scripts/sam2_zero_shot.pyThis script will perform zero-shot segmentation using the SAM2 model on the test dataset split. The results will be saved in the outputs/sam2_zs_inference directory.
4.3.2.2 Fine-Tuning Test
For fine-tuning the SAM2 model, you can use the following command:
conda activate CVers # Activate the CVers environment
python scripts/test_sam2_ft.pyThis script will load the fine-tuned SAM2 model and perform inference on the test dataset split. The results will be saved in the outputs/sam2_ft_inference directory.
- Per Image IoU: Image-by-image calculation IoU and then average
- Dataset IoU: IoU calculated on the whole dataset
- F1 Score: Harmonic mean of precision and recall
- F2 Score: Harmonic mean of precision and recall with more emphasis on recall
- Accuracy: Ratio of correct pixel predictions to the total predictions
- Precision: Ratio of true positive predictions to the total predicted positives
- Recall: Ratio of true positive predictions to the total actual positives
- Sensitivity: True positive rate, same as recall
- Specificity: True negative rate, ratio of true negative predictions to the total actual negatives
- Test Time (Seconds): Time taken to perform inference on the test dataset
| Metric | RGB-NIR | NIR-RG | RGB |
|---|---|---|---|
| Per Image IoU | 0.4481 | 0.4290 | 0.4334 |
| Dataset IoU | 0.4605 | 0.4379 | 0.4472 |
| F1 Score | 0.6020 | 0.5847 | 0.5847 |
| F2 Score | 0.6089 | 0.5938 | 0.5898 |
| Accuracy | 0.9839 | 0.9830 | 0.9836 |
| Precision | 0.6342 | 0.6139 | 0.6239 |
| Recall | 0.6235 | 0.6107 | 0.6038 |
| Sensitivity | 0.6235 | 0.6107 | 0.6038 |
| Specificity | 0.9922 | 0.9918 | 0.9924 |
The result is the mean of all the architectures and feature extractors tested on the RGB-NIR, NIR-RG, and RGB modalities.
| Architecture | Per Image IoU | Dataset IoU | Test Time (Seconds) | |||
|---|---|---|---|---|---|---|
| mean | max | mean | max | mean | max | |
| DeepLabV3 | 0.4216 | 0.4505 | 0.4358 | 0.4602 | 3.0977 | 5.7786 |
| DeepLabV3Plus | 0.4199 | 0.4528 | 0.4324 | 0.4655 | 2.5329 | 3.7419 |
| FPN | 0.4469 | 0.4759 | 0.4562 | 0.4911 | 2.7698 | 3.8235 |
| Linknet | 0.4272 | 0.4719 | 0.4400 | 0.4861 | 2.8194 | 4.2640 |
| PAN | 0.4427 | 0.4624 | 0.4538 | 0.4762 | 4.9664 | 37.4820 |
| PSPNet | 0.4202 | 0.4410 | 0.4369 | 0.4580 | 1.8923 | 3.0223 |
| Segformer | 0.4409 | 0.4630 | 0.4525 | 0.4802 | 2.6860 | 4.1915 |
| UPerNet | 0.4488 | 0.4742 | 0.4577 | 0.4812 | 2.8821 | 4.2409 |
| Unet | 0.4400 | 0.4810 | 0.4503 | 0.4930 | 2.8149 | 3.8503 |
| Unet++ | 0.4497 | 0.4807 | 0.4587 | 0.4970 | 3.5455 | 5.4500 |
| Unet with scse | 0.4441 | 0.4790 | 0.4567 | 0.4959 | 3.1273 | 4.1764 |
The result is the mean of all the feature extractors with all three modalities tested on the different architectures. The Unet with scse architecture is a modified version of the Unet architecture with Squeeze-and-Excitation (SE) blocks on decoder.
| Backbone | Per Image IoU | Dataset IoU | Test Time (Seconds) | |||
|---|---|---|---|---|---|---|
| mean | max | mean | max | mean | max | |
| Densenet-161 | 0.4292 | 0.4546 | 0.4391 | 0.4745 | 3.7932 | 5.4500 |
| EfficientNet-b5 | 0.4521 | 0.4810 | 0.4658 | 0.4970 | 4.8207 | 37.4820 |
| MixViT-b2 | 0.4354 | 0.4759 | 0.4442 | 0.4911 | 3.1613 | 5.7786 |
| Resnet50 | 0.4243 | 0.4580 | 0.4366 | 0.4762 | 2.1381 | 3.1704 |
| ResneXt50_32x4d | 0.4320 | 0.4696 | 0.4427 | 0.4794 | 2.3084 | 9.5732 |
| SE-Resnet50 | 0.4397 | 0.4781 | 0.4532 | 0.4835 | 2.4685 | 3.2726 |
| SE-ResneXt50_32x4d | 0.4428 | 0.4686 | 0.4549 | 0.4812 | 2.4627 | 3.2390 |
The result is the mean of all the architectures with all three modalities tested on the different feature extractors. The MixViT-b2 is a modified version of the Mix Vision Transformer (MixViT) with a smaller size.
To run the Gradio demo, you can use the following command:
conda activate CVers # Activate the CVers environment
python gradio/app.pyThen, you can open your web browser and go to http://localhost:7860 to see the demo. The demo will allow you to upload an image and perform inference using the Segmentation Models Pytorch (SMP) models (SAM2 in the future work). The results will be displayed on the web page as following:
Note
- The Gradio demo is only supported for use local checkpoints, so you need to run the
scripts/train_smp.pyscript to train the models and save the checkpoints in thecheckpoints/directory before running the Gradio demo. - The Gradio demo currently only supports the RGB modality. So, you need to use the RGB images from the dataset. The NRG and merged modalities will be supported in the future work.
We plan to implement the following features in the future:
- More Models: We will add more models to the project, including U2Net and other segmentation models. This will help us to compare the performance of different models on the same dataset.
- SAM2 zero-shot with prompt: We will explore the use of prompt techniques to improve the zero-shot segmentation capabilities of the SAM2 model in this task.
- Different Loss Functions: We will experiment with different loss functions to improve the performance of the models. This includes trying out different combinations of loss functions and hyperparameters.
