Skip to content

ndrco/directed-speech-ru

Repository files navigation

CI

Directed Speech Classifier (RU)

Binary classifier that detects whether an utterance is addressed to an AI assistant ("directed speech") from text (usually ASR output).

What this project is for

This project is built to fine-tune ruElectra-small: https://huggingface.co/ai-forever/ruElectra-small

Task: binary classification

  • 1 -> the phrase is directed to the AI assistant
  • 0 -> the phrase is not directed to the AI assistant

Intended production use: as a filtering stage in front of a downstream assistant pipeline.

Typical flow:

  1. Speech recognition system (for example, Whisper-based ASR) produces a text stream.
  2. This classifier scores each phrase: p(directed).
  3. Only phrases predicted as directed are forwarded to the fine-tuned assistant/NLU/LLM stack.
  4. Non-directed phrases are filtered out.

Why this helps:

  • reduces accidental activations from background speech or human-to-human dialog
  • lowers unnecessary load on downstream assistant components
  • improves robustness of always-on or attention-mode assistant scenarios

Pipeline in this repo:

  1. synthetic dataset generation with ASR-like noise (this step can be skipped if using the already generated CSVs in data/)
  2. fine-tuning ruElectra-small
  3. inference + evaluation + top FP/FN error analysis

Example metrics (from test_report.json):

  • Accuracy: 0.9959
  • ROC-AUC: 0.9999
  • Confusion matrix: [[237, 2], [0, 248]]
  • Split (group): train 8447 / val 1523 / test 487 (groups: {'train': 321, 'val': 60, 'test': 21})

Install

python -m venv .venv
source .venv/bin/activate  # Linux/Mac
# .venv\Scripts\activate  # Windows

pip install -r requirements.txt

Quickstart

1) Generate dataset

python scripts/generate_data_v3.py

The script creates: data_v3.csv, data_v3_train.csv, data_v3_val.csv, data_v3_test.csv.

2) Train

python scripts/train_ruelectra_directed_v2.py \
  --train data_v3_train.csv --val data_v3_val.csv --test data_v3_test.csv \
  --out ./directed-ruElectra-small

3) Inference (interactive)

python scripts/infer_directed_v2.py --model ./directed-ruElectra-small

Enter text to get p(directed) and the predicted class.

4) Evaluate + top errors

python scripts/infer_directed_v2.py --model ./directed-ruElectra-small \
  --eval data_v3_test.csv --topk 20 --threshold 0.7

Data format

Minimum CSV:

  • text - string
  • label - 0/1

Optional:

  • group_id - used for group split in the training script.

Threshold vs argmax

By default, prediction is based on p(directed) >= threshold. To use logits argmax, pass --argmax to infer_directed_v2.py.

Repository structure

  • scripts/ - data generation / train / infer
  • docs/ - dataset/model/reproducibility notes
  • data/ - dataset generated by scripts/generate_data_v3.py and enriched with real examples at the end

License

See LICENSE.

Releases

No releases published

Packages

 
 
 

Contributors