Binary classifier that detects whether an utterance is addressed to an AI assistant ("directed speech") from text (usually ASR output).
This project is built to fine-tune ruElectra-small:
https://huggingface.co/ai-forever/ruElectra-small
Task: binary classification
1-> the phrase is directed to the AI assistant0-> the phrase is not directed to the AI assistant
Intended production use: as a filtering stage in front of a downstream assistant pipeline.
Typical flow:
- Speech recognition system (for example, Whisper-based ASR) produces a text stream.
- This classifier scores each phrase:
p(directed). - Only phrases predicted as directed are forwarded to the fine-tuned assistant/NLU/LLM stack.
- Non-directed phrases are filtered out.
Why this helps:
- reduces accidental activations from background speech or human-to-human dialog
- lowers unnecessary load on downstream assistant components
- improves robustness of always-on or attention-mode assistant scenarios
Pipeline in this repo:
- synthetic dataset generation with ASR-like noise (this step can be skipped if using the already generated CSVs in
data/) - fine-tuning
ruElectra-small - inference + evaluation + top FP/FN error analysis
Example metrics (from test_report.json):
- Accuracy: 0.9959
- ROC-AUC: 0.9999
- Confusion matrix: [[237, 2], [0, 248]]
- Split (group): train 8447 / val 1523 / test 487 (groups: {'train': 321, 'val': 60, 'test': 21})
python -m venv .venv
source .venv/bin/activate # Linux/Mac
# .venv\Scripts\activate # Windows
pip install -r requirements.txtpython scripts/generate_data_v3.pyThe script creates: data_v3.csv, data_v3_train.csv, data_v3_val.csv, data_v3_test.csv.
python scripts/train_ruelectra_directed_v2.py \
--train data_v3_train.csv --val data_v3_val.csv --test data_v3_test.csv \
--out ./directed-ruElectra-smallpython scripts/infer_directed_v2.py --model ./directed-ruElectra-smallEnter text to get p(directed) and the predicted class.
python scripts/infer_directed_v2.py --model ./directed-ruElectra-small \
--eval data_v3_test.csv --topk 20 --threshold 0.7Minimum CSV:
text- stringlabel- 0/1
Optional:
group_id- used for group split in the training script.
By default, prediction is based on p(directed) >= threshold.
To use logits argmax, pass --argmax to infer_directed_v2.py.
scripts/- data generation / train / inferdocs/- dataset/model/reproducibility notesdata/- dataset generated byscripts/generate_data_v3.pyand enriched with real examples at the end
See LICENSE.