Skip to content

rompoggi/MCTS_ClassifierChain

Repository files navigation

Monte Carlo Tree Search for Classifier Chains

Tests codecov License: MIT

This repository provides an implementation of Monte Carlo Tree Search (MCTS) for inference in Multi-Label Classifier Chains, a novel approach developed as part of a Bachelor Thesis at Ecole Polytechnique.

Classifier Chains are a popular method for multi-label classification, but they traditionally use a greedy approach for inference, which can lead to suboptimal predictions. This project frames the inference problem as a search problem and uses MCTS to explore the label space more intelligently, leading to significant performance improvements over the greedy baseline and achieving results competitive with state-of-the-art methods.

For a detailed explanation of the method, please see the full Bachelor Thesis Report.

Key Features

  • Novel Inference Strategy: A new application of Monte Carlo Tree Search to improve predictions for Classifier Chains.
  • High Performance: Outperforms the standard greedy Classifier Chain and achieves results competitive with state-of-the-art methods like Monte Carlo Classifier Chains (MCC).
  • Flexible Policies: Easily experiment with different MCTS selection and exploration policies, such as UCB and Epsilon-Greedy.
  • Visualization Tools: Includes tools to visualize the MCTS search tree, providing insight into the decision-making process.

Quick Start

The following example shows how to train a ClassifierChain and use MCTS for inference on a synthetic dataset.

from sklearn.datasets import make_multilabel_classification
from sklearn.model_selection import train_test_split
from sklearn.multioutput import ClassifierChain
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import hamming_loss

from mcts_inference import MCTS, MCTSConfig, Constraint
from mcts_inference.policy import UCB

# 1. Create a synthetic dataset
X, Y = make_multilabel_classification(n_samples=100, n_features=20, n_classes=5, n_labels=2, random_state=0)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=0)

# 2. Train a standard Classifier Chain
base_classifier = LogisticRegression(solver="liblinear")
chain = ClassifierChain(base_classifier).fit(X_train, Y_train)

# 3. Use MCTS for inference
config = MCTSConfig(
    n_classes=Y.shape[1],
    selection_policy=UCB(c=2.0),
    constraint=Constraint(max_iter=True, n_iter=100)
)
y_pred_mcts = MCTS(X_test, chain, config)

# 4. Compare with greedy inference
y_pred_greedy = chain.predict(X_test)

print(f"Hamming Loss (Greedy): {hamming_loss(Y_test, y_pred_greedy):.4f}")
print(f"Hamming Loss (MCTS):   {hamming_loss(Y_test, y_pred_mcts):.4f}")

Installation

To get started, clone the repository and install it in editable mode. This will also install all the required dependencies from requirements.txt.

git clone https://github.com/rompoggi/MCTS_ClassifierChain.git
cd MCTS_ClassifierChain
pip install -e .

You may need to use pip3 depending on your Python installation.

Results

The MCTS-based approach was benchmarked against several other methods, including standard Classifier Chains (CC), Probabilistic Classifier Chains (PCC), and Monte Carlo Classifier Chains (MCC). The tables below show the average performance rankings across multiple datasets. Our method (MUCB(2)) achieves the second-best performance, close to the state-of-the-art, without extensive hyperparameter tuning.

For details on how to reproduce these results, please refer to the notebooks in the data/ directory.

Ranking by Exact Match Score
Dataset PCC CC MCC MUCB(2) MEPS(0.2) MEPS(0.5) MTMS(1,1) M1UCB(2) M1EPS(0.2)
Music 3 6 3 1 9 5 7 2 8
Scene 1 8 1 3 6 5 6 4 9
Flags 3 9 3 5 1 2 6 8 6
Foodtruck 2 3 1 5 8 6 4 9 7
Yeast 1 3 1 4 7 6 5 9 8
Birds 9 2 1 3 6 5 4 8 7
Genbase 9 2 1 4 6 5 3 8 7
avg. rank 4.0 4.71 1.57 3.57 6.14 4.85 5.0 6.85 7.43
Ranking by Hamming Score
Dataset PCC CC MCC MUCB(2) MEPS(0.2) MEPS(0.5) MTMS(1,1) M1UCB(2) M1EPS(0.2)
Music 4 3 5 1 8 7 6 2 9
Scene 3 7 2 1 8 6 5 4 9
Flags 3 7 3 6 8 1 2 5 9
Foodtruck 2 3 1 4 8 6 5 9 7
Yeast 2 1 2 4 7 6 5 9 8
Birds 9 2 1 3 6 5 4 8 7
Genbase 9 2 1 4 6 5 3 8 7
avg. rank 4.57 3.57 2 3.29 7.29 5.14 4.29 6.429 8.0

Repository Overview

  • src/mcts_inference/: Source code for the MCTS implementation, policies, and related utilities.
  • examples/: Jupyter notebooks demonstrating how to use the library and comparing it with other methods.
  • data/: Datasets, preprocessing notebooks, and evaluation results.
  • tests/: Unit tests for the project.

Testing

The project uses pytest for testing. You can run the tests from the root directory:

pytest

Note: While the core data structures and utilities are well-tested, the main inference functions currently have limited test coverage. Contributions to improve this are welcome.

Contributing

Contributions to this project are more than welcome. The aim is to further study and improve the method used in this project. Please feel free to open an issue or submit a pull request.

Contact

For questions about the project, please contact Romain Poggi (romain.poggi@polytechnique.edu).

For questions related to the theoretical aspects of the method, you can also contact Professor Jesse Read (jesse.read@polytechnique.edu).

License

This project is licensed under the MIT License. See the LICENSE file for more details.

About

Implementation of the MCTS algorithm for Classifier Chains.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •