Repository for the paper "Confidence-Driven Multi-Scale Model Selection for Cost-Efficient Inference".
- (Optional) Create a virtual environment.
- Install necessary packages via
pip install -r requirements.txt. - Download datasets via Huggingface, a Huggingface token is required. Set the env variable
HF_TOKENor alternatively put your key inhf_token.keyfile. After setting up the token, you can runpython huggingface_download.pyto download MMLU and PopQA datasets.
First, we need to run evaluation by only using P(T) confidence.
For example, the following command evaluates LLaMA 3 3B model on MMLU dataset.
If you want to evaluate on PopQA, make sure to change prompt path (prompt/popqa-chat.txt) and task type (generative).
python run_eval.py --model_name llama_3b \
--dataset_name mmlu \
--dataset_split test \
--prompt_path prompt/mmlu.txt \
--task_type multiple_choiceThe evaluation result will be stored in ./eval_results/{dataset_name}/{model}.tsv.
After obtaining the evaluation result (ground truth for training the classifier), we still need to get the input of the model, which is the hidden states.
python extract_hidden_states.py --model_name llama_3b \
--dataset_name mmlu \
--dataset_split test \
--prompt_path prompt/mmlu.txtNow we can train the MLP using the following command:
python train_classifier.py --model_name llama_3b \
--dataset_name mmlu \
--dataset_split test \
--dataset_group_by subjectFor PopQA, note that --dataset_group_by flag should be prop instead.
The trained model will be stored in ./mlp/{dataset_name}/{model_name}.
Use analysis.py to view the aggregated results:
python analysis.py --models llama_3b,llama_8b \
--dataset_name mmlu \
--use_ikModel: llama_3b
Overall (1430) Accuracy: 0.6392
Confident (587) Accuracy: 0.8620
Not Confident (843) Accuracy: 0.4840
Model: llama_8b
Overall (1430) Accuracy: 0.6986
Confident (734) Accuracy: 0.9019
Not Confident (696) Accuracy: 0.4842
Total evaluated samples: 1430
| Models | Accuracy | Cost |
| -------------------- | -------- | -------- |
| llama_3b | 0.6392 | 4290 |
| llama_8b | 0.6986 | 11440 |
| llama_3b -> llama_8b | 0.6902 | 10662 |
Query Distribution:
source
llama_8b 843
llama_3b 587
- The results vary due to the random seed while training MLP. Your results might be better or worse compared to the results provided in the paper.
--use_ik: This flag tells the script if you want to use P(IK) classifier's result.--use_whole_dataset: Set this flag to evaluate the whole dataset (primarily for testing P(T)'s effect).--confidence_threshold=0.9: Customize the confidence threshold.