From 1293c17db619c5d56193bf147e76954b5732a6d1 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Thu, 5 Mar 2026 08:54:28 +0800 Subject: [PATCH] ut part 1 --- tests/UT/cli/test_workers.py | 224 +++++++++++++++++- tests/UT/datasets/test_base.py | 184 ++++++++++++++- tests/UT/datasets/test_g_edit.py | 136 +++++++++++ tests/UT/datasets/utils/test_llm_judge.py | 144 +++++++++++ tests/UT/datasets/utils/test_lmm_judge.py | 276 ++++++++++++++++++++++ 5 files changed, 954 insertions(+), 10 deletions(-) create mode 100644 tests/UT/datasets/test_g_edit.py create mode 100644 tests/UT/datasets/utils/test_llm_judge.py create mode 100644 tests/UT/datasets/utils/test_lmm_judge.py diff --git a/tests/UT/cli/test_workers.py b/tests/UT/cli/test_workers.py index b8f3c8b1..ac883f4f 100644 --- a/tests/UT/cli/test_workers.py +++ b/tests/UT/cli/test_workers.py @@ -1,7 +1,7 @@ import sys import os import pytest -from unittest.mock import patch, MagicMock, call +from unittest.mock import patch, MagicMock, call, mock_open from collections import defaultdict # 添加项目根目录到Python路径 @@ -12,6 +12,7 @@ from ais_bench.benchmark.cli.workers import ( BaseWorker, Infer, + JudgeInfer, Eval, AccViz, PerfViz, @@ -664,11 +665,14 @@ def test_work_flow_dict(): assert 'viz' in WORK_FLOW assert 'perf' in WORK_FLOW assert 'perf_viz' in WORK_FLOW + assert 'judge' in WORK_FLOW + assert 'infer_judge' in WORK_FLOW # 验证工作流内容正确 assert Infer in WORK_FLOW['all'] assert Eval in WORK_FLOW['all'] assert AccViz in WORK_FLOW['all'] + assert JudgeInfer in WORK_FLOW['all'] assert Infer in WORK_FLOW['infer'] @@ -680,4 +684,220 @@ def test_work_flow_dict(): assert Infer in WORK_FLOW['perf'] assert PerfViz in WORK_FLOW['perf'] - assert PerfViz in WORK_FLOW['perf_viz'] \ No newline at end of file + assert PerfViz in WORK_FLOW['perf_viz'] + + assert JudgeInfer in WORK_FLOW['judge'] + assert Infer in WORK_FLOW['infer_judge'] + assert JudgeInfer in WORK_FLOW['infer_judge'] + + +class TestJudgeInfer: + def setup_method(self): + """设置测试环境""" + self.mock_args = MagicMock() + self.mock_args.max_num_workers = 4 + self.mock_args.max_workers_per_gpu = 2 + self.mock_args.debug = False + self.judge_infer_worker = JudgeInfer(self.mock_args) + + @patch('ais_bench.benchmark.cli.workers.get_config_type') + def test_update_cfg_service_model(self, mock_get_config_type): + """测试update_cfg方法,使用service模型""" + mock_get_config_type.side_effect = ['MockNaivePartitioner', 'MockOpenICLApiInferTask', 'MockLocalRunner'] + + cfg = MockConfigDict({ + 'datasets': [{ + 'judge_infer_cfg': { + 'judge_model': {'attr': 'service', 'abbr': 'judge_model'} + } + }], + 'work_dir': '/test/workdir', + 'cli_args': MagicMock(debug=False) + }) + + with patch('os.path.join', return_value='/test/workdir/predictions/'): + result = self.judge_infer_worker.update_cfg(cfg) + + assert result == cfg + assert cfg['judge_infer']['partitioner']['type'] == 'MockNaivePartitioner' + assert cfg['judge_infer']['runner']['type'] == 'MockLocalRunner' + assert cfg['judge_infer']['runner']['task']['type'] == 'MockOpenICLApiInferTask' + + @patch('ais_bench.benchmark.cli.workers.get_config_type') + def test_update_cfg_local_model(self, mock_get_config_type): + """测试update_cfg方法,使用local模型""" + mock_get_config_type.side_effect = ['MockNaivePartitioner', 'MockOpenICLInferTask', 'MockLocalRunner'] + + cfg = MockConfigDict({ + 'datasets': [{ + 'judge_infer_cfg': { + 'judge_model': {'attr': 'local', 'abbr': 'judge_model'} + } + }], + 'work_dir': '/test/workdir', + 'cli_args': MagicMock(debug=True) + }) + + with patch('os.path.join', return_value='/test/workdir/predictions/'): + self.judge_infer_worker.update_cfg(cfg) + + assert cfg['judge_infer']['runner']['task']['type'] == 'MockOpenICLInferTask' + assert cfg['judge_infer']['runner']['debug'] == True + + def test_cfg_pre_process(self): + """测试_cfg_pre_process方法""" + cfg = MockConfigDict({ + 'datasets': [ + { + 'abbr': 'test_dataset', + 'judge_infer_cfg': { + 'judge_model': {'abbr': 'judge_model'} + } + } + ] + }) + + self.judge_infer_worker._cfg_pre_process(cfg) + + assert cfg['datasets'][0]['abbr'] == 'test_dataset-judge_model' + assert 'test_dataset-judge_model' in self.judge_infer_worker.org_dataset_abbrs + + def test_cfg_pre_process_with_model_dataset_combinations(self): + """测试_cfg_pre_process方法,包含model_dataset_combinations""" + cfg = MockConfigDict({ + 'model_dataset_combinations': [ + { + 'datasets': [ + { + 'abbr': 'combo_dataset', + 'judge_infer_cfg': { + 'judge_model': {'abbr': 'judge_model'} + } + } + ] + } + ], + 'datasets': [] + }) + + self.judge_infer_worker._cfg_pre_process(cfg) + + assert cfg['model_dataset_combinations'][0]['datasets'][0]['abbr'] == 'combo_dataset-judge_model' + + def test_merge_datasets(self): + """测试_merge_datasets方法""" + task1 = { + 'models': [{'abbr': 'model1'}], + 'datasets': [[{'type': 'dataset_type', 'infer_cfg': {'inferencer': 'inferencer_type'}}]] + } + task2 = { + 'models': [{'abbr': 'model1'}], + 'datasets': [[{'type': 'dataset_type', 'infer_cfg': {'inferencer': 'inferencer_type'}}]] + } + task3 = { + 'models': [{'abbr': 'model2'}], + 'datasets': [[{'type': 'dataset_type', 'infer_cfg': {'inferencer': 'inferencer_type'}}]] + } + + result = self.judge_infer_worker._merge_datasets([task1, task2, task3]) + + assert len(result) == 2 + assert len(result[0]['datasets'][0]) == 2 + assert len(result[1]['datasets'][0]) == 1 + + @patch('ais_bench.benchmark.cli.workers.PARTITIONERS') + @patch('ais_bench.benchmark.cli.workers.RUNNERS') + @patch('ais_bench.benchmark.cli.workers.logger') + def test_do_work_no_tasks(self, mock_logger, mock_runners, mock_partitioners): + """测试do_work方法,没有有效任务的情况""" + mock_partitioner = MagicMock() + mock_partitioners.build.return_value = mock_partitioner + mock_tasks = [ + { + 'datasets': [[{}]] # 没有judge_infer_cfg + } + ] + mock_partitioner.return_value = mock_tasks + + cfg = MockConfigDict({ + 'judge_infer': { + 'partitioner': {}, + 'runner': {} + }, + 'datasets': [] + }) + + with patch.object(self.judge_infer_worker, '_cfg_pre_process'): + with patch.object(self.judge_infer_worker, '_update_tasks_cfg'): + self.judge_infer_worker.do_work(cfg) + + mock_runners.build.assert_not_called() + + @patch('ais_bench.benchmark.cli.workers.load_jsonl') + @patch('ais_bench.benchmark.cli.workers.dump_jsonl') + @patch('os.path.exists') + @patch('os.remove') + def test_result_post_process(self, mock_remove, mock_exists, mock_dump_jsonl, mock_load_jsonl): + """测试_result_post_process方法""" + mock_load_jsonl.side_effect = [ + [{'uuid': 'uuid1', 'id': 'id1'}], # model_preds + [{'gold': 'uuid1', 'prediction': 'pred1'}] # judge_preds + ] + mock_exists.return_value = True + + task = { + 'datasets': [[{ + 'predictions_path': '/test/model_pred.jsonl', + 'abbr': 'test_dataset-judge_model' + }]], + 'models': [{'abbr': 'model1'}] + } + tasks = [task] + + cfg = MockConfigDict({ + 'judge_infer': { + 'partitioner': { + 'out_dir': '/test/predictions' + } + } + }) + + self.judge_infer_worker.org_dataset_abbrs = {'test_dataset-judge_model': 'test_dataset'} + + self.judge_infer_worker._result_post_process(tasks, cfg) + + mock_remove.assert_called_once() + mock_dump_jsonl.assert_called_once() + + def test_update_tasks_cfg_with_judge_infer(self): + """测试_update_tasks_cfg方法,包含judge_infer_cfg""" + self.judge_infer_worker.org_dataset_abbrs = {'test_dataset-judge_model': 'test_dataset'} + + task = { + 'models': [{'abbr': 'model1'}], + 'datasets': [[{ + 'abbr': 'test_dataset-judge_model', + 'judge_infer_cfg': { + 'judge_model': {'type': 'judge_model_type'}, + 'judge_dataset_type': 'judge_dataset', + 'judge_reader_cfg': {'test': 'cfg'} + } + }]] + } + tasks = [task] + + cfg = MockConfigDict({ + 'judge_infer': { + 'partitioner': { + 'out_dir': '/test/predictions' + } + } + }) + + with patch('os.path.join', return_value='/test/predictions/model1/test_dataset.jsonl'): + with patch('os.path.exists', return_value=True): + self.judge_infer_worker._update_tasks_cfg(tasks, cfg) + + assert 'judge_infer_cfg' not in task['datasets'][0][0] + assert task['models'][0]['type'] == 'judge_model_type' + assert task['datasets'][0][0]['type'] == 'judge_dataset' \ No newline at end of file diff --git a/tests/UT/datasets/test_base.py b/tests/UT/datasets/test_base.py index a2bc8593..9f59d1db 100644 --- a/tests/UT/datasets/test_base.py +++ b/tests/UT/datasets/test_base.py @@ -1,9 +1,10 @@ import unittest from unittest.mock import patch, MagicMock +import pytest from datasets import Dataset, DatasetDict -from ais_bench.benchmark.datasets.base import BaseDataset +from ais_bench.benchmark.datasets.base import BaseDataset, BaseJDGDataset class DummyDataset(BaseDataset): @@ -35,7 +36,7 @@ def load(**kwargs): {"text": "a"}, {"text": "b"}, ]) - + def _init_reader(self, **kwargs): # 先正常初始化reader from ais_bench.benchmark.openicl.icl_dataset_reader import DatasetReader @@ -66,7 +67,7 @@ def test_repeated_dataset_and_metadata(self): first = ds.dataset['test'][0] self.assertIn("subdivision", first) self.assertIn("idx", first) - + def test_repeated_dataset_with_dataset_type(self): """测试当reader.dataset是Dataset类型时的处理(覆盖46-62行)""" # 创建一个返回Dataset的类,并手动设置reader.dataset为Dataset @@ -83,7 +84,7 @@ def test_repeated_dataset_with_dataset_type(self): first = ds.dataset[0] self.assertIn("subdivision", first) self.assertIn("idx", first) - + def test_train_property(self): """测试train属性(覆盖92行)""" ds = DummyDataset( @@ -94,7 +95,7 @@ def test_train_property(self): train = ds.train self.assertIsInstance(train, Dataset) self.assertGreater(len(train), 0) - + def test_test_property(self): """测试test属性(覆盖96行)""" ds = DummyDataset( @@ -105,17 +106,17 @@ def test_test_property(self): test = ds.test self.assertIsInstance(test, Dataset) self.assertGreater(len(test), 0) - + def test_repeated_dataset_with_large_batch_size(self): """测试大批量数据的批处理逻辑(覆盖批处理相关代码)""" # 创建一个较大的数据集来触发批处理逻辑 large_data = [{"text": f"item_{i}"} for i in range(15000)] - + class LargeDataset(BaseDataset): @staticmethod def load(**kwargs): return Dataset.from_list(large_data) - + ds = LargeDataset( reader_cfg={'input_columns': ['text'], 'output_column': None}, k=1, @@ -130,6 +131,173 @@ def load(**kwargs): self.assertIn("subdivision", first) self.assertIn("idx", first) + def test_init_with_task_state_manager(self): + """测试使用task_state_manager初始化""" + mock_task_state_manager = MagicMock() + ds = DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1, + task_state_manager=mock_task_state_manager + ) + self.assertEqual(ds.task_state_manager, mock_task_state_manager) + + def test_update_task_state_with_manager(self): + """测试使用task_state_manager更新状态""" + mock_task_state_manager = MagicMock() + ds = DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1, + task_state_manager=mock_task_state_manager + ) + state = {'status': 'processing'} + ds.update_task_state(state) + mock_task_state_manager.update_task_state.assert_called_once_with(state) + + def test_update_task_state_without_manager(self): + """测试没有task_state_manager时更新状态""" + ds = DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1 + ) + ds.task_state_manager = None + ds.logger = MagicMock() + # 不应抛出异常 + ds.update_task_state({'status': 'processing'}) + + def test_init_with_abbr(self): + """测试使用abbr参数初始化""" + ds = DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1, + abbr='custom_abbr' + ) + self.assertEqual(ds.abbr, 'custom_abbr') + + def test_init_k_greater_than_n_raises_error(self): + """测试k > n时抛出异常""" + from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError + with self.assertRaises(ParameterValueError): + DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=5, + n=3 + ) + + def test_init_k_list_greater_than_n_raises_error(self): + """测试k为列表且最大值 > n时抛出异常""" + from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError + with self.assertRaises(ParameterValueError): + DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=[1, 2, 5], + n=3 + ) + + def test_repeated_dataset_with_dataset_dict(self): + """测试DatasetDict类型的repeated_dataset处理""" + ds = DummyDatasetDict( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=2 + ) + self.assertIsInstance(ds.dataset, DatasetDict) + # 验证每个split都被正确处理 + self.assertEqual(len(ds.dataset['train']), 4) # 2 * 2 + self.assertEqual(len(ds.dataset['test']), 2) # 1 * 2 + # 验证元数据 + first_train = ds.dataset['train'][0] + self.assertIn("subdivision", first_train) + self.assertIn("idx", first_train) + + +class TestBaseJDGDataset(unittest.TestCase): + def test_init_org_datasets_instance(self): + """测试_init_org_datasets_instance方法""" + class DummyJDGDataset(BaseJDGDataset): + def _get_dataset_class(self): + return DummyDataset + def _load_from_predictions(self, prediction_path): + return [] + + with patch.object(DummyJDGDataset, 'load') as mock_load: + mock_load.return_value = Dataset.from_list([{"text": "a"}]) + ds = DummyJDGDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1 + ) + self.assertIsNotNone(ds.dataset_instance) + + def test_process_single_item(self): + """测试_process_single_item方法""" + class DummyJDGDataset(BaseJDGDataset): + def _get_dataset_class(self): + return DummyDataset + def _load_from_predictions(self, prediction_path): + return [] + + dataset_content = Dataset.from_list([ + {"text": "question1", "answer": "A"}, + {"text": "question2", "answer": "B"} + ]) + pred_item = {"id": 0, "prediction": "predicted_answer", "uuid": "test_uuid"} + + with patch.object(DummyJDGDataset, 'load') as mock_load: + mock_load.return_value = Dataset.from_list([{"text": "a"}]) + ds = DummyJDGDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1 + ) + result = ds._process_single_item(dataset_content, pred_item) + self.assertEqual(result["model_answer"], "predicted_answer") + self.assertEqual(result["model_pred_uuid"], "test_uuid") + + def test_modify_dataset_item(self): + """测试_modify_dataset_item方法""" + class DummyJDGDataset(BaseJDGDataset): + def _get_dataset_class(self): + return DummyDataset + def _load_from_predictions(self, prediction_path): + return [] + + with patch.object(DummyJDGDataset, 'load') as mock_load: + mock_load.return_value = Dataset.from_list([{"text": "a"}]) + ds = DummyJDGDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1 + ) + dataset_item = {"text": "question", "answer": "A"} + pred_item = {"prediction": "predicted_answer"} + ds._modify_dataset_item(dataset_item, pred_item) + self.assertEqual(dataset_item["model_answer"], "predicted_answer") + + def test_load_with_predictions(self): + """测试load方法处理predictions""" + class DummyJDGDataset(BaseJDGDataset): + def _get_dataset_class(self): + return DummyDataset + def _load_from_predictions(self, prediction_path): + return [{"id": 0, "prediction": "pred1", "uuid": "uuid1"}] + + with patch.object(DummyJDGDataset, '_process_predictions') as mock_process: + mock_process.return_value = Dataset.from_list([{"text": "result"}]) + + with patch.object(DummyJDGDataset, '__init__', lambda self, *args, **kwargs: None): + ds = DummyJDGDataset.__new__(DummyJDGDataset) + ds.dataset_instance = MagicMock() + ds.dataset_instance.dataset = {"test": Dataset.from_list([{"text": "test"}])} + ds.task_state_manager = None + ds.logger = MagicMock() + + result = ds.load(predictions_path="/test/predictions.jsonl") + self.assertIsInstance(result, Dataset) + if __name__ == "__main__": unittest.main() diff --git a/tests/UT/datasets/test_g_edit.py b/tests/UT/datasets/test_g_edit.py new file mode 100644 index 00000000..035600ce --- /dev/null +++ b/tests/UT/datasets/test_g_edit.py @@ -0,0 +1,136 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock, mock_open +from io import BytesIO +from PIL import Image +import base64 + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) + +from datasets import Dataset + +from ais_bench.benchmark.datasets.g_edit import ( + GEditDataset, + GEditSCJDGDataset, + GEditPQJDGDataset, + GEditEvaluator +) + + +class TestGEditEvaluator: + def test_score(self): + """测试score方法""" + evaluator = GEditEvaluator() + predictions = ["pred1", "pred2", "pred3"] + references = ["ref1", "ref2", "ref3"] + + result = evaluator.score(predictions, references) + + assert "accuracy" in result + assert "details" in result + assert result["accuracy"] == 100.0 + assert len(result["details"]) == 3 + + def test_score_empty_predictions(self): + """测试score方法,空predictions""" + evaluator = GEditEvaluator() + predictions = [] + references = ["ref1"] + + result = evaluator.score(predictions, references) + + assert "accuracy" in result + assert "details" in result + assert result["accuracy"] == 0.0 + + +class TestGEditDataset: + def test_load_basic(self): + """测试基本load方法""" + mock_dataset = Dataset.from_list([ + { + "input_image": Image.new('RGB', (100, 100), color='red'), + "input_image_raw": Image.new('RGB', (100, 100), color='blue'), + "instruction": "test instruction" + } + ] * 2000) + + with patch('ais_bench.benchmark.datasets.g_edit.load_from_disk') as mock_load: + with patch('ais_bench.benchmark.datasets.g_edit.get_data_path') as mock_get_path: + mock_get_path.return_value = '/test/path' + mock_load.return_value = mock_dataset + + ds = GEditDataset.__new__(GEditDataset) + ds.task_state_manager = None + ds.logger = MagicMock() + ds.update_task_state = MagicMock() + + result = ds.load(path='/test/path') + + assert isinstance(result, Dataset) + + def test_load_with_split(self): + """测试带数据集切分的load方法""" + mock_dataset = Dataset.from_list([ + { + "input_image": Image.new('RGB', (100, 100), color='red'), + "input_image_raw": Image.new('RGB', (100, 100), color='blue'), + "instruction": "test instruction" + } + ] * 10) + + with patch('ais_bench.benchmark.datasets.g_edit.load_from_disk') as mock_load: + with patch('ais_bench.benchmark.datasets.g_edit.get_data_path') as mock_get_path: + mock_get_path.return_value = '/test/path' + mock_load.return_value = mock_dataset + + ds = GEditDataset.__new__(GEditDataset) + ds.task_state_manager = None + ds.logger = MagicMock() + ds.update_task_state = MagicMock() + + result = ds.load(path='/test/path', split_count=2, split_index=0) + + assert isinstance(result, Dataset) + + def test_load_use_raw(self): + """测试使用原始图片的load方法""" + mock_dataset = Dataset.from_list([ + { + "input_image": Image.new('RGB', (100, 100), color='red'), + "input_image_raw": Image.new('RGB', (100, 100), color='blue'), + "instruction": "test instruction" + } + ] * 2000) + + with patch('ais_bench.benchmark.datasets.g_edit.load_from_disk') as mock_load: + with patch('ais_bench.benchmark.datasets.g_edit.get_data_path') as mock_get_path: + mock_get_path.return_value = '/test/path' + mock_load.return_value = mock_dataset + + ds = GEditDataset.__new__(GEditDataset) + ds.task_state_manager = None + ds.logger = MagicMock() + ds.update_task_state = MagicMock() + + result = ds.load(path='/test/path', use_raw=True) + + assert isinstance(result, Dataset) + + +class TestGEditSCJDGDataset: + def test_get_dataset_class(self): + """测试_get_dataset_class方法""" + ds = GEditSCJDGDataset.__new__(GEditSCJDGDataset) + result = ds._get_dataset_class() + assert result == GEditDataset + + +class TestGEditPQJDGDataset: + def test_get_dataset_class(self): + """测试_get_dataset_class方法""" + ds = GEditPQJDGDataset.__new__(GEditPQJDGDataset) + result = ds._get_dataset_class() + assert result == GEditDataset diff --git a/tests/UT/datasets/utils/test_llm_judge.py b/tests/UT/datasets/utils/test_llm_judge.py new file mode 100644 index 00000000..8123a8c4 --- /dev/null +++ b/tests/UT/datasets/utils/test_llm_judge.py @@ -0,0 +1,144 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock, mock_open +import tempfile +import json + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) + +from ais_bench.benchmark.datasets.utils import llm_judge +from ais_bench.benchmark.datasets.utils.llm_judge import ( + get_a_or_b, + LLMJudgeDataset, + LLMJudgeCorrectEvaluator +) + + +class TestGetAOrB: + def test_get_a_or_b_with_a(self): + """测试提取A""" + result = get_a_or_b("The answer is A") + assert result == "A" + + def test_get_a_or_b_with_b(self): + """测试提取B""" + result = get_a_or_b("The answer is B") + assert result == "B" + + def test_get_a_or_b_no_match(self): + """测试没有匹配时返回B""" + result = get_a_or_b("The answer is C") + assert result == "B" + + def test_get_a_or_b_empty(self): + """测试空字符串""" + result = get_a_or_b("") + assert result == "B" + + def test_get_a_or_b_at_end(self): + """测试A在末尾""" + result = get_a_or_b("something A") + assert result == "A" + + +class TestLLMJudgeDataset: + def test_load_from_predictions_file_not_exists(self): + """测试文件不存在的情况""" + ds = LLMJudgeDataset.__new__(LLMJudgeDataset) + ds.logger = MagicMock() + + with patch('os.path.exists', return_value=False): + result = ds._load_from_predictions('/test/nonexistent.jsonl') + assert result == [] + + def test_load_from_predictions_success(self): + """测试成功加载predictions""" + mock_preds = [ + {"id": 1, "prediction": "pred1"}, + {"id": 0, "prediction": "pred2"} + ] + + ds = LLMJudgeDataset.__new__(LLMJudgeDataset) + ds.logger = MagicMock() + ds.task_state_manager = None + + with patch('os.path.exists', return_value=True): + # Use patch.object on the already imported module + with patch.object(llm_judge, 'load_jsonl', return_value=mock_preds): + result = ds._load_from_predictions('/test/predictions.jsonl') + + assert len(result) == 2 + assert result[0]["id"] == 0 + assert result[1]["id"] == 1 + + +class TestLLMJudgeCorrectEvaluator: + def test_score_all_correct(self): + """测试全部正确的情况""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["A", "A", "A"] + references = ["correct", "correct", "correct"] + + result = evaluator.score(predictions, references) + + assert result["accuracy"] == 100.0 + assert len(result["details"]) == 3 + for detail in result["details"]: + assert detail["correct"] is True + + def test_score_all_wrong(self): + """测试全部错误的情况""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["B", "B", "B"] + references = ["correct", "correct", "correct"] + + result = evaluator.score(predictions, references) + + assert result["accuracy"] == 0.0 + for detail in result["details"]: + assert detail["correct"] is False + + def test_score_mixed(self): + """测试混合情况""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["A", "B", "A"] + references = ["correct1", "correct2", "correct3"] + + result = evaluator.score(predictions, references) + + assert result["accuracy"] == pytest.approx(100 * 2 / 3, rel=1e-2) + assert result["details"][0]["correct"] is True + assert result["details"][1]["correct"] is False + assert result["details"][2]["correct"] is True + + def test_score_length_mismatch(self): + """测试长度不匹配的情况""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["A", "A"] + references = ["correct", "correct", "correct"] + + result = evaluator.score(predictions, references) + + assert "error" in result + + def test_score_empty_predictions(self): + """测试空predictions""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = [] + references = ["correct"] + + result = evaluator.score(predictions, references) + + assert "error" in result + + def test_score_empty_references(self): + """测试空references""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["A"] + references = [] + + result = evaluator.score(predictions, references) + + assert "error" in result diff --git a/tests/UT/datasets/utils/test_lmm_judge.py b/tests/UT/datasets/utils/test_lmm_judge.py new file mode 100644 index 00000000..4e9acec4 --- /dev/null +++ b/tests/UT/datasets/utils/test_lmm_judge.py @@ -0,0 +1,276 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock, mock_open +import tempfile +import json +import base64 +from io import BytesIO +from PIL import Image + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) + +from ais_bench.benchmark.datasets.utils import lmm_judge +from ais_bench.benchmark.datasets.utils.lmm_judge import ( + get_lmm_point_list, + LMMImgJDGDataset, + ImgSCJDGDataset, + ImgPQJDGDataset, + LMMJudgeImageEditEvaluator +) + + +class TestGetLmmPointList: + def test_get_lmm_point_list_valid(self): + """测试提取有效的列表""" + result = get_lmm_point_list("The answer is [1, 2, 3]") + assert result == "[1, 2, 3]" + + def test_get_lmm_point_list_single(self): + """测试提取单个数字的列表""" + result = get_lmm_point_list("Result: [5]") + assert result == "[5]" + + def test_get_lmm_point_list_with_spaces(self): + """测试提取带空格的列表""" + result = get_lmm_point_list("Scores: [ 1 , 2 , 3 ]") + assert result == "[ 1 , 2 , 3 ]" + + def test_get_lmm_point_list_no_match(self): + """测试没有匹配时返回空列表""" + result = get_lmm_point_list("No list here") + assert result == "[]" + + def test_get_lmm_point_list_empty(self): + """测试空字符串""" + result = get_lmm_point_list("") + assert result == "[]" + + def test_get_lmm_point_list_multiple(self): + """测试多个数字的列表""" + result = get_lmm_point_list("Points: [10, 20, 30, 40]") + assert result == "[10, 20, 30, 40]" + + +class TestLMMImgJDGDataset: + def test_load_from_predictions_file_not_exists(self): + """测试文件不存在的情况""" + ds = LMMImgJDGDataset.__new__(LMMImgJDGDataset) + ds.task_state_manager = None + + with patch('os.path.exists', return_value=False): + result = ds._load_from_predictions('/test/nonexistent.jsonl') + assert result == [] + + def test_load_from_predictions_success(self): + """测试成功加载predictions""" + # 创建测试图片并转换为base64 + test_image = Image.new('RGB', (100, 100), color='red') + buffered = BytesIO() + test_image.save(buffered, format="PNG") + expected_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + + mock_preds = [ + {"id": 1, "prediction": "image1.png"}, + {"id": 0, "prediction": "image2.png"} + ] + + ds = LMMImgJDGDataset.__new__(LMMImgJDGDataset) + ds.task_state_manager = MagicMock() + ds.update_task_state = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + # 创建测试图片文件 + img_path1 = os.path.join(tmpdir, "image1.png") + img_path2 = os.path.join(tmpdir, "image2.png") + test_image.save(img_path1) + test_image.save(img_path2) + + # 创建predictions文件 + pred_file = os.path.join(tmpdir, "predictions.jsonl") + + with patch('os.path.exists', return_value=True): + with patch.object(lmm_judge, 'load_jsonl', return_value=mock_preds): + result = ds._load_from_predictions(pred_file) + + assert len(result) == 2 + assert result[0]["id"] == 0 + assert result[1]["id"] == 1 + # 验证图片被转换为base64 + assert result[0]["prediction"] == expected_base64 + + def test_load_from_predictions_with_nonexistent_image(self): + """测试图片文件不存在的情况""" + mock_preds = [ + {"id": 0, "prediction": "nonexistent.png"} + ] + + ds = LMMImgJDGDataset.__new__(LMMImgJDGDataset) + ds.task_state_manager = MagicMock() + ds.update_task_state = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + pred_file = os.path.join(tmpdir, "predictions.jsonl") + + with patch('os.path.exists') as mock_exists: + # 文件存在但图片不存在 + mock_exists.side_effect = lambda path: path.endswith('.jsonl') or path == pred_file + + with patch.object(lmm_judge, 'load_jsonl', return_value=mock_preds): + result = ds._load_from_predictions(pred_file) + + assert len(result) == 1 + assert result[0]["prediction"] == "nonexistent.png" + + +class TestImgSCJDGDataset: + def test_modify_dataset_item(self): + """测试_modify_dataset_item方法""" + from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START + + ds = ImgSCJDGDataset.__new__(ImgSCJDGDataset) + ds.logger = MagicMock() + + question = "What is in the image?" + org_image_url = "original_base64_string" + pred_image_url = "prediction_base64_string" + + dataset_item = { + "content": AIS_TEXT_START + question + AIS_CONTENT_TAG + AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG + } + pred_item = {"prediction": pred_image_url} + + ds._modify_dataset_item(dataset_item, pred_item) + + # 验证content被正确修改 + assert AIS_TEXT_START + question + AIS_CONTENT_TAG in dataset_item["content"] + assert AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG in dataset_item["content"] + assert AIS_IMAGE_START + pred_image_url + AIS_CONTENT_TAG in dataset_item["content"] + + +class TestImgPQJDGDataset: + def test_modify_dataset_item(self): + """测试_modify_dataset_item方法""" + from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START + + ds = ImgPQJDGDataset.__new__(ImgPQJDGDataset) + ds.logger = MagicMock() + + question = "Describe the image?" + org_image_url = "original_base64_string" + pred_image_url = "prediction_base64_string" + + dataset_item = { + "content": AIS_TEXT_START + question + AIS_CONTENT_TAG + AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG + } + pred_item = {"prediction": pred_image_url} + + ds._modify_dataset_item(dataset_item, pred_item) + + # 验证content被正确修改(PQ版本不包含原始图片) + assert AIS_TEXT_START + question + AIS_CONTENT_TAG in dataset_item["content"] + assert AIS_IMAGE_START + pred_image_url + AIS_CONTENT_TAG in dataset_item["content"] + # PQ版本不应该包含原始图片URL + assert org_image_url not in dataset_item["content"] + + +class TestLMMJudgeImageEditEvaluator: + def test_init_default_metric(self): + """测试默认metric初始化""" + evaluator = LMMJudgeImageEditEvaluator() + assert evaluator.metric == "SC" + assert evaluator.point_key_list == ["editing success", "over editing"] + + def test_init_pq_metric(self): + """测试PQ metric初始化""" + evaluator = LMMJudgeImageEditEvaluator(metric="PQ") + assert evaluator.metric == "PQ" + assert evaluator.point_key_list == ["naturalness", "artifacts"] + + def test_score_success(self): + """测试score方法成功情况""" + evaluator = LMMJudgeImageEditEvaluator(metric="SC") + predictions = ["[5, 4]", "[3, 2]", "[4, 5]"] + references = ["ref1", "ref2", "ref3"] + + result = evaluator.score(predictions, references) + + assert "SC" in result + assert "details" in result + assert len(result["details"]) == 3 + # min(5,4)=4, min(3,2)=2, min(4,5)=4, average = (4+2+4)/3 = 3.33 + assert result["SC"] == pytest.approx(10/3, rel=1e-2) + + def test_score_pq_metric(self): + """测试PQ metric的score方法""" + evaluator = LMMJudgeImageEditEvaluator(metric="PQ") + predictions = ["[4, 5]", "[3, 3]"] + references = ["ref1", "ref2"] + + result = evaluator.score(predictions, references) + + assert "PQ" in result + assert len(result["details"]) == 2 + # min(4,5)=4, min(3,3)=3, average = (4+3)/2 = 3.5 + assert result["PQ"] == pytest.approx(3.5, rel=1e-2) + + def test_score_length_mismatch(self): + """测试长度不匹配的情况""" + evaluator = LMMJudgeImageEditEvaluator() + predictions = ["[1, 2]"] + references = ["ref1", "ref2"] + + result = evaluator.score(predictions, references) + + assert "error" in result + + def test_score_non_string_predictions(self): + """测试predictions不是字符串的情况""" + evaluator = LMMJudgeImageEditEvaluator() + predictions = [[1, 2], [3, 4]] + references = ["ref1", "ref2"] + + result = evaluator.score(predictions, references) + + assert "error" in result + + def test_score_invalid_prediction_format(self): + """测试prediction格式错误的情况""" + evaluator = LMMJudgeImageEditEvaluator(metric="SC") + # SC需要2个分数,但这里给出3个 + predictions = ["[1, 2, 3]"] + references = ["ref1"] + + result = evaluator.score(predictions, references) + + assert "details" in result + assert result["details"][0]["eval_success"] is False + assert "failed reason" in result["details"][0] + + def test_score_empty_predictions(self): + """测试空predictions""" + evaluator = LMMJudgeImageEditEvaluator() + predictions = [] + references = [] + + result = evaluator.score(predictions, references) + + assert "SC" in result + assert result["SC"] == 0.0 + assert "details" in result + + def test_score_detail_structure(self): + """测试details结构正确""" + evaluator = LMMJudgeImageEditEvaluator(metric="SC") + predictions = ["[5, 4]"] + references = ["ref1"] + + result = evaluator.score(predictions, references) + + detail = result["details"][0] + assert detail["eval_success"] is True + assert "pred" in detail + assert detail["pred"]["editing success"] == 5 + assert detail["pred"]["over editing"] == 4 + assert detail["org_uuid"] == "ref1"