Skip to content

A simple toolkit for training semantic segmentation models from autolableled data

Notifications You must be signed in to change notification settings

pepisg/simple_segmentation_toolkit

Repository files navigation

Simple Segmentation Toolkit

End-to-end pipeline for semantic segmentation - A proof-of-concept toolkit for quickly creating simple segmentation models from zero-shot labels to trained PyTorch models.

This toolkit combines:

Features

  • Zero-shot Auto-Labeling: Use Grounded SAM to generate masks without any training data
  • Interactive Review GUI: Accept/reject masks with live ontology editing
  • Model Training: Train efficient segmentation models with real-time visualization
  • Inference Script: Test trained models on new RGB images

Requirements

  • GPU with at least 8GB VRAM (required for Grounded SAM during labeling)
  • CUDA-compatible GPU recommended for training
  • Python 3.10

Installation

Prerequisites: GPU with at least 8GB VRAM for Grounded SAM labeling (training can work with less VRAM or CPU).

1. Create Conda Environment

conda create -n sst python=3.10 -y
conda activate sst

2. Install Dependencies

pip install -r requirements.txt

Note: For GPU support, ensure CUDA is installed. If needed, install PyTorch with CUDA:

pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121

Quick Start Workflow

Step 1: Prepare Your RGB Images

The toolkit includes an example_dataset/ folder with sample RGB images captured from a TurtleBot 4 in Gazebo Baylands world. To get started, create the data directory and copy images to the working directory:

mkdir -p data/raw_images
cp example_dataset/*.png data/raw_images/

For your own data, replace example_dataset/ with your own folder containing RGB images:

mkdir -p data/raw_images
cp /path/to/your/images/*.png data/raw_images/
# or
cp /path/to/your/images/*.jpg data/raw_images/

No labels? No problem! The toolkit will generate them for you using Grounded SAM.


Step 2: Auto-Label with SAM GUI

Run the interactive labeling GUI to generate and review segmentation masks using Roboflow's Autodistill with Grounded SAM:

Image

python3 autolabel.py --input_dir data/raw_images

Note: This step requires a GPU with at least 8GB VRAM. Grounded SAM combines Grounding DINO for text-based detection with SAM for segmentation.

How Images Move During Labeling

The labeling GUI processes images from data/raw_images/ and moves them based on your decisions:

When you ACCEPT an image (press N or Enter):

  1. Image is moved from data/raw_images/data/labeling/accepted/images/
  2. Generated mask is saved to data/labeling/accepted/masks/ with _mask.png suffix
  3. The mask contains class IDs: 0 (background), 1 (first class), 2 (second class), etc.
  4. Original filename is preserved with a unique UUID to prevent overwrites

When you DISCARD an image (press S):

  1. Image is moved from data/raw_images/data/labeling/discarded/
  2. No mask is saved

Result: After labeling, data/raw_images/ will be empty, and all images will be in either accepted/ or discarded/.

GUI Controls

  • N / Enter: Accept current mask and move to next image
  • S: Skip/discard current image
  • R: Reprocess current image with updated ontology
  • Q / Esc: Quit

Runtime Ontology Editing

You can edit class prompts and colors in real-time using the built-in YAML editor:

  1. Modify the ontology text in the GUI
  2. Click "Apply Ontology"
  3. Press R to reprocess with new settings

Default classes (defined in configs/ontology.yaml):

  • sidewalk - Blue overlay
  • grass - Green overlay

Step 3: Train a Segmentation Model

Once you have accepted masks in data/labeling/accepted/, train a model using SuperGradients with the interactive training GUI:

python3 train.py --epochs 100

Image

How Images Move During Training

When training starts, the script automatically:

  1. Reads accepted images from data/labeling/accepted/images/ and data/labeling/accepted/masks/
  2. Copies them (doesn't move) to data/training/ and splits into:
    • data/training/train/ - 90% of data (used for training)
    • data/training/val/ - 10% of data (used for validation)
    • data/training/test/ - 0% by default (optional holdout set)
      (These are the default splits, but you can configure the percentages using command-line arguments.)
  3. Each split has images/ and masks/ subdirectories
  4. Original files in accepted/ remain unchanged - you can retrain anytime

Note: Training splits are regenerated each time you run train.py, so you can adjust split ratios without losing data.

Training Options

# Basic training (default: DDRNet-39)
python3 train.py --epochs 100

# Different model architecture
python3 train.py --model pp_lite_b_seg --epochs 50

# Adjust data splits (default: 90% train, 10% val, 0% test)
python3 train.py --train_split 0.8 --val_split 0.2

# Disable data augmentation
python3 train.py --no-augmentation

# Train without GUI (console only)
python3 train.py --no-gui --epochs 100

# Smaller input size (faster training, less memory)
python3 train.py --input_size 256

# Smaller batch size (less memory)
python3 train.py --batch_size 2

Configuration

Class definitions and model settings are configured in configs/ontology.yaml:

ontology:
  classes:
    - name: sidewalk
      prompt: sidewalk
      color: [255, 0, 0]  # BGR: Blue
    
    - name: grass
      prompt: grass, lawn
      color: [0, 255, 0]  # BGR: Green

model:
  device: cuda               # cuda or cpu
  max_image_dim: 1024        # Max image dimension (for SAM, saves GPU memory)
  mask_opacity: 0.15         # Mask overlay opacity in GUI
  border_width: 1            # Mask border width

The num_classes for training is automatically calculated as len(classes) + 1 (includes background class 0).

Training GUI

The GUI shows real-time training progress:

  • Left Panel (40%): Loss and mIoU curves with logarithmic scaling
  • Right Panel (60%): Live prediction samples from train/val sets
  • Bottom: Training log with epoch metrics

Model Checkpoints

Trained models are saved to:

models/checkpoints/{model_name}_custom/
├── best_model.pth      # Best validation loss
└── latest_model.pth    # Most recent epoch

Data Augmentation

Data augmentation is enabled by default to prevent overfitting with limited data. Augmentations include:

  • Geometric transforms: flips, rotations, scale/shift
  • Color transforms: brightness, contrast, hue/saturation
  • Noise and blur: Gaussian noise, motion blur
  • Weather effects: random shadows

Disable with --no-augmentation if needed.

Supported Models

All SuperGradients segmentation models are supported. The default is DDRNet-39 for its excellent accuracy/speed tradeoff.

See the full list of available models in the SuperGradients source: https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/common/object_names.py#L211

Popular choices:

  • ddrnet_39 (default) - Good accuracy/speed balance
  • pp_lite_b_seg - Very fast
  • pp_lite_t_seg - Fastest
  • regseg - Good for edge devices
  • stdc - STDC segmentation

Step 4: Run Inference

Test your trained model on new RGB images:

# Random image from data folders + auto-detect latest model
python3 infer.py

# Specific image
python3 infer.py data/raw_images/my_image.png

# Specific checkpoint
python3 infer.py my_image.png --checkpoint models/checkpoints/ddrnet_39_custom/best_model.pth

The script will:

  1. Load the model checkpoint
  2. Run inference with proper preprocessing (ImageNet normalization)
  3. Display the segmentation overlay (70% image + 30% colored mask)

Press any key to close the result window.


Step 5: Convert Model to ONNX (For ROS2 Deployment)

For ROS2 deployment, convert the trained model to ONNX format to avoid super-gradients dependency in production:

python3 convert_to_onnx.py

This creates models/model.onnx which can be used with ONNX Runtime (much lighter than PyTorch + super-gradients). The ROS2 semantic_segmentation_node package uses this ONNX model for inference.

Why ONNX?

  • No super-gradients dependency (~2GB lighter)
  • Python 3.12+ compatible (super-gradients has issues)
  • Faster inference with ONNX Runtime
  • Single file deployment

Note: The conversion script looks for models/checkpoints/*/latest_model.pth and configs/ontology.yaml. Copy the ONNX file to your ROS2 package's models folder.


Using Your Model in Code

import torch
import cv2
import numpy as np
from super_gradients.training import models

# Load model
num_classes = 3  # 2 classes + background
model = models.get("ddrnet_39", num_classes=num_classes)
checkpoint = torch.load("models/checkpoints/ddrnet_39_custom/best_model.pth")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Load and preprocess RGB image
image = cv2.imread("test_image.png")
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input_tensor = torch.from_numpy(rgb_image).permute(2, 0, 1).float() / 255.0

# Apply ImageNet normalization
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
input_tensor = (input_tensor - mean) / std

# Run inference
with torch.no_grad():
    output = model(input_tensor.unsqueeze(0))
    if isinstance(output, (list, tuple)):
        output = output[0]
    prediction = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
    # prediction shape: [H, W] with class IDs

# prediction[i, j] contains:
#   0 for background
#   1 for first class
#   2 for second class, etc.

Data Flow Summary

Your images
    ↓ (copy to data/raw_images/)
data/raw_images/*.png
    ↓ (autolabel.py - review with GUI)
    ├─ Accept → data/labeling/accepted/images/*.png
    │           data/labeling/accepted/masks/*_mask.png
    │               ↓ (train.py - auto-split)
    │           data/training/train/{images,masks}/
    │           data/training/val/{images,masks}/
    │               ↓ (training)
    │           models/checkpoints/{model}_custom/*.pth
    │               ↓ (infer.py)
    │           Segmentation predictions on new images
    │
    └─ Discard → data/labeling/discarded/*.png

Key Points:

  • Labeling moves images out of raw_images/
  • Training copies from accepted/ to training/
  • You can always retrain from accepted/ with different settings

Tips

Labeling

  • Review all masks carefully before accepting
  • Use specific prompts for better SAM results
  • Discard poor quality images early
  • Aim for 100+ accepted images for reasonable model performance

Training

  • Monitor validation loss - stop if it plateaus or increases
  • Start with 50-100 epochs
  • Try different models if accuracy is insufficient
  • Data augmentation helps with limited data (keep it enabled)

Inference

  • Input must be RGB images (same as training)
  • Preprocessing must match training (ImageNet normalization)
  • Use best_model.pth for best performance
  • Use latest_model.pth to resume training

About

A simple toolkit for training semantic segmentation models from autolableled data

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages