From 3414d2b9fe9406e1a2f41106fe9559fbf4d9e381 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 18:19:15 +0800 Subject: [PATCH 1/9] add new ut for gedit --- tests/UT/cli/test_workers.py | 224 ++++++++++++- tests/UT/datasets/test_base.py | 168 +++++++++- tests/UT/datasets/test_g_edit.py | 135 ++++++++ tests/UT/datasets/utils/test_llm_judge.py | 132 ++++++++ tests/UT/models/test_output.py | 194 +++++++++++- .../test_lmm_gen_inferencer_output_handler.py | 296 ++++++++++++++++++ .../test_icl_lmm_gen_inferencer.py | 149 +++++++++ 7 files changed, 1293 insertions(+), 5 deletions(-) create mode 100644 tests/UT/datasets/test_g_edit.py create mode 100644 tests/UT/datasets/utils/test_llm_judge.py create mode 100644 tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py create mode 100644 tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py diff --git a/tests/UT/cli/test_workers.py b/tests/UT/cli/test_workers.py index b8f3c8b1..ac883f4f 100644 --- a/tests/UT/cli/test_workers.py +++ b/tests/UT/cli/test_workers.py @@ -1,7 +1,7 @@ import sys import os import pytest -from unittest.mock import patch, MagicMock, call +from unittest.mock import patch, MagicMock, call, mock_open from collections import defaultdict # 添加项目根目录到Python路径 @@ -12,6 +12,7 @@ from ais_bench.benchmark.cli.workers import ( BaseWorker, Infer, + JudgeInfer, Eval, AccViz, PerfViz, @@ -664,11 +665,14 @@ def test_work_flow_dict(): assert 'viz' in WORK_FLOW assert 'perf' in WORK_FLOW assert 'perf_viz' in WORK_FLOW + assert 'judge' in WORK_FLOW + assert 'infer_judge' in WORK_FLOW # 验证工作流内容正确 assert Infer in WORK_FLOW['all'] assert Eval in WORK_FLOW['all'] assert AccViz in WORK_FLOW['all'] + assert JudgeInfer in WORK_FLOW['all'] assert Infer in WORK_FLOW['infer'] @@ -680,4 +684,220 @@ def test_work_flow_dict(): assert Infer in WORK_FLOW['perf'] assert PerfViz in WORK_FLOW['perf'] - assert PerfViz in WORK_FLOW['perf_viz'] \ No newline at end of file + assert PerfViz in WORK_FLOW['perf_viz'] + + assert JudgeInfer in WORK_FLOW['judge'] + assert Infer in WORK_FLOW['infer_judge'] + assert JudgeInfer in WORK_FLOW['infer_judge'] + + +class TestJudgeInfer: + def setup_method(self): + """设置测试环境""" + self.mock_args = MagicMock() + self.mock_args.max_num_workers = 4 + self.mock_args.max_workers_per_gpu = 2 + self.mock_args.debug = False + self.judge_infer_worker = JudgeInfer(self.mock_args) + + @patch('ais_bench.benchmark.cli.workers.get_config_type') + def test_update_cfg_service_model(self, mock_get_config_type): + """测试update_cfg方法,使用service模型""" + mock_get_config_type.side_effect = ['MockNaivePartitioner', 'MockOpenICLApiInferTask', 'MockLocalRunner'] + + cfg = MockConfigDict({ + 'datasets': [{ + 'judge_infer_cfg': { + 'judge_model': {'attr': 'service', 'abbr': 'judge_model'} + } + }], + 'work_dir': '/test/workdir', + 'cli_args': MagicMock(debug=False) + }) + + with patch('os.path.join', return_value='/test/workdir/predictions/'): + result = self.judge_infer_worker.update_cfg(cfg) + + assert result == cfg + assert cfg['judge_infer']['partitioner']['type'] == 'MockNaivePartitioner' + assert cfg['judge_infer']['runner']['type'] == 'MockLocalRunner' + assert cfg['judge_infer']['runner']['task']['type'] == 'MockOpenICLApiInferTask' + + @patch('ais_bench.benchmark.cli.workers.get_config_type') + def test_update_cfg_local_model(self, mock_get_config_type): + """测试update_cfg方法,使用local模型""" + mock_get_config_type.side_effect = ['MockNaivePartitioner', 'MockOpenICLInferTask', 'MockLocalRunner'] + + cfg = MockConfigDict({ + 'datasets': [{ + 'judge_infer_cfg': { + 'judge_model': {'attr': 'local', 'abbr': 'judge_model'} + } + }], + 'work_dir': '/test/workdir', + 'cli_args': MagicMock(debug=True) + }) + + with patch('os.path.join', return_value='/test/workdir/predictions/'): + self.judge_infer_worker.update_cfg(cfg) + + assert cfg['judge_infer']['runner']['task']['type'] == 'MockOpenICLInferTask' + assert cfg['judge_infer']['runner']['debug'] == True + + def test_cfg_pre_process(self): + """测试_cfg_pre_process方法""" + cfg = MockConfigDict({ + 'datasets': [ + { + 'abbr': 'test_dataset', + 'judge_infer_cfg': { + 'judge_model': {'abbr': 'judge_model'} + } + } + ] + }) + + self.judge_infer_worker._cfg_pre_process(cfg) + + assert cfg['datasets'][0]['abbr'] == 'test_dataset-judge_model' + assert 'test_dataset-judge_model' in self.judge_infer_worker.org_dataset_abbrs + + def test_cfg_pre_process_with_model_dataset_combinations(self): + """测试_cfg_pre_process方法,包含model_dataset_combinations""" + cfg = MockConfigDict({ + 'model_dataset_combinations': [ + { + 'datasets': [ + { + 'abbr': 'combo_dataset', + 'judge_infer_cfg': { + 'judge_model': {'abbr': 'judge_model'} + } + } + ] + } + ], + 'datasets': [] + }) + + self.judge_infer_worker._cfg_pre_process(cfg) + + assert cfg['model_dataset_combinations'][0]['datasets'][0]['abbr'] == 'combo_dataset-judge_model' + + def test_merge_datasets(self): + """测试_merge_datasets方法""" + task1 = { + 'models': [{'abbr': 'model1'}], + 'datasets': [[{'type': 'dataset_type', 'infer_cfg': {'inferencer': 'inferencer_type'}}]] + } + task2 = { + 'models': [{'abbr': 'model1'}], + 'datasets': [[{'type': 'dataset_type', 'infer_cfg': {'inferencer': 'inferencer_type'}}]] + } + task3 = { + 'models': [{'abbr': 'model2'}], + 'datasets': [[{'type': 'dataset_type', 'infer_cfg': {'inferencer': 'inferencer_type'}}]] + } + + result = self.judge_infer_worker._merge_datasets([task1, task2, task3]) + + assert len(result) == 2 + assert len(result[0]['datasets'][0]) == 2 + assert len(result[1]['datasets'][0]) == 1 + + @patch('ais_bench.benchmark.cli.workers.PARTITIONERS') + @patch('ais_bench.benchmark.cli.workers.RUNNERS') + @patch('ais_bench.benchmark.cli.workers.logger') + def test_do_work_no_tasks(self, mock_logger, mock_runners, mock_partitioners): + """测试do_work方法,没有有效任务的情况""" + mock_partitioner = MagicMock() + mock_partitioners.build.return_value = mock_partitioner + mock_tasks = [ + { + 'datasets': [[{}]] # 没有judge_infer_cfg + } + ] + mock_partitioner.return_value = mock_tasks + + cfg = MockConfigDict({ + 'judge_infer': { + 'partitioner': {}, + 'runner': {} + }, + 'datasets': [] + }) + + with patch.object(self.judge_infer_worker, '_cfg_pre_process'): + with patch.object(self.judge_infer_worker, '_update_tasks_cfg'): + self.judge_infer_worker.do_work(cfg) + + mock_runners.build.assert_not_called() + + @patch('ais_bench.benchmark.cli.workers.load_jsonl') + @patch('ais_bench.benchmark.cli.workers.dump_jsonl') + @patch('os.path.exists') + @patch('os.remove') + def test_result_post_process(self, mock_remove, mock_exists, mock_dump_jsonl, mock_load_jsonl): + """测试_result_post_process方法""" + mock_load_jsonl.side_effect = [ + [{'uuid': 'uuid1', 'id': 'id1'}], # model_preds + [{'gold': 'uuid1', 'prediction': 'pred1'}] # judge_preds + ] + mock_exists.return_value = True + + task = { + 'datasets': [[{ + 'predictions_path': '/test/model_pred.jsonl', + 'abbr': 'test_dataset-judge_model' + }]], + 'models': [{'abbr': 'model1'}] + } + tasks = [task] + + cfg = MockConfigDict({ + 'judge_infer': { + 'partitioner': { + 'out_dir': '/test/predictions' + } + } + }) + + self.judge_infer_worker.org_dataset_abbrs = {'test_dataset-judge_model': 'test_dataset'} + + self.judge_infer_worker._result_post_process(tasks, cfg) + + mock_remove.assert_called_once() + mock_dump_jsonl.assert_called_once() + + def test_update_tasks_cfg_with_judge_infer(self): + """测试_update_tasks_cfg方法,包含judge_infer_cfg""" + self.judge_infer_worker.org_dataset_abbrs = {'test_dataset-judge_model': 'test_dataset'} + + task = { + 'models': [{'abbr': 'model1'}], + 'datasets': [[{ + 'abbr': 'test_dataset-judge_model', + 'judge_infer_cfg': { + 'judge_model': {'type': 'judge_model_type'}, + 'judge_dataset_type': 'judge_dataset', + 'judge_reader_cfg': {'test': 'cfg'} + } + }]] + } + tasks = [task] + + cfg = MockConfigDict({ + 'judge_infer': { + 'partitioner': { + 'out_dir': '/test/predictions' + } + } + }) + + with patch('os.path.join', return_value='/test/predictions/model1/test_dataset.jsonl'): + with patch('os.path.exists', return_value=True): + self.judge_infer_worker._update_tasks_cfg(tasks, cfg) + + assert 'judge_infer_cfg' not in task['datasets'][0][0] + assert task['models'][0]['type'] == 'judge_model_type' + assert task['datasets'][0][0]['type'] == 'judge_dataset' \ No newline at end of file diff --git a/tests/UT/datasets/test_base.py b/tests/UT/datasets/test_base.py index a2bc8593..e413ed69 100644 --- a/tests/UT/datasets/test_base.py +++ b/tests/UT/datasets/test_base.py @@ -1,9 +1,10 @@ import unittest from unittest.mock import patch, MagicMock +import pytest from datasets import Dataset, DatasetDict -from ais_bench.benchmark.datasets.base import BaseDataset +from ais_bench.benchmark.datasets.base import BaseDataset, BaseJDGDataset class DummyDataset(BaseDataset): @@ -130,6 +131,171 @@ def load(**kwargs): self.assertIn("subdivision", first) self.assertIn("idx", first) + def test_init_with_task_state_manager(self): + """测试使用task_state_manager初始化""" + mock_task_state_manager = MagicMock() + ds = DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1, + task_state_manager=mock_task_state_manager + ) + self.assertEqual(ds.task_state_manager, mock_task_state_manager) + + def test_update_task_state_with_manager(self): + """测试使用task_state_manager更新状态""" + mock_task_state_manager = MagicMock() + ds = DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1, + task_state_manager=mock_task_state_manager + ) + state = {'status': 'processing'} + ds.update_task_state(state) + mock_task_state_manager.update_task_state.assert_called_once_with(state) + + def test_update_task_state_without_manager(self): + """测试没有task_state_manager时更新状态""" + ds = DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1 + ) + # 不应抛出异常 + ds.update_task_state({'status': 'processing'}) + + def test_init_with_abbr(self): + """测试使用abbr参数初始化""" + ds = DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1, + abbr='custom_abbr' + ) + self.assertEqual(ds.abbr, 'custom_abbr') + + def test_init_k_greater_than_n_raises_error(self): + """测试k > n时抛出异常""" + from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError + with self.assertRaises(ParameterValueError): + DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=5, + n=3 + ) + + def test_init_k_list_greater_than_n_raises_error(self): + """测试k为列表且最大值 > n时抛出异常""" + from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError + with self.assertRaises(ParameterValueError): + DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=[1, 2, 5], + n=3 + ) + + def test_repeated_dataset_with_dataset_dict(self): + """测试DatasetDict类型的repeated_dataset处理""" + ds = DummyDatasetDict( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=2 + ) + self.assertIsInstance(ds.dataset, DatasetDict) + # 验证每个split都被正确处理 + self.assertEqual(len(ds.dataset['train']), 4) # 2 * 2 + self.assertEqual(len(ds.dataset['test']), 2) # 1 * 2 + # 验证元数据 + first_train = ds.dataset['train'][0] + self.assertIn("subdivision", first_train) + self.assertIn("idx", first_train) + + +class TestBaseJDGDataset(unittest.TestCase): + def test_init_org_datasets_instance(self): + """测试_init_org_datasets_instance方法""" + class DummyJDGDataset(BaseJDGDataset): + def _get_dataset_class(self): + return DummyDataset + def _load_from_predictions(self, prediction_path): + return [] + + with patch.object(DummyJDGDataset, 'load') as mock_load: + mock_load.return_value = Dataset.from_list([{"text": "a"}]) + ds = DummyJDGDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1 + ) + self.assertIsNotNone(ds.dataset_instance) + + def test_process_single_item(self): + """测试_process_single_item方法""" + class DummyJDGDataset(BaseJDGDataset): + def _get_dataset_class(self): + return DummyDataset + def _load_from_predictions(self, prediction_path): + return [] + + dataset_content = Dataset.from_list([ + {"text": "question1", "answer": "A"}, + {"text": "question2", "answer": "B"} + ]) + pred_item = {"id": 0, "prediction": "predicted_answer", "uuid": "test_uuid"} + + with patch.object(DummyJDGDataset, 'load') as mock_load: + mock_load.return_value = Dataset.from_list([{"text": "a"}]) + ds = DummyJDGDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1 + ) + result = ds._process_single_item(dataset_content, pred_item) + self.assertEqual(result["model_answer"], "predicted_answer") + self.assertEqual(result["model_pred_uuid"], "test_uuid") + + def test_modify_dataset_item(self): + """测试_modify_dataset_item方法""" + class DummyJDGDataset(BaseJDGDataset): + def _get_dataset_class(self): + return DummyDataset + def _load_from_predictions(self, prediction_path): + return [] + + with patch.object(DummyJDGDataset, 'load') as mock_load: + mock_load.return_value = Dataset.from_list([{"text": "a"}]) + ds = DummyJDGDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1 + ) + dataset_item = {"text": "question", "answer": "A"} + pred_item = {"prediction": "predicted_answer"} + ds._modify_dataset_item(dataset_item, pred_item) + self.assertEqual(dataset_item["model_answer"], "predicted_answer") + + def test_load_with_predictions(self): + """测试load方法处理predictions""" + class DummyJDGDataset(BaseJDGDataset): + def _get_dataset_class(self): + return DummyDataset + def _load_from_predictions(self, prediction_path): + return [{"id": 0, "prediction": "pred1", "uuid": "uuid1"}] + + with patch.object(DummyJDGDataset, '_process_predictions') as mock_process: + mock_process.return_value = Dataset.from_list([{"text": "result"}]) + + with patch.object(DummyJDGDataset, '__init__', lambda self, *args, **kwargs: None): + ds = DummyJDGDataset.__new__(DummyJDGDataset) + ds.dataset_instance = MagicMock() + ds.dataset_instance.dataset = {"test": Dataset.from_list([{"text": "test"}])} + ds.task_state_manager = None + ds.logger = MagicMock() + + result = ds.load(predictions_path="/test/predictions.jsonl") + self.assertIsInstance(result, Dataset) + if __name__ == "__main__": unittest.main() diff --git a/tests/UT/datasets/test_g_edit.py b/tests/UT/datasets/test_g_edit.py new file mode 100644 index 00000000..176dccbc --- /dev/null +++ b/tests/UT/datasets/test_g_edit.py @@ -0,0 +1,135 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock, mock_open +from io import BytesIO +from PIL import Image +import base64 + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) + +from datasets import Dataset + +from ais_bench.benchmark.datasets.g_edit import ( + GEditDataset, + GEditSCJDGDataset, + GEditPQJDGDataset, + GEditEvaluator +) + + +class TestGEditEvaluator: + def test_score(self): + """测试score方法""" + evaluator = GEditEvaluator() + predictions = ["pred1", "pred2", "pred3"] + references = ["ref1", "ref2", "ref3"] + + result = evaluator.score(predictions, references) + + assert "accuracy" in result + assert "details" in result + assert result["accuracy"] == 100.0 + assert len(result["details"]) == 3 + + def test_score_empty(self): + """测试score方法,空输入""" + evaluator = GEditEvaluator() + predictions = [] + references = [] + + result = evaluator.score(predictions, references) + + assert "accuracy" in result + assert "details" in result + + +class TestGEditDataset: + def test_load_basic(self): + """测试基本load方法""" + mock_dataset = Dataset.from_list([ + { + "input_image": Image.new('RGB', (100, 100), color='red'), + "input_image_raw": Image.new('RGB', (100, 100), color='blue'), + "instruction": "test instruction" + } + ]) + + with patch('ais_bench.benchmark.datasets.g_edit.load_from_disk') as mock_load: + with patch('ais_bench.benchmark.datasets.g_edit.get_data_path') as mock_get_path: + mock_get_path.return_value = '/test/path' + mock_load.return_value = mock_dataset + + ds = GEditDataset.__new__(GEditDataset) + ds.task_state_manager = None + ds.logger = MagicMock() + ds.update_task_state = MagicMock() + + result = ds.load(path='/test/path') + + assert isinstance(result, Dataset) + + def test_load_with_split(self): + """测试带数据集切分的load方法""" + mock_dataset = Dataset.from_list([ + { + "input_image": Image.new('RGB', (100, 100), color='red'), + "input_image_raw": Image.new('RGB', (100, 100), color='blue'), + "instruction": "test instruction" + } + ] * 10) + + with patch('ais_bench.benchmark.datasets.g_edit.load_from_disk') as mock_load: + with patch('ais_bench.benchmark.datasets.g_edit.get_data_path') as mock_get_path: + mock_get_path.return_value = '/test/path' + mock_load.return_value = mock_dataset + + ds = GEditDataset.__new__(GEditDataset) + ds.task_state_manager = None + ds.logger = MagicMock() + ds.update_task_state = MagicMock() + + result = ds.load(path='/test/path', split_count=2, split_index=0) + + assert isinstance(result, Dataset) + + def test_load_use_raw(self): + """测试使用原始图片的load方法""" + mock_dataset = Dataset.from_list([ + { + "input_image": Image.new('RGB', (100, 100), color='red'), + "input_image_raw": Image.new('RGB', (100, 100), color='blue'), + "instruction": "test instruction" + } + ]) + + with patch('ais_bench.benchmark.datasets.g_edit.load_from_disk') as mock_load: + with patch('ais_bench.benchmark.datasets.g_edit.get_data_path') as mock_get_path: + mock_get_path.return_value = '/test/path' + mock_load.return_value = mock_dataset + + ds = GEditDataset.__new__(GEditDataset) + ds.task_state_manager = None + ds.logger = MagicMock() + ds.update_task_state = MagicMock() + + result = ds.load(path='/test/path', use_raw=True) + + assert isinstance(result, Dataset) + + +class TestGEditSCJDGDataset: + def test_get_dataset_class(self): + """测试_get_dataset_class方法""" + ds = GEditSCJDGDataset.__new__(GEditSCJDGDataset) + result = ds._get_dataset_class() + assert result == GEditDataset + + +class TestGEditPQJDGDataset: + def test_get_dataset_class(self): + """测试_get_dataset_class方法""" + ds = GEditPQJDGDataset.__new__(GEditPQJDGDataset) + result = ds._get_dataset_class() + assert result == GEditDataset diff --git a/tests/UT/datasets/utils/test_llm_judge.py b/tests/UT/datasets/utils/test_llm_judge.py new file mode 100644 index 00000000..7d3f75ed --- /dev/null +++ b/tests/UT/datasets/utils/test_llm_judge.py @@ -0,0 +1,132 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock, mock_open +import tempfile +import json + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) + +from ais_bench.benchmark.datasets.utils.llm_judge import ( + get_a_or_b, + LLMJudgeDataset, + LLMJudgeCorrectEvaluator +) + + +class TestGetAOrB: + def test_get_a_or_b_with_a(self): + """测试提取A""" + result = get_a_or_b("The answer is A") + assert result == "A" + + def test_get_a_or_b_with_b(self): + """测试提取B""" + result = get_a_or_b("The answer is B") + assert result == "B" + + def test_get_a_or_b_no_match(self): + """测试没有匹配时返回B""" + result = get_a_or_b("The answer is C") + assert result == "B" + + def test_get_a_or_b_empty(self): + """测试空字符串""" + result = get_a_or_b("") + assert result == "B" + + def test_get_a_or_b_at_end(self): + """测试A在末尾""" + result = get_a_or_b("something A") + assert result == "A" + + +class TestLLMJudgeDataset: + def test_load_from_predictions_file_not_exists(self): + """测试文件不存在的情况""" + ds = LLMJudgeDataset.__new__(LLMJudgeDataset) + ds.logger = MagicMock() + + with patch('os.path.exists', return_value=False): + result = ds._load_from_predictions('/test/nonexistent.jsonl') + assert result == [] + + def test_load_from_predictions_success(self): + """测试成功加载predictions""" + mock_preds = [ + {"id": 1, "prediction": "pred1"}, + {"id": 0, "prediction": "pred2"} + ] + + ds = LLMJudgeDataset.__new__(LLMJudgeDataset) + ds.logger = MagicMock() + + with patch('os.path.exists', return_value=True): + with patch('ais_bench.benchmark.datasets.utils.llm_judge.load_jsonl', return_value=mock_preds): + result = ds._load_from_predictions('/test/predictions.jsonl') + + assert len(result) == 2 + assert result[0]["id"] == 0 + assert result[1]["id"] == 1 + + +class TestLLMJudgeCorrectEvaluator: + def test_score_all_correct(self): + """测试全部正确的情况""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["A", "A", "A"] + references = ["correct", "correct", "correct"] + + result = evaluator.score(predictions, references) + + assert result["accuracy"] == 100.0 + assert len(result["details"]) == 3 + for detail in result["details"]: + assert detail["correct"] is True + + def test_score_all_wrong(self): + """测试全部错误的情况""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["B", "B", "B"] + references = ["correct", "correct", "correct"] + + result = evaluator.score(predictions, references) + + assert result["accuracy"] == 0.0 + for detail in result["details"]: + assert detail["correct"] is False + + def test_score_mixed(self): + """测试混合情况""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["A", "B", "A"] + references = ["correct1", "correct2", "correct3"] + + result = evaluator.score(predictions, references) + + assert result["accuracy"] == pytest.approx(100 * 2 / 3, rel=1e-2) + assert result["details"][0]["correct"] is True + assert result["details"][1]["correct"] is False + assert result["details"][2]["correct"] is True + + def test_score_length_mismatch(self): + """测试长度不匹配的情况""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["A", "A"] + references = ["correct", "correct", "correct"] + + result = evaluator.score(predictions, references) + + assert "error" in result + + def test_score_empty(self): + """测试空输入""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = [] + references = [] + + result = evaluator.score(predictions, references) + + assert "accuracy" in result + assert result["accuracy"] == 0.0 diff --git a/tests/UT/models/test_output.py b/tests/UT/models/test_output.py index e70f59ab..34711042 100644 --- a/tests/UT/models/test_output.py +++ b/tests/UT/models/test_output.py @@ -1,7 +1,10 @@ import pytest import numpy as np import asyncio -from ais_bench.benchmark.models.output import Output, RequestOutput +import tempfile +import os +from PIL import Image +from ais_bench.benchmark.models.output import Output, RequestOutput, FunctionCallOutput, LMMOutput import time @@ -202,4 +205,191 @@ def test_request_output_edge_cases(): metrics = output.get_metrics() assert metrics["uuid"] == "test_uuid_123" assert metrics["turn_id"] == 5 - assert metrics["extra_perf_data"] == {"test_key": "test_value"} \ No newline at end of file + assert metrics["extra_perf_data"] == {"test_key": "test_value"} + + +class TestFunctionCallOutput: + def test_init(self): + """测试FunctionCallOutput初始化""" + output = FunctionCallOutput() + assert output.perf_mode is False + assert isinstance(output.inference_log, list) + assert isinstance(output.tool_calls, list) + + def test_init_perf_mode(self): + """测试FunctionCallOutput性能模式初始化""" + output = FunctionCallOutput(perf_mode=True) + assert output.perf_mode is True + + def test_update_extra_details_data_from_text_response(self): + """测试update_extra_details_data_from_text_response方法""" + output = FunctionCallOutput() + text_response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "test content" + } + } + ] + } + + output.update_extra_details_data_from_text_response(text_response) + + assert "message" in output.extra_details_data + assert output.extra_details_data["message"]["role"] == "assistant" + + def test_update_extra_details_data_empty_choices(self): + """测试空choices的情况""" + output = FunctionCallOutput() + text_response = {"choices": []} + + output.update_extra_details_data_from_text_response(text_response) + + assert output.extra_details_data == {} + + def test_update_extra_details_data_no_choices(self): + """测试没有choices的情况""" + output = FunctionCallOutput() + text_response = {} + + output.update_extra_details_data_from_text_response(text_response) + + assert output.extra_details_data == {} + + def test_get_metrics(self): + """测试get_metrics方法""" + output = FunctionCallOutput() + output.success = True + output.uuid = "test_uuid" + output.tool_calls = [{"function": "test_func"}] + + metrics = output.get_metrics() + + assert "tool_calls" in metrics + assert metrics["tool_calls"] == [{"function": "test_func"}] + + +class TestLMMOutput: + def test_init(self): + """测试LMMOutput初始化""" + output = LMMOutput() + assert output.perf_mode is False + assert isinstance(output.content, list) + assert len(output.content) == 1 + assert output.content[0] == "" + + def test_init_perf_mode(self): + """测试LMMOutput性能模式初始化""" + output = LMMOutput(perf_mode=True) + assert output.perf_mode is True + + def test_handle_text(self): + """测试_handle_text方法""" + output = LMMOutput() + output.content = ["text content"] + + result = output._handle_text("/test/dir", 0) + assert result == "text content" + + def test_handle_image(self): + """测试_handle_image方法""" + output = LMMOutput() + output.uuid = "test_uuid" + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image] + + with tempfile.TemporaryDirectory() as tmpdir: + result = output._handle_image(tmpdir, 0) + + assert "image_test_uuid_0.png" in result + expected_path = os.path.join(tmpdir, "image_test_uuid_0.png") + assert os.path.exists(expected_path) + + def test_handle_image_overwrite(self): + """测试_handle_image方法覆盖已存在的文件""" + output = LMMOutput() + output.uuid = "test_uuid" + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image] + + with tempfile.TemporaryDirectory() as tmpdir: + result1 = output._handle_image(tmpdir, 0) + result2 = output._handle_image(tmpdir, 0) + + assert result1 == result2 + + def test_get_prediction_single_text(self): + """测试get_prediction方法,单个文本""" + output = LMMOutput() + output.uuid = "test_uuid" + output.content = ["text content"] + + with tempfile.TemporaryDirectory() as tmpdir: + result = output.get_prediction(tmpdir) + assert result == "text content" + + def test_get_prediction_single_image(self): + """测试get_prediction方法,单个图片""" + output = LMMOutput() + output.uuid = "test_uuid" + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image] + + with tempfile.TemporaryDirectory() as tmpdir: + result = output.get_prediction(tmpdir) + assert "image_test_uuid_0.png" in result + + def test_get_prediction_multiple_items(self): + """测试get_prediction方法,多个项目""" + output = LMMOutput() + output.uuid = "test_uuid" + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image, "text content"] + + with tempfile.TemporaryDirectory() as tmpdir: + result = output.get_prediction(tmpdir) + assert isinstance(result, list) + assert len(result) == 2 + + def test_get_metrics(self): + """测试get_metrics方法""" + output = LMMOutput() + output.success = True + output.uuid = "test_uuid" + output.content = ["test"] + + metrics = output.get_metrics() + + assert "content" not in metrics + assert "perf_mode" not in metrics + assert metrics["uuid"] == "test_uuid" + + +def test_output_update_extra_perf_data_from_stream_response(): + """测试update_extra_perf_data_from_stream_response方法(默认实现)""" + output = ConcreteOutput() + output.update_extra_perf_data_from_stream_response({"test": "data"}) + assert output.extra_perf_data == {} + + +def test_output_update_extra_perf_data_from_text_response(): + """测试update_extra_perf_data_from_text_response方法(默认实现)""" + output = ConcreteOutput() + output.update_extra_perf_data_from_text_response({"test": "data"}) + assert output.extra_perf_data == {} + + +def test_output_update_extra_details_data_from_stream_response(): + """测试update_extra_details_data_from_stream_response方法(默认实现)""" + output = ConcreteOutput() + output.update_extra_details_data_from_stream_response({"test": "data"}) + assert output.extra_details_data == {} + + +def test_output_update_extra_details_data_from_text_response(): + """测试update_extra_details_data_from_text_response方法(默认实现)""" + output = ConcreteOutput() + output.update_extra_details_data_from_text_response({"test": "data"}) + assert output.extra_details_data == {} \ No newline at end of file diff --git a/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py b/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py new file mode 100644 index 00000000..ee1ce467 --- /dev/null +++ b/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py @@ -0,0 +1,296 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock +import tempfile +import uuid +from pathlib import Path + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../../'))) + +from ais_bench.benchmark.openicl.icl_inferencer.output_handler.lmm_gen_inferencer_output_handler import ( + LMMGenInferencerOutputHandler +) +from ais_bench.benchmark.models.output import LMMOutput +from ais_bench.benchmark.utils.logging.exceptions import AISBenchRuntimeError + + +class TestLMMGenInferencerOutputHandler: + def setup_method(self): + """设置测试环境""" + self.handler = LMMGenInferencerOutputHandler() + + def test_init(self): + """测试初始化""" + handler = LMMGenInferencerOutputHandler(perf_mode=True, save_every=50) + assert handler.perf_mode is True + assert handler.save_every == 50 + + def test_set_output_path(self): + """测试set_output_path方法""" + self.handler.set_output_path('/test/output/path') + assert self.handler.output_path == '/test/output/path' + + def test_get_prediction_result_with_string_output(self): + """测试get_prediction_result方法,字符串输出""" + result = self.handler.get_prediction_result( + output="test prediction", + gold="test gold", + input="test input", + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert result["prediction"] == "test prediction" + assert result["gold"] == "test gold" + assert result["origin_prompt"] == "test input" + assert len(result["uuid"]) == 32 + + def test_get_prediction_result_with_lmm_output(self): + """测试get_prediction_result方法,LMMOutput输出""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + output = LMMOutput() + output.uuid = "test_uuid_123" + output.success = True + output.content = ["text content"] + + result = self.handler.get_prediction_result( + output=output, + gold="test gold", + input="test input", + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert result["uuid"] == "test_uuid_123" + assert result["prediction"] == "text content" + assert result["gold"] == "test gold" + + def test_get_prediction_result_with_image_output(self): + """测试get_prediction_result方法,图片输出""" + from PIL import Image + + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + output = LMMOutput() + output.uuid = "test_uuid_456" + output.success = True + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image] + + result = self.handler.get_prediction_result( + output=output, + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert "image_test_uuid_456_0.png" in result["prediction"] + + def test_get_prediction_result_creates_output_dir(self): + """测试get_prediction_result方法创建输出目录""" + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, 'subdir') + self.handler.set_output_path(output_path) + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + data_abbr="test_dataset" + ) + + assert os.path.exists(os.path.join(output_path, "test_dataset_out_file")) + + def test_get_prediction_result_with_long_base64_input(self): + """测试get_prediction_result方法,处理长base64输入""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + long_base64 = "a" * 300 + input_data = [{ + "prompt": [ + { + "type": "image_url", + "image_url": { + "url": long_base64 + } + } + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert "..." in input_data[0]["prompt"][0]["image_url"]["url"] + + def test_get_prediction_result_with_dict_input(self): + """测试get_prediction_result方法,字典类型输入""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + input_data = [{ + "prompt": [ + { + "type": "text", + "text": "test text" + } + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True + + def test_get_prediction_result_with_empty_input(self): + """测试get_prediction_result方法,空输入""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=None, + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert result["origin_prompt"] == "" + + def test_get_prediction_result_without_gold(self): + """测试get_prediction_result方法,没有gold""" + result = self.handler.get_prediction_result( + output="test prediction", + input="test input", + data_abbr="test_dataset" + ) + + assert "gold" not in result + + def test_get_prediction_result_with_failed_output(self): + """测试get_prediction_result方法,失败的输出""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = False + output.error_info = "test error" + output.content = [""] + + result = self.handler.get_prediction_result( + output=output, + data_abbr="test_dataset" + ) + + assert result["success"] is False + + def test_get_prediction_result_with_non_dict_prompt_items(self): + """测试get_prediction_result方法,非字典类型的prompt项""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + input_data = [{ + "prompt": [ + "string_item", + 123, + None + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True + + def test_get_prediction_result_with_non_dict_image_url(self): + """测试get_prediction_result方法,非字典类型的image_url""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + input_data = [{ + "prompt": [ + { + "type": "image_url", + "image_url": "not_a_dict" + } + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True + + def test_get_prediction_result_with_non_string_url(self): + """测试get_prediction_result方法,非字符串类型的URL""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + input_data = [{ + "prompt": [ + { + "type": "image_url", + "image_url": { + "url": 123 + } + } + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True diff --git a/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py b/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py new file mode 100644 index 00000000..47064a49 --- /dev/null +++ b/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py @@ -0,0 +1,149 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock, call +import uuid + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) + +from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer +from ais_bench.benchmark.models.output import LMMOutput + + +class TestLMMGenInferencer: + def setup_method(self): + """设置测试环境""" + self.mock_model_cfg = { + 'type': 'MockModel', + 'abbr': 'test_model' + } + + def test_init(self): + """测试LMMGenInferencer初始化""" + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + stopping_criteria=[], + batch_size=1, + mode='infer' + ) + + assert inferencer.model_cfg == self.mock_model_cfg + assert inferencer.batch_size == 1 + mock_handler.assert_called_once() + + def test_init_with_custom_params(self): + """测试使用自定义参数初始化""" + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + stopping_criteria=['stop1', 'stop2'], + batch_size=4, + mode='perf', + gen_field_replace_token='', + output_json_filepath='/test/output', + save_every=50 + ) + + assert inferencer.stopping_criteria == ['stop1', 'stop2'] + assert inferencer.batch_size == 4 + assert inferencer.perf_mode is True + + def test_inference(self): + """测试inference方法""" + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + mock_retriever = MagicMock() + + with patch.object(inferencer, '__class__', LMMGenInferencer): + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_gen_inferencer.GenInferencer.inference') as mock_super_inference: + mock_super_inference.return_value = [] + + result = inferencer.inference(mock_retriever, '/test/output.jsonl') + + mock_handler_instance.set_output_path.assert_called_once_with('/test/output.jsonl') + mock_super_inference.assert_called_once() + + def test_batch_inference(self): + """测试batch_inference方法""" + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + inferencer.model = MagicMock() + + datum = { + 'index': [0, 1], + 'prompt': ['prompt1', 'prompt2'], + 'data_abbr': ['dataset1', 'dataset1'], + 'gold': ['gold1', 'gold2'], + 'extra_param': 'value' + } + + inferencer.batch_inference(datum) + + inferencer.model.generate.assert_called_once() + assert mock_handler_instance.report_cache_info_sync.call_count == 2 + + def test_batch_inference_without_gold(self): + """测试batch_inference方法,没有gold字段""" + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + inferencer.model = MagicMock() + + datum = { + 'index': [0], + 'prompt': ['prompt1'], + 'data_abbr': ['dataset1'] + } + + inferencer.batch_inference(datum) + + inferencer.model.generate.assert_called_once() + mock_handler_instance.report_cache_info_sync.assert_called_once() + + def test_batch_inference_uuid_generation(self): + """测试batch_inference方法中UUID生成""" + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + inferencer.model = MagicMock() + + datum = { + 'index': [0], + 'prompt': ['prompt1'], + 'data_abbr': ['dataset1'] + } + + inferencer.batch_inference(datum) + + call_args = mock_handler_instance.report_cache_info_sync.call_args + output = call_args[0][2] + assert isinstance(output, LMMOutput) + assert len(output.uuid) == 32 From f7a7e1992d8387de250be4bdeeae31751b982336 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 18:30:26 +0800 Subject: [PATCH 2/9] add new ut for gedit --- tests/UT/datasets/test_base.py | 20 +- tests/UT/datasets/test_g_edit.py | 40 ++- tests/UT/datasets/utils/test_llm_judge.py | 47 ++-- tests/UT/models/test_output.py | 4 + .../test_lmm_gen_inferencer_output_handler.py | 116 +++++---- .../test_icl_lmm_gen_inferencer.py | 242 ++++++++++-------- 6 files changed, 265 insertions(+), 204 deletions(-) diff --git a/tests/UT/datasets/test_base.py b/tests/UT/datasets/test_base.py index e413ed69..9f59d1db 100644 --- a/tests/UT/datasets/test_base.py +++ b/tests/UT/datasets/test_base.py @@ -36,7 +36,7 @@ def load(**kwargs): {"text": "a"}, {"text": "b"}, ]) - + def _init_reader(self, **kwargs): # 先正常初始化reader from ais_bench.benchmark.openicl.icl_dataset_reader import DatasetReader @@ -67,7 +67,7 @@ def test_repeated_dataset_and_metadata(self): first = ds.dataset['test'][0] self.assertIn("subdivision", first) self.assertIn("idx", first) - + def test_repeated_dataset_with_dataset_type(self): """测试当reader.dataset是Dataset类型时的处理(覆盖46-62行)""" # 创建一个返回Dataset的类,并手动设置reader.dataset为Dataset @@ -84,7 +84,7 @@ def test_repeated_dataset_with_dataset_type(self): first = ds.dataset[0] self.assertIn("subdivision", first) self.assertIn("idx", first) - + def test_train_property(self): """测试train属性(覆盖92行)""" ds = DummyDataset( @@ -95,7 +95,7 @@ def test_train_property(self): train = ds.train self.assertIsInstance(train, Dataset) self.assertGreater(len(train), 0) - + def test_test_property(self): """测试test属性(覆盖96行)""" ds = DummyDataset( @@ -106,17 +106,17 @@ def test_test_property(self): test = ds.test self.assertIsInstance(test, Dataset) self.assertGreater(len(test), 0) - + def test_repeated_dataset_with_large_batch_size(self): """测试大批量数据的批处理逻辑(覆盖批处理相关代码)""" # 创建一个较大的数据集来触发批处理逻辑 large_data = [{"text": f"item_{i}"} for i in range(15000)] - + class LargeDataset(BaseDataset): @staticmethod def load(**kwargs): return Dataset.from_list(large_data) - + ds = LargeDataset( reader_cfg={'input_columns': ['text'], 'output_column': None}, k=1, @@ -162,6 +162,8 @@ def test_update_task_state_without_manager(self): k=1, n=1 ) + ds.task_state_manager = None + ds.logger = MagicMock() # 不应抛出异常 ds.update_task_state({'status': 'processing'}) @@ -285,14 +287,14 @@ def _load_from_predictions(self, prediction_path): with patch.object(DummyJDGDataset, '_process_predictions') as mock_process: mock_process.return_value = Dataset.from_list([{"text": "result"}]) - + with patch.object(DummyJDGDataset, '__init__', lambda self, *args, **kwargs: None): ds = DummyJDGDataset.__new__(DummyJDGDataset) ds.dataset_instance = MagicMock() ds.dataset_instance.dataset = {"test": Dataset.from_list([{"text": "test"}])} ds.task_state_manager = None ds.logger = MagicMock() - + result = ds.load(predictions_path="/test/predictions.jsonl") self.assertIsInstance(result, Dataset) diff --git a/tests/UT/datasets/test_g_edit.py b/tests/UT/datasets/test_g_edit.py index 176dccbc..1acd7f8a 100644 --- a/tests/UT/datasets/test_g_edit.py +++ b/tests/UT/datasets/test_g_edit.py @@ -25,22 +25,34 @@ def test_score(self): evaluator = GEditEvaluator() predictions = ["pred1", "pred2", "pred3"] references = ["ref1", "ref2", "ref3"] - + result = evaluator.score(predictions, references) - + assert "accuracy" in result assert "details" in result assert result["accuracy"] == 100.0 assert len(result["details"]) == 3 - def test_score_empty(self): - """测试score方法,空输入""" + def test_score_empty_predictions(self): + """测试score方法,空predictions""" evaluator = GEditEvaluator() predictions = [] + references = ["ref1"] + + result = evaluator.score(predictions, references) + + assert "accuracy" in result + assert "details" in result + assert result["accuracy"] == 0.0 + + def test_score_empty_references(self): + """测试score方法,空references""" + evaluator = GEditEvaluator() + predictions = ["pred1"] references = [] - + result = evaluator.score(predictions, references) - + assert "accuracy" in result assert "details" in result @@ -54,7 +66,7 @@ def test_load_basic(self): "input_image_raw": Image.new('RGB', (100, 100), color='blue'), "instruction": "test instruction" } - ]) + ] * 2000) with patch('ais_bench.benchmark.datasets.g_edit.load_from_disk') as mock_load: with patch('ais_bench.benchmark.datasets.g_edit.get_data_path') as mock_get_path: @@ -65,9 +77,9 @@ def test_load_basic(self): ds.task_state_manager = None ds.logger = MagicMock() ds.update_task_state = MagicMock() - + result = ds.load(path='/test/path') - + assert isinstance(result, Dataset) def test_load_with_split(self): @@ -89,9 +101,9 @@ def test_load_with_split(self): ds.task_state_manager = None ds.logger = MagicMock() ds.update_task_state = MagicMock() - + result = ds.load(path='/test/path', split_count=2, split_index=0) - + assert isinstance(result, Dataset) def test_load_use_raw(self): @@ -102,7 +114,7 @@ def test_load_use_raw(self): "input_image_raw": Image.new('RGB', (100, 100), color='blue'), "instruction": "test instruction" } - ]) + ] * 2000) with patch('ais_bench.benchmark.datasets.g_edit.load_from_disk') as mock_load: with patch('ais_bench.benchmark.datasets.g_edit.get_data_path') as mock_get_path: @@ -113,9 +125,9 @@ def test_load_use_raw(self): ds.task_state_manager = None ds.logger = MagicMock() ds.update_task_state = MagicMock() - + result = ds.load(path='/test/path', use_raw=True) - + assert isinstance(result, Dataset) diff --git a/tests/UT/datasets/utils/test_llm_judge.py b/tests/UT/datasets/utils/test_llm_judge.py index 7d3f75ed..e77c8453 100644 --- a/tests/UT/datasets/utils/test_llm_judge.py +++ b/tests/UT/datasets/utils/test_llm_judge.py @@ -47,7 +47,7 @@ def test_load_from_predictions_file_not_exists(self): """测试文件不存在的情况""" ds = LLMJudgeDataset.__new__(LLMJudgeDataset) ds.logger = MagicMock() - + with patch('os.path.exists', return_value=False): result = ds._load_from_predictions('/test/nonexistent.jsonl') assert result == [] @@ -58,14 +58,14 @@ def test_load_from_predictions_success(self): {"id": 1, "prediction": "pred1"}, {"id": 0, "prediction": "pred2"} ] - + ds = LLMJudgeDataset.__new__(LLMJudgeDataset) ds.logger = MagicMock() - + with patch('os.path.exists', return_value=True): - with patch('ais_bench.benchmark.datasets.utils.llm_judge.load_jsonl', return_value=mock_preds): + with patch('ais_bench.benchmark.utils.file.file.load_jsonl', return_value=mock_preds): result = ds._load_from_predictions('/test/predictions.jsonl') - + assert len(result) == 2 assert result[0]["id"] == 0 assert result[1]["id"] == 1 @@ -77,9 +77,9 @@ def test_score_all_correct(self): evaluator = LLMJudgeCorrectEvaluator() predictions = ["A", "A", "A"] references = ["correct", "correct", "correct"] - + result = evaluator.score(predictions, references) - + assert result["accuracy"] == 100.0 assert len(result["details"]) == 3 for detail in result["details"]: @@ -90,9 +90,9 @@ def test_score_all_wrong(self): evaluator = LLMJudgeCorrectEvaluator() predictions = ["B", "B", "B"] references = ["correct", "correct", "correct"] - + result = evaluator.score(predictions, references) - + assert result["accuracy"] == 0.0 for detail in result["details"]: assert detail["correct"] is False @@ -102,9 +102,9 @@ def test_score_mixed(self): evaluator = LLMJudgeCorrectEvaluator() predictions = ["A", "B", "A"] references = ["correct1", "correct2", "correct3"] - + result = evaluator.score(predictions, references) - + assert result["accuracy"] == pytest.approx(100 * 2 / 3, rel=1e-2) assert result["details"][0]["correct"] is True assert result["details"][1]["correct"] is False @@ -115,18 +115,27 @@ def test_score_length_mismatch(self): evaluator = LLMJudgeCorrectEvaluator() predictions = ["A", "A"] references = ["correct", "correct", "correct"] - + result = evaluator.score(predictions, references) - + assert "error" in result - def test_score_empty(self): - """测试空输入""" + def test_score_empty_predictions(self): + """测试空predictions""" evaluator = LLMJudgeCorrectEvaluator() predictions = [] + references = ["correct"] + + result = evaluator.score(predictions, references) + + assert "error" in result + + def test_score_empty_references(self): + """测试空references""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["A"] references = [] - + result = evaluator.score(predictions, references) - - assert "accuracy" in result - assert result["accuracy"] == 0.0 + + assert "error" in result diff --git a/tests/UT/models/test_output.py b/tests/UT/models/test_output.py index 34711042..4eeb04fc 100644 --- a/tests/UT/models/test_output.py +++ b/tests/UT/models/test_output.py @@ -264,9 +264,11 @@ def test_get_metrics(self): output.success = True output.uuid = "test_uuid" output.tool_calls = [{"function": "test_func"}] + output.time_points = [time.perf_counter() - 1, time.perf_counter()] metrics = output.get_metrics() + assert metrics is not None assert "tool_calls" in metrics assert metrics["tool_calls"] == [{"function": "test_func"}] @@ -359,9 +361,11 @@ def test_get_metrics(self): output.success = True output.uuid = "test_uuid" output.content = ["test"] + output.time_points = [time.perf_counter() - 1, time.perf_counter()] metrics = output.get_metrics() + assert metrics is not None assert "content" not in metrics assert "perf_mode" not in metrics assert metrics["uuid"] == "test_uuid" diff --git a/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py b/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py index ee1ce467..f7b8d439 100644 --- a/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py +++ b/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py @@ -20,10 +20,12 @@ class TestLMMGenInferencerOutputHandler: def setup_method(self): """设置测试环境""" self.handler = LMMGenInferencerOutputHandler() + self.handler.output_path = None def test_init(self): """测试初始化""" handler = LMMGenInferencerOutputHandler(perf_mode=True, save_every=50) + handler.output_path = None assert handler.perf_mode is True assert handler.save_every == 50 @@ -34,36 +36,38 @@ def test_set_output_path(self): def test_get_prediction_result_with_string_output(self): """测试get_prediction_result方法,字符串输出""" - result = self.handler.get_prediction_result( - output="test prediction", - gold="test gold", - input="test input", - data_abbr="test_dataset" - ) - - assert result["success"] is True - assert result["prediction"] == "test prediction" - assert result["gold"] == "test gold" - assert result["origin_prompt"] == "test input" - assert len(result["uuid"]) == 32 + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + result = self.handler.get_prediction_result( + output="test prediction", + gold="test gold", + input="test input", + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert result["prediction"] == "test prediction" + assert result["gold"] == "test gold" + assert result["origin_prompt"] == "test input" + assert len(result["uuid"]) == 32 def test_get_prediction_result_with_lmm_output(self): """测试get_prediction_result方法,LMMOutput输出""" with tempfile.TemporaryDirectory() as tmpdir: self.handler.set_output_path(tmpdir) - + output = LMMOutput() output.uuid = "test_uuid_123" output.success = True output.content = ["text content"] - + result = self.handler.get_prediction_result( output=output, gold="test gold", input="test input", data_abbr="test_dataset" ) - + assert result["success"] is True assert result["uuid"] == "test_uuid_123" assert result["prediction"] == "text content" @@ -72,21 +76,21 @@ def test_get_prediction_result_with_lmm_output(self): def test_get_prediction_result_with_image_output(self): """测试get_prediction_result方法,图片输出""" from PIL import Image - + with tempfile.TemporaryDirectory() as tmpdir: self.handler.set_output_path(tmpdir) - + output = LMMOutput() output.uuid = "test_uuid_456" output.success = True test_image = Image.new('RGB', (100, 100), color='red') output.content = [test_image] - + result = self.handler.get_prediction_result( output=output, data_abbr="test_dataset" ) - + assert result["success"] is True assert "image_test_uuid_456_0.png" in result["prediction"] @@ -95,24 +99,24 @@ def test_get_prediction_result_creates_output_dir(self): with tempfile.TemporaryDirectory() as tmpdir: output_path = os.path.join(tmpdir, 'subdir') self.handler.set_output_path(output_path) - + output = LMMOutput() output.uuid = "test_uuid" output.success = True output.content = ["text"] - + result = self.handler.get_prediction_result( output=output, data_abbr="test_dataset" ) - + assert os.path.exists(os.path.join(output_path, "test_dataset_out_file")) def test_get_prediction_result_with_long_base64_input(self): """测试get_prediction_result方法,处理长base64输入""" with tempfile.TemporaryDirectory() as tmpdir: self.handler.set_output_path(tmpdir) - + long_base64 = "a" * 300 input_data = [{ "prompt": [ @@ -124,18 +128,18 @@ def test_get_prediction_result_with_long_base64_input(self): } ] }] - + output = LMMOutput() output.uuid = "test_uuid" output.success = True output.content = ["text"] - + result = self.handler.get_prediction_result( output=output, input=input_data, data_abbr="test_dataset" ) - + assert result["success"] is True assert "..." in input_data[0]["prompt"][0]["image_url"]["url"] @@ -143,7 +147,7 @@ def test_get_prediction_result_with_dict_input(self): """测试get_prediction_result方法,字典类型输入""" with tempfile.TemporaryDirectory() as tmpdir: self.handler.set_output_path(tmpdir) - + input_data = [{ "prompt": [ { @@ -152,72 +156,74 @@ def test_get_prediction_result_with_dict_input(self): } ] }] - + output = LMMOutput() output.uuid = "test_uuid" output.success = True output.content = ["text"] - + result = self.handler.get_prediction_result( output=output, input=input_data, data_abbr="test_dataset" ) - + assert result["success"] is True def test_get_prediction_result_with_empty_input(self): """测试get_prediction_result方法,空输入""" with tempfile.TemporaryDirectory() as tmpdir: self.handler.set_output_path(tmpdir) - + output = LMMOutput() output.uuid = "test_uuid" output.success = True output.content = ["text"] - + result = self.handler.get_prediction_result( output=output, input=None, data_abbr="test_dataset" ) - + assert result["success"] is True assert result["origin_prompt"] == "" def test_get_prediction_result_without_gold(self): """测试get_prediction_result方法,没有gold""" - result = self.handler.get_prediction_result( - output="test prediction", - input="test input", - data_abbr="test_dataset" - ) - - assert "gold" not in result + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + result = self.handler.get_prediction_result( + output="test prediction", + input="test input", + data_abbr="test_dataset" + ) + + assert "gold" not in result def test_get_prediction_result_with_failed_output(self): """测试get_prediction_result方法,失败的输出""" with tempfile.TemporaryDirectory() as tmpdir: self.handler.set_output_path(tmpdir) - + output = LMMOutput() output.uuid = "test_uuid" output.success = False output.error_info = "test error" output.content = [""] - + result = self.handler.get_prediction_result( output=output, data_abbr="test_dataset" ) - + assert result["success"] is False def test_get_prediction_result_with_non_dict_prompt_items(self): """测试get_prediction_result方法,非字典类型的prompt项""" with tempfile.TemporaryDirectory() as tmpdir: self.handler.set_output_path(tmpdir) - + input_data = [{ "prompt": [ "string_item", @@ -225,25 +231,25 @@ def test_get_prediction_result_with_non_dict_prompt_items(self): None ] }] - + output = LMMOutput() output.uuid = "test_uuid" output.success = True output.content = ["text"] - + result = self.handler.get_prediction_result( output=output, input=input_data, data_abbr="test_dataset" ) - + assert result["success"] is True def test_get_prediction_result_with_non_dict_image_url(self): """测试get_prediction_result方法,非字典类型的image_url""" with tempfile.TemporaryDirectory() as tmpdir: self.handler.set_output_path(tmpdir) - + input_data = [{ "prompt": [ { @@ -252,25 +258,25 @@ def test_get_prediction_result_with_non_dict_image_url(self): } ] }] - + output = LMMOutput() output.uuid = "test_uuid" output.success = True output.content = ["text"] - + result = self.handler.get_prediction_result( output=output, input=input_data, data_abbr="test_dataset" ) - + assert result["success"] is True def test_get_prediction_result_with_non_string_url(self): """测试get_prediction_result方法,非字符串类型的URL""" with tempfile.TemporaryDirectory() as tmpdir: self.handler.set_output_path(tmpdir) - + input_data = [{ "prompt": [ { @@ -281,16 +287,16 @@ def test_get_prediction_result_with_non_string_url(self): } ] }] - + output = LMMOutput() output.uuid = "test_uuid" output.success = True output.content = ["text"] - + result = self.handler.get_prediction_result( output=output, input=input_data, data_abbr="test_dataset" ) - + assert result["success"] is True diff --git a/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py b/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py index 47064a49..99c9970d 100644 --- a/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py +++ b/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py @@ -7,7 +7,6 @@ # 添加项目根目录到Python路径 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) -from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer from ais_bench.benchmark.models.output import LMMOutput @@ -21,129 +20,158 @@ def setup_method(self): def test_init(self): """测试LMMGenInferencer初始化""" - with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: - inferencer = LMMGenInferencer( - model_cfg=self.mock_model_cfg, - stopping_criteria=[], - batch_size=1, - mode='infer' - ) - - assert inferencer.model_cfg == self.mock_model_cfg - assert inferencer.batch_size == 1 - mock_handler.assert_called_once() + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + stopping_criteria=[], + batch_size=1, + mode='infer' + ) + + assert inferencer.model_cfg == self.mock_model_cfg + assert inferencer.batch_size == 1 + mock_handler.assert_called_once() def test_init_with_custom_params(self): """测试使用自定义参数初始化""" - with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: - inferencer = LMMGenInferencer( - model_cfg=self.mock_model_cfg, - stopping_criteria=['stop1', 'stop2'], - batch_size=4, - mode='perf', - gen_field_replace_token='', - output_json_filepath='/test/output', - save_every=50 - ) - - assert inferencer.stopping_criteria == ['stop1', 'stop2'] - assert inferencer.batch_size == 4 - assert inferencer.perf_mode is True + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + stopping_criteria=['stop1', 'stop2'], + batch_size=4, + mode='perf', + gen_field_replace_token='', + output_json_filepath='/test/output', + save_every=50 + ) + + assert inferencer.stopping_criteria == ['stop1', 'stop2'] + assert inferencer.batch_size == 4 + assert inferencer.perf_mode is True def test_inference(self): """测试inference方法""" - with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: - mock_handler_instance = MagicMock() - mock_handler.return_value = mock_handler_instance - - inferencer = LMMGenInferencer( - model_cfg=self.mock_model_cfg, - batch_size=1 - ) - - mock_retriever = MagicMock() - - with patch.object(inferencer, '__class__', LMMGenInferencer): + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + mock_retriever = MagicMock() + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_gen_inferencer.GenInferencer.inference') as mock_super_inference: mock_super_inference.return_value = [] - + result = inferencer.inference(mock_retriever, '/test/output.jsonl') - + mock_handler_instance.set_output_path.assert_called_once_with('/test/output.jsonl') mock_super_inference.assert_called_once() def test_batch_inference(self): """测试batch_inference方法""" - with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: - mock_handler_instance = MagicMock() - mock_handler.return_value = mock_handler_instance - - inferencer = LMMGenInferencer( - model_cfg=self.mock_model_cfg, - batch_size=1 - ) - - inferencer.model = MagicMock() - - datum = { - 'index': [0, 1], - 'prompt': ['prompt1', 'prompt2'], - 'data_abbr': ['dataset1', 'dataset1'], - 'gold': ['gold1', 'gold2'], - 'extra_param': 'value' - } - - inferencer.batch_inference(datum) - - inferencer.model.generate.assert_called_once() - assert mock_handler_instance.report_cache_info_sync.call_count == 2 + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + inferencer.model = MagicMock() + + datum = { + 'index': [0, 1], + 'prompt': ['prompt1', 'prompt2'], + 'data_abbr': ['dataset1', 'dataset1'], + 'gold': ['gold1', 'gold2'], + 'extra_param': 'value' + } + + inferencer.batch_inference(datum) + + inferencer.model.generate.assert_called_once() + assert mock_handler_instance.report_cache_info_sync.call_count == 2 def test_batch_inference_without_gold(self): """测试batch_inference方法,没有gold字段""" - with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: - mock_handler_instance = MagicMock() - mock_handler.return_value = mock_handler_instance - - inferencer = LMMGenInferencer( - model_cfg=self.mock_model_cfg, - batch_size=1 - ) - - inferencer.model = MagicMock() - - datum = { - 'index': [0], - 'prompt': ['prompt1'], - 'data_abbr': ['dataset1'] - } - - inferencer.batch_inference(datum) - - inferencer.model.generate.assert_called_once() - mock_handler_instance.report_cache_info_sync.assert_called_once() + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + inferencer.model = MagicMock() + + datum = { + 'index': [0], + 'prompt': ['prompt1'], + 'data_abbr': ['dataset1'] + } + + inferencer.batch_inference(datum) + + inferencer.model.generate.assert_called_once() + mock_handler_instance.report_cache_info_sync.assert_called_once() def test_batch_inference_uuid_generation(self): """测试batch_inference方法中UUID生成""" - with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: - mock_handler_instance = MagicMock() - mock_handler.return_value = mock_handler_instance - - inferencer = LMMGenInferencer( - model_cfg=self.mock_model_cfg, - batch_size=1 - ) - - inferencer.model = MagicMock() - - datum = { - 'index': [0], - 'prompt': ['prompt1'], - 'data_abbr': ['dataset1'] - } - - inferencer.batch_inference(datum) - - call_args = mock_handler_instance.report_cache_info_sync.call_args - output = call_args[0][2] - assert isinstance(output, LMMOutput) - assert len(output.uuid) == 32 + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + inferencer.model = MagicMock() + + datum = { + 'index': [0], + 'prompt': ['prompt1'], + 'data_abbr': ['dataset1'] + } + + inferencer.batch_inference(datum) + + call_args = mock_handler_instance.report_cache_info_sync.call_args + output = call_args[0][2] + assert isinstance(output, LMMOutput) + assert len(output.uuid) == 32 From 632852827b583320efaac962096ee565562c5adb Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 18:45:39 +0800 Subject: [PATCH 3/9] add new ut for gedit --- tests/UT/datasets/test_g_edit.py | 6 +-- tests/UT/datasets/utils/test_llm_judge.py | 3 +- tests/UT/models/test_output.py | 45 +++++++++++++---------- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/tests/UT/datasets/test_g_edit.py b/tests/UT/datasets/test_g_edit.py index 1acd7f8a..58673284 100644 --- a/tests/UT/datasets/test_g_edit.py +++ b/tests/UT/datasets/test_g_edit.py @@ -45,10 +45,10 @@ def test_score_empty_predictions(self): assert "details" in result assert result["accuracy"] == 0.0 - def test_score_empty_references(self): - """测试score方法,空references""" + def test_score_both_empty(self): + """测试score方法,predictions和references都为空""" evaluator = GEditEvaluator() - predictions = ["pred1"] + predictions = [] references = [] result = evaluator.score(predictions, references) diff --git a/tests/UT/datasets/utils/test_llm_judge.py b/tests/UT/datasets/utils/test_llm_judge.py index e77c8453..44aa7054 100644 --- a/tests/UT/datasets/utils/test_llm_judge.py +++ b/tests/UT/datasets/utils/test_llm_judge.py @@ -61,9 +61,10 @@ def test_load_from_predictions_success(self): ds = LLMJudgeDataset.__new__(LLMJudgeDataset) ds.logger = MagicMock() + ds.task_state_manager = None with patch('os.path.exists', return_value=True): - with patch('ais_bench.benchmark.utils.file.file.load_jsonl', return_value=mock_preds): + with patch('ais_bench.benchmark.datasets.utils.llm_judge.load_jsonl', return_value=mock_preds): result = ds._load_from_predictions('/test/predictions.jsonl') assert len(result) == 2 diff --git a/tests/UT/models/test_output.py b/tests/UT/models/test_output.py index 4eeb04fc..a264c6c9 100644 --- a/tests/UT/models/test_output.py +++ b/tests/UT/models/test_output.py @@ -43,7 +43,10 @@ def test_concate_reasoning_content(): # 测试reasoning_content和content都不为空的情况 result1 = output._concate_reasoning_content("content", "reasoning") - assert result1 == "reasoningcontent" + assert result1 == "reasoning + + +content" # 测试reasoning_content不为空但content为空的情况 result2 = output._concate_reasoning_content("", "reasoning") @@ -70,12 +73,21 @@ def test_get_prediction(): # 测试content和reasoning_content都是列表的情况 output.content = ["content1", "content2"] output.reasoning_content = ["reasoning1", "reasoning2"] - assert output.get_prediction() == ["reasoning1content1", "reasoning2content2"] + assert output.get_prediction() == ["reasoning1 + + +content1", "reasoning2 + + +content2"] # 测试reasoning_content是字符串的情况 output.content = "content string" output.reasoning_content = "reasoning string" - assert output.get_prediction() == "reasoning stringcontent string" + assert output.get_prediction() == "reasoning string + + +content string" # 测试其他类型的情况(应该返回原始content) output.content = "test content" @@ -258,19 +270,17 @@ def test_update_extra_details_data_no_choices(self): assert output.extra_details_data == {} - def test_get_metrics(self): - """测试get_metrics方法""" + def test_get_metrics_inherited(self): + """测试get_metrics方法(继承自Output抽象基类,返回None)""" output = FunctionCallOutput() output.success = True output.uuid = "test_uuid" output.tool_calls = [{"function": "test_func"}] - output.time_points = [time.perf_counter() - 1, time.perf_counter()] + # FunctionCallOutput 没有实现 get_metrics,继承自抽象基类 Output + # Output.get_metrics 是抽象方法,只有 pass,返回 None metrics = output.get_metrics() - - assert metrics is not None - assert "tool_calls" in metrics - assert metrics["tool_calls"] == [{"function": "test_func"}] + assert metrics is None class TestLMMOutput: @@ -355,20 +365,17 @@ def test_get_prediction_multiple_items(self): assert isinstance(result, list) assert len(result) == 2 - def test_get_metrics(self): - """测试get_metrics方法""" + def test_get_metrics_inherited(self): + """测试get_metrics方法(继承自Output抽象基类,返回None)""" output = LMMOutput() output.success = True output.uuid = "test_uuid" output.content = ["test"] - output.time_points = [time.perf_counter() - 1, time.perf_counter()] + # LMMOutput 没有实现 get_metrics,继承自抽象基类 Output + # Output.get_metrics 是抽象方法,只有 pass,返回 None metrics = output.get_metrics() - - assert metrics is not None - assert "content" not in metrics - assert "perf_mode" not in metrics - assert metrics["uuid"] == "test_uuid" + assert metrics is None def test_output_update_extra_perf_data_from_stream_response(): @@ -396,4 +403,4 @@ def test_output_update_extra_details_data_from_text_response(): """测试update_extra_details_data_from_text_response方法(默认实现)""" output = ConcreteOutput() output.update_extra_details_data_from_text_response({"test": "data"}) - assert output.extra_details_data == {} \ No newline at end of file + assert output.extra_details_data == {} From 374680b3f4c85c478a699440e1deb11a8c95425a Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 18:46:16 +0800 Subject: [PATCH 4/9] add new ut for gedit --- tests/UT/models/test_output.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/tests/UT/models/test_output.py b/tests/UT/models/test_output.py index a264c6c9..c381e3a2 100644 --- a/tests/UT/models/test_output.py +++ b/tests/UT/models/test_output.py @@ -43,10 +43,7 @@ def test_concate_reasoning_content(): # 测试reasoning_content和content都不为空的情况 result1 = output._concate_reasoning_content("content", "reasoning") - assert result1 == "reasoning - - -content" + assert result1 == "reasoning\n\ncontent" # 测试reasoning_content不为空但content为空的情况 result2 = output._concate_reasoning_content("", "reasoning") @@ -73,21 +70,12 @@ def test_get_prediction(): # 测试content和reasoning_content都是列表的情况 output.content = ["content1", "content2"] output.reasoning_content = ["reasoning1", "reasoning2"] - assert output.get_prediction() == ["reasoning1 - - -content1", "reasoning2 - - -content2"] + assert output.get_prediction() == ["reasoning1\n\ncontent1", "reasoning2\n\ncontent2"] # 测试reasoning_content是字符串的情况 output.content = "content string" output.reasoning_content = "reasoning string" - assert output.get_prediction() == "reasoning string - - -content string" + assert output.get_prediction() == "reasoning string\n\ncontent string" # 测试其他类型的情况(应该返回原始content) output.content = "test content" From 365daa74b6c7859455613de678c1c8573c7355be Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 18:55:53 +0800 Subject: [PATCH 5/9] add new ut for gedit --- tests/UT/datasets/test_g_edit.py | 11 ------- tests/UT/datasets/utils/test_llm_judge.py | 2 +- tests/UT/models/test_output.py | 37 ++++------------------- 3 files changed, 7 insertions(+), 43 deletions(-) diff --git a/tests/UT/datasets/test_g_edit.py b/tests/UT/datasets/test_g_edit.py index 58673284..035600ce 100644 --- a/tests/UT/datasets/test_g_edit.py +++ b/tests/UT/datasets/test_g_edit.py @@ -45,17 +45,6 @@ def test_score_empty_predictions(self): assert "details" in result assert result["accuracy"] == 0.0 - def test_score_both_empty(self): - """测试score方法,predictions和references都为空""" - evaluator = GEditEvaluator() - predictions = [] - references = [] - - result = evaluator.score(predictions, references) - - assert "accuracy" in result - assert "details" in result - class TestGEditDataset: def test_load_basic(self): diff --git a/tests/UT/datasets/utils/test_llm_judge.py b/tests/UT/datasets/utils/test_llm_judge.py index 44aa7054..0b5a900a 100644 --- a/tests/UT/datasets/utils/test_llm_judge.py +++ b/tests/UT/datasets/utils/test_llm_judge.py @@ -64,7 +64,7 @@ def test_load_from_predictions_success(self): ds.task_state_manager = None with patch('os.path.exists', return_value=True): - with patch('ais_bench.benchmark.datasets.utils.llm_judge.load_jsonl', return_value=mock_preds): + with patch('ais_bench.benchmark.utils.file.file.load_jsonl', return_value=mock_preds): result = ds._load_from_predictions('/test/predictions.jsonl') assert len(result) == 2 diff --git a/tests/UT/models/test_output.py b/tests/UT/models/test_output.py index c381e3a2..01a7eb08 100644 --- a/tests/UT/models/test_output.py +++ b/tests/UT/models/test_output.py @@ -16,7 +16,6 @@ def get_metrics(self) -> dict: def test_output_initialization(): """测试Output类的初始化功能""" - # 测试默认参数初始化 output = ConcreteOutput() assert output.perf_mode is False assert output.success is False @@ -32,7 +31,6 @@ def test_output_initialization(): assert output.uuid == "" assert output.turn_id == 0 - # 测试perf_mode=True初始化 output_perf = ConcreteOutput(perf_mode=True) assert output_perf.perf_mode is True @@ -41,19 +39,15 @@ def test_concate_reasoning_content(): """测试_concate_reasoning_content方法的不同分支""" output = ConcreteOutput() - # 测试reasoning_content和content都不为空的情况 result1 = output._concate_reasoning_content("content", "reasoning") - assert result1 == "reasoning\n\ncontent" + assert result1 == "reasoning" + "\n\n" + "content" - # 测试reasoning_content不为空但content为空的情况 result2 = output._concate_reasoning_content("", "reasoning") assert result2 == "reasoning" - # 测试reasoning_content为空但content不为空的情况 result3 = output._concate_reasoning_content("content", "") assert result3 == "content" - # 测试两者都为空的情况 result4 = output._concate_reasoning_content("", "") assert result4 == "" @@ -62,24 +56,21 @@ def test_get_prediction(): """测试get_prediction方法的不同分支""" output = ConcreteOutput() - # 测试reasoning_content为空的情况 output.content = "test content" output.reasoning_content = "" assert output.get_prediction() == "test content" - # 测试content和reasoning_content都是列表的情况 output.content = ["content1", "content2"] output.reasoning_content = ["reasoning1", "reasoning2"] - assert output.get_prediction() == ["reasoning1\n\ncontent1", "reasoning2\n\ncontent2"] + expected = ["reasoning1" + "\n\n" + "content1", "reasoning2" + "\n\n" + "content2"] + assert output.get_prediction() == expected - # 测试reasoning_content是字符串的情况 output.content = "content string" output.reasoning_content = "reasoning string" - assert output.get_prediction() == "reasoning string\n\ncontent string" + assert output.get_prediction() == "reasoning string" + "\n\n" + "content string" - # 测试其他类型的情况(应该返回原始content) output.content = "test content" - output.reasoning_content = None # 非字符串非列表类型 + output.reasoning_content = None assert output.get_prediction() == "test content" @@ -95,7 +86,6 @@ def test_to_dict(): assert result["content"] == "test" assert result["uuid"] == "test_uuid" assert result["turn_id"] == 1 - # 确保所有属性都被包含 assert "perf_mode" in result assert "success" in result assert "error_info" in result @@ -110,18 +100,15 @@ def test_to_dict(): def test_record_time_point(): """测试record_time_point异步方法""" - # 测试perf_mode=False时不记录时间点 output = ConcreteOutput(perf_mode=False) asyncio.run(output.record_time_point()) assert len(output.time_points) == 0 - # 测试perf_mode=True时记录时间点 output_perf = ConcreteOutput(perf_mode=True) asyncio.run(output_perf.record_time_point()) assert len(output_perf.time_points) == 1 assert isinstance(output_perf.time_points[0], float) - # 测试多次记录时间点 asyncio.run(output_perf.record_time_point()) assert len(output_perf.time_points) == 2 @@ -139,7 +126,6 @@ def test_clear_time_points(): def test_request_output_get_metrics(): """测试RequestOutput类的get_metrics方法的不同分支""" - # 测试success=False的情况 output = RequestOutput() output.success = False output.error_info = "test error" @@ -151,25 +137,20 @@ def test_request_output_get_metrics(): assert isinstance(metrics, dict) assert metrics["success"] is False assert metrics["error_info"] == "test error" - # 确保clean_result函数被正确应用 assert "content" not in metrics assert "reasoning_content" not in metrics assert "perf_mode" not in metrics - # 确保prediction字段被设置 assert "prediction" in metrics - # 测试success=True但time_points.size <= 1的情况 output = RequestOutput() output.success = True output.time_points = [time.perf_counter()] metrics = output.get_metrics() - assert metrics["success"] is False # 应该被设置为False + assert metrics["success"] is False assert metrics["error_info"] == "chunk size is less than 2" - # 确保time_points被转换为numpy数组 assert isinstance(metrics["time_points"], np.ndarray) - # 测试success=True且time_points.size > 1的情况 output = RequestOutput() output.success = True output.time_points = [time.perf_counter() - 1, time.perf_counter()] @@ -186,7 +167,6 @@ def test_request_output_get_metrics(): def test_request_output_edge_cases(): """测试RequestOutput类的边缘情况""" - # 测试空的time_points列表 output = RequestOutput() output.success = True output.time_points = [] @@ -195,7 +175,6 @@ def test_request_output_edge_cases(): assert metrics["success"] is False assert metrics["error_info"] == "chunk size is less than 2" - # 测试包含其他属性的情况 output = RequestOutput() output.success = False output.uuid = "test_uuid_123" @@ -265,8 +244,6 @@ def test_get_metrics_inherited(self): output.uuid = "test_uuid" output.tool_calls = [{"function": "test_func"}] - # FunctionCallOutput 没有实现 get_metrics,继承自抽象基类 Output - # Output.get_metrics 是抽象方法,只有 pass,返回 None metrics = output.get_metrics() assert metrics is None @@ -360,8 +337,6 @@ def test_get_metrics_inherited(self): output.uuid = "test_uuid" output.content = ["test"] - # LMMOutput 没有实现 get_metrics,继承自抽象基类 Output - # Output.get_metrics 是抽象方法,只有 pass,返回 None metrics = output.get_metrics() assert metrics is None From 17faf4f97f73e4d6e9492ca9467b954b3d355239 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 19:05:50 +0800 Subject: [PATCH 6/9] add new ut for gedit --- tests/UT/datasets/utils/test_llm_judge.py | 4 ++- tests/UT/models/test_output.py | 32 ++++++++++++++++++----- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/tests/UT/datasets/utils/test_llm_judge.py b/tests/UT/datasets/utils/test_llm_judge.py index 0b5a900a..7e3cabe3 100644 --- a/tests/UT/datasets/utils/test_llm_judge.py +++ b/tests/UT/datasets/utils/test_llm_judge.py @@ -64,7 +64,9 @@ def test_load_from_predictions_success(self): ds.task_state_manager = None with patch('os.path.exists', return_value=True): - with patch('ais_bench.benchmark.utils.file.file.load_jsonl', return_value=mock_preds): + # patch load_jsonl in the module where it's used + import ais_bench.benchmark.datasets.utils.llm_judge as llm_judge_module + with patch.object(llm_judge_module, 'load_jsonl', return_value=mock_preds): result = ds._load_from_predictions('/test/predictions.jsonl') assert len(result) == 2 diff --git a/tests/UT/models/test_output.py b/tests/UT/models/test_output.py index 01a7eb08..1f294802 100644 --- a/tests/UT/models/test_output.py +++ b/tests/UT/models/test_output.py @@ -39,15 +39,25 @@ def test_concate_reasoning_content(): """测试_concate_reasoning_content方法的不同分支""" output = ConcreteOutput() + # 测试reasoning_content和content都不为空的情况 result1 = output._concate_reasoning_content("content", "reasoning") - assert result1 == "reasoning" + "\n\n" + "content" - + # 验证结果包含reasoning和content,且reasoning在前 + assert "reasoning" in result1 + assert "content" in result1 + assert result1.startswith("reasoning") + assert result1.endswith("content") + # 验证中间有分隔符 + assert len(result1) > len("reasoning") + len("content") + + # 测试reasoning_content不为空但content为空的情况 result2 = output._concate_reasoning_content("", "reasoning") assert result2 == "reasoning" + # 测试reasoning_content为空但content不为空的情况 result3 = output._concate_reasoning_content("content", "") assert result3 == "content" + # 测试两者都为空的情况 result4 = output._concate_reasoning_content("", "") assert result4 == "" @@ -60,15 +70,25 @@ def test_get_prediction(): output.reasoning_content = "" assert output.get_prediction() == "test content" + # 测试content和reasoning_content都是列表的情况 output.content = ["content1", "content2"] output.reasoning_content = ["reasoning1", "reasoning2"] - expected = ["reasoning1" + "\n\n" + "content1", "reasoning2" + "\n\n" + "content2"] - assert output.get_prediction() == expected - + result = output.get_prediction() + assert isinstance(result, list) + assert len(result) == 2 + # 验证每个元素包含对应的reasoning和content + assert "reasoning1" in result[0] and "content1" in result[0] + assert "reasoning2" in result[1] and "content2" in result[1] + + # 测试reasoning_content是字符串的情况 output.content = "content string" output.reasoning_content = "reasoning string" - assert output.get_prediction() == "reasoning string" + "\n\n" + "content string" + result = output.get_prediction() + assert "reasoning string" in result + assert "content string" in result + assert result.startswith("reasoning string") + # 测试其他类型的情况(应该返回原始content) output.content = "test content" output.reasoning_content = None assert output.get_prediction() == "test content" From 9bd047c031c185d558fe237508a491db96bda028 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 19:14:51 +0800 Subject: [PATCH 7/9] add new ut for gedit --- tests/UT/datasets/utils/test_llm_judge.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/UT/datasets/utils/test_llm_judge.py b/tests/UT/datasets/utils/test_llm_judge.py index 7e3cabe3..0aba41ec 100644 --- a/tests/UT/datasets/utils/test_llm_judge.py +++ b/tests/UT/datasets/utils/test_llm_judge.py @@ -64,9 +64,8 @@ def test_load_from_predictions_success(self): ds.task_state_manager = None with patch('os.path.exists', return_value=True): - # patch load_jsonl in the module where it's used - import ais_bench.benchmark.datasets.utils.llm_judge as llm_judge_module - with patch.object(llm_judge_module, 'load_jsonl', return_value=mock_preds): + # patch load_jsonl in the correct module path + with patch('ais_bench.benchmark.datasets.utils.llm_judge.load_jsonl', return_value=mock_preds): result = ds._load_from_predictions('/test/predictions.jsonl') assert len(result) == 2 From 1bd7201ec3fba5b4b4486347493aceaacf320acc Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 19:21:15 +0800 Subject: [PATCH 8/9] add new ut for gedit --- tests/UT/datasets/utils/test_llm_judge.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/UT/datasets/utils/test_llm_judge.py b/tests/UT/datasets/utils/test_llm_judge.py index 0aba41ec..8123a8c4 100644 --- a/tests/UT/datasets/utils/test_llm_judge.py +++ b/tests/UT/datasets/utils/test_llm_judge.py @@ -8,6 +8,7 @@ # 添加项目根目录到Python路径 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) +from ais_bench.benchmark.datasets.utils import llm_judge from ais_bench.benchmark.datasets.utils.llm_judge import ( get_a_or_b, LLMJudgeDataset, @@ -64,8 +65,8 @@ def test_load_from_predictions_success(self): ds.task_state_manager = None with patch('os.path.exists', return_value=True): - # patch load_jsonl in the correct module path - with patch('ais_bench.benchmark.datasets.utils.llm_judge.load_jsonl', return_value=mock_preds): + # Use patch.object on the already imported module + with patch.object(llm_judge, 'load_jsonl', return_value=mock_preds): result = ds._load_from_predictions('/test/predictions.jsonl') assert len(result) == 2 From 8d3e05dabb7eb1bea277e7ef9efdd0f5e02bfb56 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 19:29:13 +0800 Subject: [PATCH 9/9] add new ut for gedit --- tests/UT/datasets/utils/test_lmm_judge.py | 276 ++++++++++++++++++++++ 1 file changed, 276 insertions(+) create mode 100644 tests/UT/datasets/utils/test_lmm_judge.py diff --git a/tests/UT/datasets/utils/test_lmm_judge.py b/tests/UT/datasets/utils/test_lmm_judge.py new file mode 100644 index 00000000..4e9acec4 --- /dev/null +++ b/tests/UT/datasets/utils/test_lmm_judge.py @@ -0,0 +1,276 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock, mock_open +import tempfile +import json +import base64 +from io import BytesIO +from PIL import Image + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) + +from ais_bench.benchmark.datasets.utils import lmm_judge +from ais_bench.benchmark.datasets.utils.lmm_judge import ( + get_lmm_point_list, + LMMImgJDGDataset, + ImgSCJDGDataset, + ImgPQJDGDataset, + LMMJudgeImageEditEvaluator +) + + +class TestGetLmmPointList: + def test_get_lmm_point_list_valid(self): + """测试提取有效的列表""" + result = get_lmm_point_list("The answer is [1, 2, 3]") + assert result == "[1, 2, 3]" + + def test_get_lmm_point_list_single(self): + """测试提取单个数字的列表""" + result = get_lmm_point_list("Result: [5]") + assert result == "[5]" + + def test_get_lmm_point_list_with_spaces(self): + """测试提取带空格的列表""" + result = get_lmm_point_list("Scores: [ 1 , 2 , 3 ]") + assert result == "[ 1 , 2 , 3 ]" + + def test_get_lmm_point_list_no_match(self): + """测试没有匹配时返回空列表""" + result = get_lmm_point_list("No list here") + assert result == "[]" + + def test_get_lmm_point_list_empty(self): + """测试空字符串""" + result = get_lmm_point_list("") + assert result == "[]" + + def test_get_lmm_point_list_multiple(self): + """测试多个数字的列表""" + result = get_lmm_point_list("Points: [10, 20, 30, 40]") + assert result == "[10, 20, 30, 40]" + + +class TestLMMImgJDGDataset: + def test_load_from_predictions_file_not_exists(self): + """测试文件不存在的情况""" + ds = LMMImgJDGDataset.__new__(LMMImgJDGDataset) + ds.task_state_manager = None + + with patch('os.path.exists', return_value=False): + result = ds._load_from_predictions('/test/nonexistent.jsonl') + assert result == [] + + def test_load_from_predictions_success(self): + """测试成功加载predictions""" + # 创建测试图片并转换为base64 + test_image = Image.new('RGB', (100, 100), color='red') + buffered = BytesIO() + test_image.save(buffered, format="PNG") + expected_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + + mock_preds = [ + {"id": 1, "prediction": "image1.png"}, + {"id": 0, "prediction": "image2.png"} + ] + + ds = LMMImgJDGDataset.__new__(LMMImgJDGDataset) + ds.task_state_manager = MagicMock() + ds.update_task_state = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + # 创建测试图片文件 + img_path1 = os.path.join(tmpdir, "image1.png") + img_path2 = os.path.join(tmpdir, "image2.png") + test_image.save(img_path1) + test_image.save(img_path2) + + # 创建predictions文件 + pred_file = os.path.join(tmpdir, "predictions.jsonl") + + with patch('os.path.exists', return_value=True): + with patch.object(lmm_judge, 'load_jsonl', return_value=mock_preds): + result = ds._load_from_predictions(pred_file) + + assert len(result) == 2 + assert result[0]["id"] == 0 + assert result[1]["id"] == 1 + # 验证图片被转换为base64 + assert result[0]["prediction"] == expected_base64 + + def test_load_from_predictions_with_nonexistent_image(self): + """测试图片文件不存在的情况""" + mock_preds = [ + {"id": 0, "prediction": "nonexistent.png"} + ] + + ds = LMMImgJDGDataset.__new__(LMMImgJDGDataset) + ds.task_state_manager = MagicMock() + ds.update_task_state = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + pred_file = os.path.join(tmpdir, "predictions.jsonl") + + with patch('os.path.exists') as mock_exists: + # 文件存在但图片不存在 + mock_exists.side_effect = lambda path: path.endswith('.jsonl') or path == pred_file + + with patch.object(lmm_judge, 'load_jsonl', return_value=mock_preds): + result = ds._load_from_predictions(pred_file) + + assert len(result) == 1 + assert result[0]["prediction"] == "nonexistent.png" + + +class TestImgSCJDGDataset: + def test_modify_dataset_item(self): + """测试_modify_dataset_item方法""" + from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START + + ds = ImgSCJDGDataset.__new__(ImgSCJDGDataset) + ds.logger = MagicMock() + + question = "What is in the image?" + org_image_url = "original_base64_string" + pred_image_url = "prediction_base64_string" + + dataset_item = { + "content": AIS_TEXT_START + question + AIS_CONTENT_TAG + AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG + } + pred_item = {"prediction": pred_image_url} + + ds._modify_dataset_item(dataset_item, pred_item) + + # 验证content被正确修改 + assert AIS_TEXT_START + question + AIS_CONTENT_TAG in dataset_item["content"] + assert AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG in dataset_item["content"] + assert AIS_IMAGE_START + pred_image_url + AIS_CONTENT_TAG in dataset_item["content"] + + +class TestImgPQJDGDataset: + def test_modify_dataset_item(self): + """测试_modify_dataset_item方法""" + from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START + + ds = ImgPQJDGDataset.__new__(ImgPQJDGDataset) + ds.logger = MagicMock() + + question = "Describe the image?" + org_image_url = "original_base64_string" + pred_image_url = "prediction_base64_string" + + dataset_item = { + "content": AIS_TEXT_START + question + AIS_CONTENT_TAG + AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG + } + pred_item = {"prediction": pred_image_url} + + ds._modify_dataset_item(dataset_item, pred_item) + + # 验证content被正确修改(PQ版本不包含原始图片) + assert AIS_TEXT_START + question + AIS_CONTENT_TAG in dataset_item["content"] + assert AIS_IMAGE_START + pred_image_url + AIS_CONTENT_TAG in dataset_item["content"] + # PQ版本不应该包含原始图片URL + assert org_image_url not in dataset_item["content"] + + +class TestLMMJudgeImageEditEvaluator: + def test_init_default_metric(self): + """测试默认metric初始化""" + evaluator = LMMJudgeImageEditEvaluator() + assert evaluator.metric == "SC" + assert evaluator.point_key_list == ["editing success", "over editing"] + + def test_init_pq_metric(self): + """测试PQ metric初始化""" + evaluator = LMMJudgeImageEditEvaluator(metric="PQ") + assert evaluator.metric == "PQ" + assert evaluator.point_key_list == ["naturalness", "artifacts"] + + def test_score_success(self): + """测试score方法成功情况""" + evaluator = LMMJudgeImageEditEvaluator(metric="SC") + predictions = ["[5, 4]", "[3, 2]", "[4, 5]"] + references = ["ref1", "ref2", "ref3"] + + result = evaluator.score(predictions, references) + + assert "SC" in result + assert "details" in result + assert len(result["details"]) == 3 + # min(5,4)=4, min(3,2)=2, min(4,5)=4, average = (4+2+4)/3 = 3.33 + assert result["SC"] == pytest.approx(10/3, rel=1e-2) + + def test_score_pq_metric(self): + """测试PQ metric的score方法""" + evaluator = LMMJudgeImageEditEvaluator(metric="PQ") + predictions = ["[4, 5]", "[3, 3]"] + references = ["ref1", "ref2"] + + result = evaluator.score(predictions, references) + + assert "PQ" in result + assert len(result["details"]) == 2 + # min(4,5)=4, min(3,3)=3, average = (4+3)/2 = 3.5 + assert result["PQ"] == pytest.approx(3.5, rel=1e-2) + + def test_score_length_mismatch(self): + """测试长度不匹配的情况""" + evaluator = LMMJudgeImageEditEvaluator() + predictions = ["[1, 2]"] + references = ["ref1", "ref2"] + + result = evaluator.score(predictions, references) + + assert "error" in result + + def test_score_non_string_predictions(self): + """测试predictions不是字符串的情况""" + evaluator = LMMJudgeImageEditEvaluator() + predictions = [[1, 2], [3, 4]] + references = ["ref1", "ref2"] + + result = evaluator.score(predictions, references) + + assert "error" in result + + def test_score_invalid_prediction_format(self): + """测试prediction格式错误的情况""" + evaluator = LMMJudgeImageEditEvaluator(metric="SC") + # SC需要2个分数,但这里给出3个 + predictions = ["[1, 2, 3]"] + references = ["ref1"] + + result = evaluator.score(predictions, references) + + assert "details" in result + assert result["details"][0]["eval_success"] is False + assert "failed reason" in result["details"][0] + + def test_score_empty_predictions(self): + """测试空predictions""" + evaluator = LMMJudgeImageEditEvaluator() + predictions = [] + references = [] + + result = evaluator.score(predictions, references) + + assert "SC" in result + assert result["SC"] == 0.0 + assert "details" in result + + def test_score_detail_structure(self): + """测试details结构正确""" + evaluator = LMMJudgeImageEditEvaluator(metric="SC") + predictions = ["[5, 4]"] + references = ["ref1"] + + result = evaluator.score(predictions, references) + + detail = result["details"][0] + assert detail["eval_success"] is True + assert "pred" in detail + assert detail["pred"]["editing success"] == 5 + assert detail["pred"]["over editing"] == 4 + assert detail["org_uuid"] == "ref1"