PyTorch implementation of a prototype-based sequence tagger following a single geometric principle: score labels by negative squared distance to prototypes plus Markov transitions. Full theory and experiments are in report/report.pdf (build from report/report.tex).
distance-based-neural-tagger/
├── README.md
├── requirements.txt
├── data/ # Code-mixed train/test data
├── figures/ # (gitignored) Training plots and scatterplots — generate with scripts below
├── checkpoints/ # (gitignored: *.pt) Model checkpoints and tracking data
├── docs/ # (gitignored) Generated documentation — optional, theory is in report/
├── report/ # LaTeX report (theory + experiments)
│ ├── report.tex
│ └── report.pdf # Build with: cd report && pdflatex report.tex
├── scripts/ # Entry-point scripts (run from repo root)
│ ├── train.py # Train on dummy data
│ ├── train_with_tracking.py
│ ├── create_scatterplots.py
│ ├── visualize_training.py
│ └── generate_code_mixed_data.py
├── prototype_sequence_tagger.py # Core model
├── chunk_relation_model.py # Chunk-relation model
├── code_mixed_data_loader.py # Data loader for code-mixed data
└── data_generator.py # Dummy sequence generator
Note: docs/, figures/, and checkpoints/*.pt are in .gitignore and are not in the repo. Generate figures and checkpoints with the scripts below. Run all scripts from the repository root, e.g. python scripts/train.py.
This implementation models sequence tagging tasks (POS, NER, chunks, etc.) using a unified geometric principle:
- Embed tokens into a Hilbert space using an encoder (BiLSTM or MLP)
- Associate each label with a prototype (learnable centroid) in that space
- Score labels by negative squared distance to prototypes:
score(k|z) = -||z - c_k||² + transitions - Use Markov transitions for sequence structure
- Geometric prototype-based classification: Labels are scored by distance to prototypes in representation space
- Markov transitions: Transition matrix models label-to-label dependencies
- Multiple decoding methods: Greedy and Viterbi decoding
- Prototype initialization: Can initialize prototypes from class centroids
- Clean gradient structure: Direct control over curvature via prototype geometry
- Activate the virtual environment:
source venv/bin/activate- Install dependencies:
pip install -r requirements.txtFrom the repo root, train with default parameters:
python scripts/train.pyTrain with custom parameters:
python scripts/train.py \
--vocab_size 100 \
--num_labels 5 \
--hidden_dim 128 \
--batch_size 32 \
--num_epochs 20 \
--lr 0.001 \
--use_bilstm--vocab_size: Vocabulary size (default: 50)--num_labels: Number of label classes (default: 5)--embedding_dim: Token embedding dimension (default: 100)--hidden_dim: Hidden dimension for encoder and prototypes (default: 128)--num_layers: Number of LSTM layers (default: 1)--dropout: Dropout rate (default: 0.1)--use_bilstm: Use BiLSTM encoder (default: True)--no_bilstm: Use MLP encoder instead--train_samples: Number of training samples (default: 1000)--val_samples: Number of validation samples (default: 200)--batch_size: Batch size (default: 32)--num_epochs: Number of training epochs (default: 20)--lr: Learning rate (default: 0.001)--seed: Random seed (default: 42)--save_dir: Directory to save checkpoints (default:checkpoints/in repo root)
-
Encoder:
- Option 1: Bidirectional LSTM with projection
- Option 2: Multi-layer MLP
-
Prototypes:
- Learnable parameters
C ∈ ℝ^{num_labels × hidden_dim} - Each label has a prototype (centroid) in representation space
- Learnable parameters
-
Transition Matrix:
- Learnable
A ∈ ℝ^{num_labels × num_labels} - Models label-to-label transitions
- Learnable
- Encode tokens:
h_t = encoder(x_t) - Compute emissions:
e_t(k) = -||h_t - c_k||² - Add transitions:
s_t(k) = e_t(k) + A[y_{t-1}, k] - Apply softmax:
p(y_t = k) = softmax(s_t(k))
Cross-entropy loss with teacher forcing:
L = -Σ_t log p(y_t | X, y_{t-1})
The gradient w.r.t. representation h_t is:
∂L/∂h_t = 2(μ_t - c_{y_t})
where μ_t = Σ_k p_t(k) c_k is the predicted mean prototype.
The Hessian (curvature) is:
H_{h_t} = 4 Cov_{p_t}(c)
This means the curvature in representation space is directly controlled by the covariance of prototypes under the predictive distribution.
figures/ and docs/ are gitignored. To create them locally:
- Code-mixed training with tracking (writes
checkpoints/tracking_data.jsonand optionallycheckpoints/*.pt):python scripts/train_with_tracking.py
- Training plots (writes
figures/*.png):python scripts/visualize_training.py
- Scatter plots (writes
figures/scatterplots/*.png):python scripts/create_scatterplots.py
The report (report/report.tex) contains the full theory and references these figures. Build the PDF with cd report && pdflatex report.tex (twice).
- Root:
prototype_sequence_tagger.py(core model),chunk_relation_model.py,code_mixed_data_loader.py,data_generator.py - scripts/:
train.py(dummy data),train_with_tracking.py(code-mixed + tracking),create_scatterplots.py,visualize_training.py,generate_code_mixed_data.py - report/: LaTeX report (theory + example runs + figure references).
docs/andfigures/are gitignored and generated by the scripts above.
from prototype_sequence_tagger import PrototypeSequenceTagger
from data_generator import DummyDataGenerator
import torch
# Create model
model = PrototypeSequenceTagger(
vocab_size=50,
num_labels=5,
hidden_dim=128,
use_bilstm=True
)
# Generate dummy data
generator = DummyDataGenerator(vocab_size=50, num_labels=5)
input_ids, labels, mask = generator.generate_batch(batch_size=10)
# Forward pass
logits, loss = model(input_ids, labels=labels, mask=mask)
# Decode
predictions = model.decode(input_ids, mask=mask, method='greedy')- The model supports both teacher forcing (training) and autoregressive decoding (inference)
- Prototypes can be initialized from class centroids of labeled data
- The implementation matches the theory described in the summary document
- Dummy data is generated with some transition structure to simulate realistic sequences
- Theory and experiments:
report/report.pdf(build fromreport/report.tex) - Report source:
report/report.tex— full derivation, gradient/Hessian, relation modeling, example runs, and figure references