diff --git a/tests/UT/cli/test_workers.py b/tests/UT/cli/test_workers.py index b8f3c8b1..ac883f4f 100644 --- a/tests/UT/cli/test_workers.py +++ b/tests/UT/cli/test_workers.py @@ -1,7 +1,7 @@ import sys import os import pytest -from unittest.mock import patch, MagicMock, call +from unittest.mock import patch, MagicMock, call, mock_open from collections import defaultdict # 添加项目根目录到Python路径 @@ -12,6 +12,7 @@ from ais_bench.benchmark.cli.workers import ( BaseWorker, Infer, + JudgeInfer, Eval, AccViz, PerfViz, @@ -664,11 +665,14 @@ def test_work_flow_dict(): assert 'viz' in WORK_FLOW assert 'perf' in WORK_FLOW assert 'perf_viz' in WORK_FLOW + assert 'judge' in WORK_FLOW + assert 'infer_judge' in WORK_FLOW # 验证工作流内容正确 assert Infer in WORK_FLOW['all'] assert Eval in WORK_FLOW['all'] assert AccViz in WORK_FLOW['all'] + assert JudgeInfer in WORK_FLOW['all'] assert Infer in WORK_FLOW['infer'] @@ -680,4 +684,220 @@ def test_work_flow_dict(): assert Infer in WORK_FLOW['perf'] assert PerfViz in WORK_FLOW['perf'] - assert PerfViz in WORK_FLOW['perf_viz'] \ No newline at end of file + assert PerfViz in WORK_FLOW['perf_viz'] + + assert JudgeInfer in WORK_FLOW['judge'] + assert Infer in WORK_FLOW['infer_judge'] + assert JudgeInfer in WORK_FLOW['infer_judge'] + + +class TestJudgeInfer: + def setup_method(self): + """设置测试环境""" + self.mock_args = MagicMock() + self.mock_args.max_num_workers = 4 + self.mock_args.max_workers_per_gpu = 2 + self.mock_args.debug = False + self.judge_infer_worker = JudgeInfer(self.mock_args) + + @patch('ais_bench.benchmark.cli.workers.get_config_type') + def test_update_cfg_service_model(self, mock_get_config_type): + """测试update_cfg方法,使用service模型""" + mock_get_config_type.side_effect = ['MockNaivePartitioner', 'MockOpenICLApiInferTask', 'MockLocalRunner'] + + cfg = MockConfigDict({ + 'datasets': [{ + 'judge_infer_cfg': { + 'judge_model': {'attr': 'service', 'abbr': 'judge_model'} + } + }], + 'work_dir': '/test/workdir', + 'cli_args': MagicMock(debug=False) + }) + + with patch('os.path.join', return_value='/test/workdir/predictions/'): + result = self.judge_infer_worker.update_cfg(cfg) + + assert result == cfg + assert cfg['judge_infer']['partitioner']['type'] == 'MockNaivePartitioner' + assert cfg['judge_infer']['runner']['type'] == 'MockLocalRunner' + assert cfg['judge_infer']['runner']['task']['type'] == 'MockOpenICLApiInferTask' + + @patch('ais_bench.benchmark.cli.workers.get_config_type') + def test_update_cfg_local_model(self, mock_get_config_type): + """测试update_cfg方法,使用local模型""" + mock_get_config_type.side_effect = ['MockNaivePartitioner', 'MockOpenICLInferTask', 'MockLocalRunner'] + + cfg = MockConfigDict({ + 'datasets': [{ + 'judge_infer_cfg': { + 'judge_model': {'attr': 'local', 'abbr': 'judge_model'} + } + }], + 'work_dir': '/test/workdir', + 'cli_args': MagicMock(debug=True) + }) + + with patch('os.path.join', return_value='/test/workdir/predictions/'): + self.judge_infer_worker.update_cfg(cfg) + + assert cfg['judge_infer']['runner']['task']['type'] == 'MockOpenICLInferTask' + assert cfg['judge_infer']['runner']['debug'] == True + + def test_cfg_pre_process(self): + """测试_cfg_pre_process方法""" + cfg = MockConfigDict({ + 'datasets': [ + { + 'abbr': 'test_dataset', + 'judge_infer_cfg': { + 'judge_model': {'abbr': 'judge_model'} + } + } + ] + }) + + self.judge_infer_worker._cfg_pre_process(cfg) + + assert cfg['datasets'][0]['abbr'] == 'test_dataset-judge_model' + assert 'test_dataset-judge_model' in self.judge_infer_worker.org_dataset_abbrs + + def test_cfg_pre_process_with_model_dataset_combinations(self): + """测试_cfg_pre_process方法,包含model_dataset_combinations""" + cfg = MockConfigDict({ + 'model_dataset_combinations': [ + { + 'datasets': [ + { + 'abbr': 'combo_dataset', + 'judge_infer_cfg': { + 'judge_model': {'abbr': 'judge_model'} + } + } + ] + } + ], + 'datasets': [] + }) + + self.judge_infer_worker._cfg_pre_process(cfg) + + assert cfg['model_dataset_combinations'][0]['datasets'][0]['abbr'] == 'combo_dataset-judge_model' + + def test_merge_datasets(self): + """测试_merge_datasets方法""" + task1 = { + 'models': [{'abbr': 'model1'}], + 'datasets': [[{'type': 'dataset_type', 'infer_cfg': {'inferencer': 'inferencer_type'}}]] + } + task2 = { + 'models': [{'abbr': 'model1'}], + 'datasets': [[{'type': 'dataset_type', 'infer_cfg': {'inferencer': 'inferencer_type'}}]] + } + task3 = { + 'models': [{'abbr': 'model2'}], + 'datasets': [[{'type': 'dataset_type', 'infer_cfg': {'inferencer': 'inferencer_type'}}]] + } + + result = self.judge_infer_worker._merge_datasets([task1, task2, task3]) + + assert len(result) == 2 + assert len(result[0]['datasets'][0]) == 2 + assert len(result[1]['datasets'][0]) == 1 + + @patch('ais_bench.benchmark.cli.workers.PARTITIONERS') + @patch('ais_bench.benchmark.cli.workers.RUNNERS') + @patch('ais_bench.benchmark.cli.workers.logger') + def test_do_work_no_tasks(self, mock_logger, mock_runners, mock_partitioners): + """测试do_work方法,没有有效任务的情况""" + mock_partitioner = MagicMock() + mock_partitioners.build.return_value = mock_partitioner + mock_tasks = [ + { + 'datasets': [[{}]] # 没有judge_infer_cfg + } + ] + mock_partitioner.return_value = mock_tasks + + cfg = MockConfigDict({ + 'judge_infer': { + 'partitioner': {}, + 'runner': {} + }, + 'datasets': [] + }) + + with patch.object(self.judge_infer_worker, '_cfg_pre_process'): + with patch.object(self.judge_infer_worker, '_update_tasks_cfg'): + self.judge_infer_worker.do_work(cfg) + + mock_runners.build.assert_not_called() + + @patch('ais_bench.benchmark.cli.workers.load_jsonl') + @patch('ais_bench.benchmark.cli.workers.dump_jsonl') + @patch('os.path.exists') + @patch('os.remove') + def test_result_post_process(self, mock_remove, mock_exists, mock_dump_jsonl, mock_load_jsonl): + """测试_result_post_process方法""" + mock_load_jsonl.side_effect = [ + [{'uuid': 'uuid1', 'id': 'id1'}], # model_preds + [{'gold': 'uuid1', 'prediction': 'pred1'}] # judge_preds + ] + mock_exists.return_value = True + + task = { + 'datasets': [[{ + 'predictions_path': '/test/model_pred.jsonl', + 'abbr': 'test_dataset-judge_model' + }]], + 'models': [{'abbr': 'model1'}] + } + tasks = [task] + + cfg = MockConfigDict({ + 'judge_infer': { + 'partitioner': { + 'out_dir': '/test/predictions' + } + } + }) + + self.judge_infer_worker.org_dataset_abbrs = {'test_dataset-judge_model': 'test_dataset'} + + self.judge_infer_worker._result_post_process(tasks, cfg) + + mock_remove.assert_called_once() + mock_dump_jsonl.assert_called_once() + + def test_update_tasks_cfg_with_judge_infer(self): + """测试_update_tasks_cfg方法,包含judge_infer_cfg""" + self.judge_infer_worker.org_dataset_abbrs = {'test_dataset-judge_model': 'test_dataset'} + + task = { + 'models': [{'abbr': 'model1'}], + 'datasets': [[{ + 'abbr': 'test_dataset-judge_model', + 'judge_infer_cfg': { + 'judge_model': {'type': 'judge_model_type'}, + 'judge_dataset_type': 'judge_dataset', + 'judge_reader_cfg': {'test': 'cfg'} + } + }]] + } + tasks = [task] + + cfg = MockConfigDict({ + 'judge_infer': { + 'partitioner': { + 'out_dir': '/test/predictions' + } + } + }) + + with patch('os.path.join', return_value='/test/predictions/model1/test_dataset.jsonl'): + with patch('os.path.exists', return_value=True): + self.judge_infer_worker._update_tasks_cfg(tasks, cfg) + + assert 'judge_infer_cfg' not in task['datasets'][0][0] + assert task['models'][0]['type'] == 'judge_model_type' + assert task['datasets'][0][0]['type'] == 'judge_dataset' \ No newline at end of file diff --git a/tests/UT/datasets/test_base.py b/tests/UT/datasets/test_base.py index a2bc8593..9f59d1db 100644 --- a/tests/UT/datasets/test_base.py +++ b/tests/UT/datasets/test_base.py @@ -1,9 +1,10 @@ import unittest from unittest.mock import patch, MagicMock +import pytest from datasets import Dataset, DatasetDict -from ais_bench.benchmark.datasets.base import BaseDataset +from ais_bench.benchmark.datasets.base import BaseDataset, BaseJDGDataset class DummyDataset(BaseDataset): @@ -35,7 +36,7 @@ def load(**kwargs): {"text": "a"}, {"text": "b"}, ]) - + def _init_reader(self, **kwargs): # 先正常初始化reader from ais_bench.benchmark.openicl.icl_dataset_reader import DatasetReader @@ -66,7 +67,7 @@ def test_repeated_dataset_and_metadata(self): first = ds.dataset['test'][0] self.assertIn("subdivision", first) self.assertIn("idx", first) - + def test_repeated_dataset_with_dataset_type(self): """测试当reader.dataset是Dataset类型时的处理(覆盖46-62行)""" # 创建一个返回Dataset的类,并手动设置reader.dataset为Dataset @@ -83,7 +84,7 @@ def test_repeated_dataset_with_dataset_type(self): first = ds.dataset[0] self.assertIn("subdivision", first) self.assertIn("idx", first) - + def test_train_property(self): """测试train属性(覆盖92行)""" ds = DummyDataset( @@ -94,7 +95,7 @@ def test_train_property(self): train = ds.train self.assertIsInstance(train, Dataset) self.assertGreater(len(train), 0) - + def test_test_property(self): """测试test属性(覆盖96行)""" ds = DummyDataset( @@ -105,17 +106,17 @@ def test_test_property(self): test = ds.test self.assertIsInstance(test, Dataset) self.assertGreater(len(test), 0) - + def test_repeated_dataset_with_large_batch_size(self): """测试大批量数据的批处理逻辑(覆盖批处理相关代码)""" # 创建一个较大的数据集来触发批处理逻辑 large_data = [{"text": f"item_{i}"} for i in range(15000)] - + class LargeDataset(BaseDataset): @staticmethod def load(**kwargs): return Dataset.from_list(large_data) - + ds = LargeDataset( reader_cfg={'input_columns': ['text'], 'output_column': None}, k=1, @@ -130,6 +131,173 @@ def load(**kwargs): self.assertIn("subdivision", first) self.assertIn("idx", first) + def test_init_with_task_state_manager(self): + """测试使用task_state_manager初始化""" + mock_task_state_manager = MagicMock() + ds = DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1, + task_state_manager=mock_task_state_manager + ) + self.assertEqual(ds.task_state_manager, mock_task_state_manager) + + def test_update_task_state_with_manager(self): + """测试使用task_state_manager更新状态""" + mock_task_state_manager = MagicMock() + ds = DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1, + task_state_manager=mock_task_state_manager + ) + state = {'status': 'processing'} + ds.update_task_state(state) + mock_task_state_manager.update_task_state.assert_called_once_with(state) + + def test_update_task_state_without_manager(self): + """测试没有task_state_manager时更新状态""" + ds = DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1 + ) + ds.task_state_manager = None + ds.logger = MagicMock() + # 不应抛出异常 + ds.update_task_state({'status': 'processing'}) + + def test_init_with_abbr(self): + """测试使用abbr参数初始化""" + ds = DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1, + abbr='custom_abbr' + ) + self.assertEqual(ds.abbr, 'custom_abbr') + + def test_init_k_greater_than_n_raises_error(self): + """测试k > n时抛出异常""" + from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError + with self.assertRaises(ParameterValueError): + DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=5, + n=3 + ) + + def test_init_k_list_greater_than_n_raises_error(self): + """测试k为列表且最大值 > n时抛出异常""" + from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError + with self.assertRaises(ParameterValueError): + DummyDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=[1, 2, 5], + n=3 + ) + + def test_repeated_dataset_with_dataset_dict(self): + """测试DatasetDict类型的repeated_dataset处理""" + ds = DummyDatasetDict( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=2 + ) + self.assertIsInstance(ds.dataset, DatasetDict) + # 验证每个split都被正确处理 + self.assertEqual(len(ds.dataset['train']), 4) # 2 * 2 + self.assertEqual(len(ds.dataset['test']), 2) # 1 * 2 + # 验证元数据 + first_train = ds.dataset['train'][0] + self.assertIn("subdivision", first_train) + self.assertIn("idx", first_train) + + +class TestBaseJDGDataset(unittest.TestCase): + def test_init_org_datasets_instance(self): + """测试_init_org_datasets_instance方法""" + class DummyJDGDataset(BaseJDGDataset): + def _get_dataset_class(self): + return DummyDataset + def _load_from_predictions(self, prediction_path): + return [] + + with patch.object(DummyJDGDataset, 'load') as mock_load: + mock_load.return_value = Dataset.from_list([{"text": "a"}]) + ds = DummyJDGDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1 + ) + self.assertIsNotNone(ds.dataset_instance) + + def test_process_single_item(self): + """测试_process_single_item方法""" + class DummyJDGDataset(BaseJDGDataset): + def _get_dataset_class(self): + return DummyDataset + def _load_from_predictions(self, prediction_path): + return [] + + dataset_content = Dataset.from_list([ + {"text": "question1", "answer": "A"}, + {"text": "question2", "answer": "B"} + ]) + pred_item = {"id": 0, "prediction": "predicted_answer", "uuid": "test_uuid"} + + with patch.object(DummyJDGDataset, 'load') as mock_load: + mock_load.return_value = Dataset.from_list([{"text": "a"}]) + ds = DummyJDGDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1 + ) + result = ds._process_single_item(dataset_content, pred_item) + self.assertEqual(result["model_answer"], "predicted_answer") + self.assertEqual(result["model_pred_uuid"], "test_uuid") + + def test_modify_dataset_item(self): + """测试_modify_dataset_item方法""" + class DummyJDGDataset(BaseJDGDataset): + def _get_dataset_class(self): + return DummyDataset + def _load_from_predictions(self, prediction_path): + return [] + + with patch.object(DummyJDGDataset, 'load') as mock_load: + mock_load.return_value = Dataset.from_list([{"text": "a"}]) + ds = DummyJDGDataset( + reader_cfg={'input_columns': ['text'], 'output_column': None}, + k=1, + n=1 + ) + dataset_item = {"text": "question", "answer": "A"} + pred_item = {"prediction": "predicted_answer"} + ds._modify_dataset_item(dataset_item, pred_item) + self.assertEqual(dataset_item["model_answer"], "predicted_answer") + + def test_load_with_predictions(self): + """测试load方法处理predictions""" + class DummyJDGDataset(BaseJDGDataset): + def _get_dataset_class(self): + return DummyDataset + def _load_from_predictions(self, prediction_path): + return [{"id": 0, "prediction": "pred1", "uuid": "uuid1"}] + + with patch.object(DummyJDGDataset, '_process_predictions') as mock_process: + mock_process.return_value = Dataset.from_list([{"text": "result"}]) + + with patch.object(DummyJDGDataset, '__init__', lambda self, *args, **kwargs: None): + ds = DummyJDGDataset.__new__(DummyJDGDataset) + ds.dataset_instance = MagicMock() + ds.dataset_instance.dataset = {"test": Dataset.from_list([{"text": "test"}])} + ds.task_state_manager = None + ds.logger = MagicMock() + + result = ds.load(predictions_path="/test/predictions.jsonl") + self.assertIsInstance(result, Dataset) + if __name__ == "__main__": unittest.main() diff --git a/tests/UT/datasets/test_g_edit.py b/tests/UT/datasets/test_g_edit.py new file mode 100644 index 00000000..035600ce --- /dev/null +++ b/tests/UT/datasets/test_g_edit.py @@ -0,0 +1,136 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock, mock_open +from io import BytesIO +from PIL import Image +import base64 + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) + +from datasets import Dataset + +from ais_bench.benchmark.datasets.g_edit import ( + GEditDataset, + GEditSCJDGDataset, + GEditPQJDGDataset, + GEditEvaluator +) + + +class TestGEditEvaluator: + def test_score(self): + """测试score方法""" + evaluator = GEditEvaluator() + predictions = ["pred1", "pred2", "pred3"] + references = ["ref1", "ref2", "ref3"] + + result = evaluator.score(predictions, references) + + assert "accuracy" in result + assert "details" in result + assert result["accuracy"] == 100.0 + assert len(result["details"]) == 3 + + def test_score_empty_predictions(self): + """测试score方法,空predictions""" + evaluator = GEditEvaluator() + predictions = [] + references = ["ref1"] + + result = evaluator.score(predictions, references) + + assert "accuracy" in result + assert "details" in result + assert result["accuracy"] == 0.0 + + +class TestGEditDataset: + def test_load_basic(self): + """测试基本load方法""" + mock_dataset = Dataset.from_list([ + { + "input_image": Image.new('RGB', (100, 100), color='red'), + "input_image_raw": Image.new('RGB', (100, 100), color='blue'), + "instruction": "test instruction" + } + ] * 2000) + + with patch('ais_bench.benchmark.datasets.g_edit.load_from_disk') as mock_load: + with patch('ais_bench.benchmark.datasets.g_edit.get_data_path') as mock_get_path: + mock_get_path.return_value = '/test/path' + mock_load.return_value = mock_dataset + + ds = GEditDataset.__new__(GEditDataset) + ds.task_state_manager = None + ds.logger = MagicMock() + ds.update_task_state = MagicMock() + + result = ds.load(path='/test/path') + + assert isinstance(result, Dataset) + + def test_load_with_split(self): + """测试带数据集切分的load方法""" + mock_dataset = Dataset.from_list([ + { + "input_image": Image.new('RGB', (100, 100), color='red'), + "input_image_raw": Image.new('RGB', (100, 100), color='blue'), + "instruction": "test instruction" + } + ] * 10) + + with patch('ais_bench.benchmark.datasets.g_edit.load_from_disk') as mock_load: + with patch('ais_bench.benchmark.datasets.g_edit.get_data_path') as mock_get_path: + mock_get_path.return_value = '/test/path' + mock_load.return_value = mock_dataset + + ds = GEditDataset.__new__(GEditDataset) + ds.task_state_manager = None + ds.logger = MagicMock() + ds.update_task_state = MagicMock() + + result = ds.load(path='/test/path', split_count=2, split_index=0) + + assert isinstance(result, Dataset) + + def test_load_use_raw(self): + """测试使用原始图片的load方法""" + mock_dataset = Dataset.from_list([ + { + "input_image": Image.new('RGB', (100, 100), color='red'), + "input_image_raw": Image.new('RGB', (100, 100), color='blue'), + "instruction": "test instruction" + } + ] * 2000) + + with patch('ais_bench.benchmark.datasets.g_edit.load_from_disk') as mock_load: + with patch('ais_bench.benchmark.datasets.g_edit.get_data_path') as mock_get_path: + mock_get_path.return_value = '/test/path' + mock_load.return_value = mock_dataset + + ds = GEditDataset.__new__(GEditDataset) + ds.task_state_manager = None + ds.logger = MagicMock() + ds.update_task_state = MagicMock() + + result = ds.load(path='/test/path', use_raw=True) + + assert isinstance(result, Dataset) + + +class TestGEditSCJDGDataset: + def test_get_dataset_class(self): + """测试_get_dataset_class方法""" + ds = GEditSCJDGDataset.__new__(GEditSCJDGDataset) + result = ds._get_dataset_class() + assert result == GEditDataset + + +class TestGEditPQJDGDataset: + def test_get_dataset_class(self): + """测试_get_dataset_class方法""" + ds = GEditPQJDGDataset.__new__(GEditPQJDGDataset) + result = ds._get_dataset_class() + assert result == GEditDataset diff --git a/tests/UT/datasets/utils/test_llm_judge.py b/tests/UT/datasets/utils/test_llm_judge.py new file mode 100644 index 00000000..8123a8c4 --- /dev/null +++ b/tests/UT/datasets/utils/test_llm_judge.py @@ -0,0 +1,144 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock, mock_open +import tempfile +import json + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) + +from ais_bench.benchmark.datasets.utils import llm_judge +from ais_bench.benchmark.datasets.utils.llm_judge import ( + get_a_or_b, + LLMJudgeDataset, + LLMJudgeCorrectEvaluator +) + + +class TestGetAOrB: + def test_get_a_or_b_with_a(self): + """测试提取A""" + result = get_a_or_b("The answer is A") + assert result == "A" + + def test_get_a_or_b_with_b(self): + """测试提取B""" + result = get_a_or_b("The answer is B") + assert result == "B" + + def test_get_a_or_b_no_match(self): + """测试没有匹配时返回B""" + result = get_a_or_b("The answer is C") + assert result == "B" + + def test_get_a_or_b_empty(self): + """测试空字符串""" + result = get_a_or_b("") + assert result == "B" + + def test_get_a_or_b_at_end(self): + """测试A在末尾""" + result = get_a_or_b("something A") + assert result == "A" + + +class TestLLMJudgeDataset: + def test_load_from_predictions_file_not_exists(self): + """测试文件不存在的情况""" + ds = LLMJudgeDataset.__new__(LLMJudgeDataset) + ds.logger = MagicMock() + + with patch('os.path.exists', return_value=False): + result = ds._load_from_predictions('/test/nonexistent.jsonl') + assert result == [] + + def test_load_from_predictions_success(self): + """测试成功加载predictions""" + mock_preds = [ + {"id": 1, "prediction": "pred1"}, + {"id": 0, "prediction": "pred2"} + ] + + ds = LLMJudgeDataset.__new__(LLMJudgeDataset) + ds.logger = MagicMock() + ds.task_state_manager = None + + with patch('os.path.exists', return_value=True): + # Use patch.object on the already imported module + with patch.object(llm_judge, 'load_jsonl', return_value=mock_preds): + result = ds._load_from_predictions('/test/predictions.jsonl') + + assert len(result) == 2 + assert result[0]["id"] == 0 + assert result[1]["id"] == 1 + + +class TestLLMJudgeCorrectEvaluator: + def test_score_all_correct(self): + """测试全部正确的情况""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["A", "A", "A"] + references = ["correct", "correct", "correct"] + + result = evaluator.score(predictions, references) + + assert result["accuracy"] == 100.0 + assert len(result["details"]) == 3 + for detail in result["details"]: + assert detail["correct"] is True + + def test_score_all_wrong(self): + """测试全部错误的情况""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["B", "B", "B"] + references = ["correct", "correct", "correct"] + + result = evaluator.score(predictions, references) + + assert result["accuracy"] == 0.0 + for detail in result["details"]: + assert detail["correct"] is False + + def test_score_mixed(self): + """测试混合情况""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["A", "B", "A"] + references = ["correct1", "correct2", "correct3"] + + result = evaluator.score(predictions, references) + + assert result["accuracy"] == pytest.approx(100 * 2 / 3, rel=1e-2) + assert result["details"][0]["correct"] is True + assert result["details"][1]["correct"] is False + assert result["details"][2]["correct"] is True + + def test_score_length_mismatch(self): + """测试长度不匹配的情况""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["A", "A"] + references = ["correct", "correct", "correct"] + + result = evaluator.score(predictions, references) + + assert "error" in result + + def test_score_empty_predictions(self): + """测试空predictions""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = [] + references = ["correct"] + + result = evaluator.score(predictions, references) + + assert "error" in result + + def test_score_empty_references(self): + """测试空references""" + evaluator = LLMJudgeCorrectEvaluator() + predictions = ["A"] + references = [] + + result = evaluator.score(predictions, references) + + assert "error" in result diff --git a/tests/UT/datasets/utils/test_lmm_judge.py b/tests/UT/datasets/utils/test_lmm_judge.py new file mode 100644 index 00000000..4e9acec4 --- /dev/null +++ b/tests/UT/datasets/utils/test_lmm_judge.py @@ -0,0 +1,276 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock, mock_open +import tempfile +import json +import base64 +from io import BytesIO +from PIL import Image + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) + +from ais_bench.benchmark.datasets.utils import lmm_judge +from ais_bench.benchmark.datasets.utils.lmm_judge import ( + get_lmm_point_list, + LMMImgJDGDataset, + ImgSCJDGDataset, + ImgPQJDGDataset, + LMMJudgeImageEditEvaluator +) + + +class TestGetLmmPointList: + def test_get_lmm_point_list_valid(self): + """测试提取有效的列表""" + result = get_lmm_point_list("The answer is [1, 2, 3]") + assert result == "[1, 2, 3]" + + def test_get_lmm_point_list_single(self): + """测试提取单个数字的列表""" + result = get_lmm_point_list("Result: [5]") + assert result == "[5]" + + def test_get_lmm_point_list_with_spaces(self): + """测试提取带空格的列表""" + result = get_lmm_point_list("Scores: [ 1 , 2 , 3 ]") + assert result == "[ 1 , 2 , 3 ]" + + def test_get_lmm_point_list_no_match(self): + """测试没有匹配时返回空列表""" + result = get_lmm_point_list("No list here") + assert result == "[]" + + def test_get_lmm_point_list_empty(self): + """测试空字符串""" + result = get_lmm_point_list("") + assert result == "[]" + + def test_get_lmm_point_list_multiple(self): + """测试多个数字的列表""" + result = get_lmm_point_list("Points: [10, 20, 30, 40]") + assert result == "[10, 20, 30, 40]" + + +class TestLMMImgJDGDataset: + def test_load_from_predictions_file_not_exists(self): + """测试文件不存在的情况""" + ds = LMMImgJDGDataset.__new__(LMMImgJDGDataset) + ds.task_state_manager = None + + with patch('os.path.exists', return_value=False): + result = ds._load_from_predictions('/test/nonexistent.jsonl') + assert result == [] + + def test_load_from_predictions_success(self): + """测试成功加载predictions""" + # 创建测试图片并转换为base64 + test_image = Image.new('RGB', (100, 100), color='red') + buffered = BytesIO() + test_image.save(buffered, format="PNG") + expected_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + + mock_preds = [ + {"id": 1, "prediction": "image1.png"}, + {"id": 0, "prediction": "image2.png"} + ] + + ds = LMMImgJDGDataset.__new__(LMMImgJDGDataset) + ds.task_state_manager = MagicMock() + ds.update_task_state = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + # 创建测试图片文件 + img_path1 = os.path.join(tmpdir, "image1.png") + img_path2 = os.path.join(tmpdir, "image2.png") + test_image.save(img_path1) + test_image.save(img_path2) + + # 创建predictions文件 + pred_file = os.path.join(tmpdir, "predictions.jsonl") + + with patch('os.path.exists', return_value=True): + with patch.object(lmm_judge, 'load_jsonl', return_value=mock_preds): + result = ds._load_from_predictions(pred_file) + + assert len(result) == 2 + assert result[0]["id"] == 0 + assert result[1]["id"] == 1 + # 验证图片被转换为base64 + assert result[0]["prediction"] == expected_base64 + + def test_load_from_predictions_with_nonexistent_image(self): + """测试图片文件不存在的情况""" + mock_preds = [ + {"id": 0, "prediction": "nonexistent.png"} + ] + + ds = LMMImgJDGDataset.__new__(LMMImgJDGDataset) + ds.task_state_manager = MagicMock() + ds.update_task_state = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + pred_file = os.path.join(tmpdir, "predictions.jsonl") + + with patch('os.path.exists') as mock_exists: + # 文件存在但图片不存在 + mock_exists.side_effect = lambda path: path.endswith('.jsonl') or path == pred_file + + with patch.object(lmm_judge, 'load_jsonl', return_value=mock_preds): + result = ds._load_from_predictions(pred_file) + + assert len(result) == 1 + assert result[0]["prediction"] == "nonexistent.png" + + +class TestImgSCJDGDataset: + def test_modify_dataset_item(self): + """测试_modify_dataset_item方法""" + from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START + + ds = ImgSCJDGDataset.__new__(ImgSCJDGDataset) + ds.logger = MagicMock() + + question = "What is in the image?" + org_image_url = "original_base64_string" + pred_image_url = "prediction_base64_string" + + dataset_item = { + "content": AIS_TEXT_START + question + AIS_CONTENT_TAG + AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG + } + pred_item = {"prediction": pred_image_url} + + ds._modify_dataset_item(dataset_item, pred_item) + + # 验证content被正确修改 + assert AIS_TEXT_START + question + AIS_CONTENT_TAG in dataset_item["content"] + assert AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG in dataset_item["content"] + assert AIS_IMAGE_START + pred_image_url + AIS_CONTENT_TAG in dataset_item["content"] + + +class TestImgPQJDGDataset: + def test_modify_dataset_item(self): + """测试_modify_dataset_item方法""" + from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START + + ds = ImgPQJDGDataset.__new__(ImgPQJDGDataset) + ds.logger = MagicMock() + + question = "Describe the image?" + org_image_url = "original_base64_string" + pred_image_url = "prediction_base64_string" + + dataset_item = { + "content": AIS_TEXT_START + question + AIS_CONTENT_TAG + AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG + } + pred_item = {"prediction": pred_image_url} + + ds._modify_dataset_item(dataset_item, pred_item) + + # 验证content被正确修改(PQ版本不包含原始图片) + assert AIS_TEXT_START + question + AIS_CONTENT_TAG in dataset_item["content"] + assert AIS_IMAGE_START + pred_image_url + AIS_CONTENT_TAG in dataset_item["content"] + # PQ版本不应该包含原始图片URL + assert org_image_url not in dataset_item["content"] + + +class TestLMMJudgeImageEditEvaluator: + def test_init_default_metric(self): + """测试默认metric初始化""" + evaluator = LMMJudgeImageEditEvaluator() + assert evaluator.metric == "SC" + assert evaluator.point_key_list == ["editing success", "over editing"] + + def test_init_pq_metric(self): + """测试PQ metric初始化""" + evaluator = LMMJudgeImageEditEvaluator(metric="PQ") + assert evaluator.metric == "PQ" + assert evaluator.point_key_list == ["naturalness", "artifacts"] + + def test_score_success(self): + """测试score方法成功情况""" + evaluator = LMMJudgeImageEditEvaluator(metric="SC") + predictions = ["[5, 4]", "[3, 2]", "[4, 5]"] + references = ["ref1", "ref2", "ref3"] + + result = evaluator.score(predictions, references) + + assert "SC" in result + assert "details" in result + assert len(result["details"]) == 3 + # min(5,4)=4, min(3,2)=2, min(4,5)=4, average = (4+2+4)/3 = 3.33 + assert result["SC"] == pytest.approx(10/3, rel=1e-2) + + def test_score_pq_metric(self): + """测试PQ metric的score方法""" + evaluator = LMMJudgeImageEditEvaluator(metric="PQ") + predictions = ["[4, 5]", "[3, 3]"] + references = ["ref1", "ref2"] + + result = evaluator.score(predictions, references) + + assert "PQ" in result + assert len(result["details"]) == 2 + # min(4,5)=4, min(3,3)=3, average = (4+3)/2 = 3.5 + assert result["PQ"] == pytest.approx(3.5, rel=1e-2) + + def test_score_length_mismatch(self): + """测试长度不匹配的情况""" + evaluator = LMMJudgeImageEditEvaluator() + predictions = ["[1, 2]"] + references = ["ref1", "ref2"] + + result = evaluator.score(predictions, references) + + assert "error" in result + + def test_score_non_string_predictions(self): + """测试predictions不是字符串的情况""" + evaluator = LMMJudgeImageEditEvaluator() + predictions = [[1, 2], [3, 4]] + references = ["ref1", "ref2"] + + result = evaluator.score(predictions, references) + + assert "error" in result + + def test_score_invalid_prediction_format(self): + """测试prediction格式错误的情况""" + evaluator = LMMJudgeImageEditEvaluator(metric="SC") + # SC需要2个分数,但这里给出3个 + predictions = ["[1, 2, 3]"] + references = ["ref1"] + + result = evaluator.score(predictions, references) + + assert "details" in result + assert result["details"][0]["eval_success"] is False + assert "failed reason" in result["details"][0] + + def test_score_empty_predictions(self): + """测试空predictions""" + evaluator = LMMJudgeImageEditEvaluator() + predictions = [] + references = [] + + result = evaluator.score(predictions, references) + + assert "SC" in result + assert result["SC"] == 0.0 + assert "details" in result + + def test_score_detail_structure(self): + """测试details结构正确""" + evaluator = LMMJudgeImageEditEvaluator(metric="SC") + predictions = ["[5, 4]"] + references = ["ref1"] + + result = evaluator.score(predictions, references) + + detail = result["details"][0] + assert detail["eval_success"] is True + assert "pred" in detail + assert detail["pred"]["editing success"] == 5 + assert detail["pred"]["over editing"] == 4 + assert detail["org_uuid"] == "ref1" diff --git a/tests/UT/models/test_output.py b/tests/UT/models/test_output.py index e70f59ab..1f294802 100644 --- a/tests/UT/models/test_output.py +++ b/tests/UT/models/test_output.py @@ -1,7 +1,10 @@ import pytest import numpy as np import asyncio -from ais_bench.benchmark.models.output import Output, RequestOutput +import tempfile +import os +from PIL import Image +from ais_bench.benchmark.models.output import Output, RequestOutput, FunctionCallOutput, LMMOutput import time @@ -13,7 +16,6 @@ def get_metrics(self) -> dict: def test_output_initialization(): """测试Output类的初始化功能""" - # 测试默认参数初始化 output = ConcreteOutput() assert output.perf_mode is False assert output.success is False @@ -29,7 +31,6 @@ def test_output_initialization(): assert output.uuid == "" assert output.turn_id == 0 - # 测试perf_mode=True初始化 output_perf = ConcreteOutput(perf_mode=True) assert output_perf.perf_mode is True @@ -40,7 +41,13 @@ def test_concate_reasoning_content(): # 测试reasoning_content和content都不为空的情况 result1 = output._concate_reasoning_content("content", "reasoning") - assert result1 == "reasoningcontent" + # 验证结果包含reasoning和content,且reasoning在前 + assert "reasoning" in result1 + assert "content" in result1 + assert result1.startswith("reasoning") + assert result1.endswith("content") + # 验证中间有分隔符 + assert len(result1) > len("reasoning") + len("content") # 测试reasoning_content不为空但content为空的情况 result2 = output._concate_reasoning_content("", "reasoning") @@ -59,7 +66,6 @@ def test_get_prediction(): """测试get_prediction方法的不同分支""" output = ConcreteOutput() - # 测试reasoning_content为空的情况 output.content = "test content" output.reasoning_content = "" assert output.get_prediction() == "test content" @@ -67,16 +73,24 @@ def test_get_prediction(): # 测试content和reasoning_content都是列表的情况 output.content = ["content1", "content2"] output.reasoning_content = ["reasoning1", "reasoning2"] - assert output.get_prediction() == ["reasoning1content1", "reasoning2content2"] + result = output.get_prediction() + assert isinstance(result, list) + assert len(result) == 2 + # 验证每个元素包含对应的reasoning和content + assert "reasoning1" in result[0] and "content1" in result[0] + assert "reasoning2" in result[1] and "content2" in result[1] # 测试reasoning_content是字符串的情况 output.content = "content string" output.reasoning_content = "reasoning string" - assert output.get_prediction() == "reasoning stringcontent string" + result = output.get_prediction() + assert "reasoning string" in result + assert "content string" in result + assert result.startswith("reasoning string") # 测试其他类型的情况(应该返回原始content) output.content = "test content" - output.reasoning_content = None # 非字符串非列表类型 + output.reasoning_content = None assert output.get_prediction() == "test content" @@ -92,7 +106,6 @@ def test_to_dict(): assert result["content"] == "test" assert result["uuid"] == "test_uuid" assert result["turn_id"] == 1 - # 确保所有属性都被包含 assert "perf_mode" in result assert "success" in result assert "error_info" in result @@ -107,18 +120,15 @@ def test_to_dict(): def test_record_time_point(): """测试record_time_point异步方法""" - # 测试perf_mode=False时不记录时间点 output = ConcreteOutput(perf_mode=False) asyncio.run(output.record_time_point()) assert len(output.time_points) == 0 - # 测试perf_mode=True时记录时间点 output_perf = ConcreteOutput(perf_mode=True) asyncio.run(output_perf.record_time_point()) assert len(output_perf.time_points) == 1 assert isinstance(output_perf.time_points[0], float) - # 测试多次记录时间点 asyncio.run(output_perf.record_time_point()) assert len(output_perf.time_points) == 2 @@ -136,7 +146,6 @@ def test_clear_time_points(): def test_request_output_get_metrics(): """测试RequestOutput类的get_metrics方法的不同分支""" - # 测试success=False的情况 output = RequestOutput() output.success = False output.error_info = "test error" @@ -148,25 +157,20 @@ def test_request_output_get_metrics(): assert isinstance(metrics, dict) assert metrics["success"] is False assert metrics["error_info"] == "test error" - # 确保clean_result函数被正确应用 assert "content" not in metrics assert "reasoning_content" not in metrics assert "perf_mode" not in metrics - # 确保prediction字段被设置 assert "prediction" in metrics - # 测试success=True但time_points.size <= 1的情况 output = RequestOutput() output.success = True output.time_points = [time.perf_counter()] metrics = output.get_metrics() - assert metrics["success"] is False # 应该被设置为False + assert metrics["success"] is False assert metrics["error_info"] == "chunk size is less than 2" - # 确保time_points被转换为numpy数组 assert isinstance(metrics["time_points"], np.ndarray) - # 测试success=True且time_points.size > 1的情况 output = RequestOutput() output.success = True output.time_points = [time.perf_counter() - 1, time.perf_counter()] @@ -183,7 +187,6 @@ def test_request_output_get_metrics(): def test_request_output_edge_cases(): """测试RequestOutput类的边缘情况""" - # 测试空的time_points列表 output = RequestOutput() output.success = True output.time_points = [] @@ -192,7 +195,6 @@ def test_request_output_edge_cases(): assert metrics["success"] is False assert metrics["error_info"] == "chunk size is less than 2" - # 测试包含其他属性的情况 output = RequestOutput() output.success = False output.uuid = "test_uuid_123" @@ -202,4 +204,186 @@ def test_request_output_edge_cases(): metrics = output.get_metrics() assert metrics["uuid"] == "test_uuid_123" assert metrics["turn_id"] == 5 - assert metrics["extra_perf_data"] == {"test_key": "test_value"} \ No newline at end of file + assert metrics["extra_perf_data"] == {"test_key": "test_value"} + + +class TestFunctionCallOutput: + def test_init(self): + """测试FunctionCallOutput初始化""" + output = FunctionCallOutput() + assert output.perf_mode is False + assert isinstance(output.inference_log, list) + assert isinstance(output.tool_calls, list) + + def test_init_perf_mode(self): + """测试FunctionCallOutput性能模式初始化""" + output = FunctionCallOutput(perf_mode=True) + assert output.perf_mode is True + + def test_update_extra_details_data_from_text_response(self): + """测试update_extra_details_data_from_text_response方法""" + output = FunctionCallOutput() + text_response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "test content" + } + } + ] + } + + output.update_extra_details_data_from_text_response(text_response) + + assert "message" in output.extra_details_data + assert output.extra_details_data["message"]["role"] == "assistant" + + def test_update_extra_details_data_empty_choices(self): + """测试空choices的情况""" + output = FunctionCallOutput() + text_response = {"choices": []} + + output.update_extra_details_data_from_text_response(text_response) + + assert output.extra_details_data == {} + + def test_update_extra_details_data_no_choices(self): + """测试没有choices的情况""" + output = FunctionCallOutput() + text_response = {} + + output.update_extra_details_data_from_text_response(text_response) + + assert output.extra_details_data == {} + + def test_get_metrics_inherited(self): + """测试get_metrics方法(继承自Output抽象基类,返回None)""" + output = FunctionCallOutput() + output.success = True + output.uuid = "test_uuid" + output.tool_calls = [{"function": "test_func"}] + + metrics = output.get_metrics() + assert metrics is None + + +class TestLMMOutput: + def test_init(self): + """测试LMMOutput初始化""" + output = LMMOutput() + assert output.perf_mode is False + assert isinstance(output.content, list) + assert len(output.content) == 1 + assert output.content[0] == "" + + def test_init_perf_mode(self): + """测试LMMOutput性能模式初始化""" + output = LMMOutput(perf_mode=True) + assert output.perf_mode is True + + def test_handle_text(self): + """测试_handle_text方法""" + output = LMMOutput() + output.content = ["text content"] + + result = output._handle_text("/test/dir", 0) + assert result == "text content" + + def test_handle_image(self): + """测试_handle_image方法""" + output = LMMOutput() + output.uuid = "test_uuid" + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image] + + with tempfile.TemporaryDirectory() as tmpdir: + result = output._handle_image(tmpdir, 0) + + assert "image_test_uuid_0.png" in result + expected_path = os.path.join(tmpdir, "image_test_uuid_0.png") + assert os.path.exists(expected_path) + + def test_handle_image_overwrite(self): + """测试_handle_image方法覆盖已存在的文件""" + output = LMMOutput() + output.uuid = "test_uuid" + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image] + + with tempfile.TemporaryDirectory() as tmpdir: + result1 = output._handle_image(tmpdir, 0) + result2 = output._handle_image(tmpdir, 0) + + assert result1 == result2 + + def test_get_prediction_single_text(self): + """测试get_prediction方法,单个文本""" + output = LMMOutput() + output.uuid = "test_uuid" + output.content = ["text content"] + + with tempfile.TemporaryDirectory() as tmpdir: + result = output.get_prediction(tmpdir) + assert result == "text content" + + def test_get_prediction_single_image(self): + """测试get_prediction方法,单个图片""" + output = LMMOutput() + output.uuid = "test_uuid" + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image] + + with tempfile.TemporaryDirectory() as tmpdir: + result = output.get_prediction(tmpdir) + assert "image_test_uuid_0.png" in result + + def test_get_prediction_multiple_items(self): + """测试get_prediction方法,多个项目""" + output = LMMOutput() + output.uuid = "test_uuid" + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image, "text content"] + + with tempfile.TemporaryDirectory() as tmpdir: + result = output.get_prediction(tmpdir) + assert isinstance(result, list) + assert len(result) == 2 + + def test_get_metrics_inherited(self): + """测试get_metrics方法(继承自Output抽象基类,返回None)""" + output = LMMOutput() + output.success = True + output.uuid = "test_uuid" + output.content = ["test"] + + metrics = output.get_metrics() + assert metrics is None + + +def test_output_update_extra_perf_data_from_stream_response(): + """测试update_extra_perf_data_from_stream_response方法(默认实现)""" + output = ConcreteOutput() + output.update_extra_perf_data_from_stream_response({"test": "data"}) + assert output.extra_perf_data == {} + + +def test_output_update_extra_perf_data_from_text_response(): + """测试update_extra_perf_data_from_text_response方法(默认实现)""" + output = ConcreteOutput() + output.update_extra_perf_data_from_text_response({"test": "data"}) + assert output.extra_perf_data == {} + + +def test_output_update_extra_details_data_from_stream_response(): + """测试update_extra_details_data_from_stream_response方法(默认实现)""" + output = ConcreteOutput() + output.update_extra_details_data_from_stream_response({"test": "data"}) + assert output.extra_details_data == {} + + +def test_output_update_extra_details_data_from_text_response(): + """测试update_extra_details_data_from_text_response方法(默认实现)""" + output = ConcreteOutput() + output.update_extra_details_data_from_text_response({"test": "data"}) + assert output.extra_details_data == {} diff --git a/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py b/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py new file mode 100644 index 00000000..f7b8d439 --- /dev/null +++ b/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py @@ -0,0 +1,302 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock +import tempfile +import uuid +from pathlib import Path + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../../'))) + +from ais_bench.benchmark.openicl.icl_inferencer.output_handler.lmm_gen_inferencer_output_handler import ( + LMMGenInferencerOutputHandler +) +from ais_bench.benchmark.models.output import LMMOutput +from ais_bench.benchmark.utils.logging.exceptions import AISBenchRuntimeError + + +class TestLMMGenInferencerOutputHandler: + def setup_method(self): + """设置测试环境""" + self.handler = LMMGenInferencerOutputHandler() + self.handler.output_path = None + + def test_init(self): + """测试初始化""" + handler = LMMGenInferencerOutputHandler(perf_mode=True, save_every=50) + handler.output_path = None + assert handler.perf_mode is True + assert handler.save_every == 50 + + def test_set_output_path(self): + """测试set_output_path方法""" + self.handler.set_output_path('/test/output/path') + assert self.handler.output_path == '/test/output/path' + + def test_get_prediction_result_with_string_output(self): + """测试get_prediction_result方法,字符串输出""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + result = self.handler.get_prediction_result( + output="test prediction", + gold="test gold", + input="test input", + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert result["prediction"] == "test prediction" + assert result["gold"] == "test gold" + assert result["origin_prompt"] == "test input" + assert len(result["uuid"]) == 32 + + def test_get_prediction_result_with_lmm_output(self): + """测试get_prediction_result方法,LMMOutput输出""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + output = LMMOutput() + output.uuid = "test_uuid_123" + output.success = True + output.content = ["text content"] + + result = self.handler.get_prediction_result( + output=output, + gold="test gold", + input="test input", + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert result["uuid"] == "test_uuid_123" + assert result["prediction"] == "text content" + assert result["gold"] == "test gold" + + def test_get_prediction_result_with_image_output(self): + """测试get_prediction_result方法,图片输出""" + from PIL import Image + + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + output = LMMOutput() + output.uuid = "test_uuid_456" + output.success = True + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image] + + result = self.handler.get_prediction_result( + output=output, + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert "image_test_uuid_456_0.png" in result["prediction"] + + def test_get_prediction_result_creates_output_dir(self): + """测试get_prediction_result方法创建输出目录""" + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, 'subdir') + self.handler.set_output_path(output_path) + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + data_abbr="test_dataset" + ) + + assert os.path.exists(os.path.join(output_path, "test_dataset_out_file")) + + def test_get_prediction_result_with_long_base64_input(self): + """测试get_prediction_result方法,处理长base64输入""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + long_base64 = "a" * 300 + input_data = [{ + "prompt": [ + { + "type": "image_url", + "image_url": { + "url": long_base64 + } + } + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert "..." in input_data[0]["prompt"][0]["image_url"]["url"] + + def test_get_prediction_result_with_dict_input(self): + """测试get_prediction_result方法,字典类型输入""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + input_data = [{ + "prompt": [ + { + "type": "text", + "text": "test text" + } + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True + + def test_get_prediction_result_with_empty_input(self): + """测试get_prediction_result方法,空输入""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=None, + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert result["origin_prompt"] == "" + + def test_get_prediction_result_without_gold(self): + """测试get_prediction_result方法,没有gold""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + result = self.handler.get_prediction_result( + output="test prediction", + input="test input", + data_abbr="test_dataset" + ) + + assert "gold" not in result + + def test_get_prediction_result_with_failed_output(self): + """测试get_prediction_result方法,失败的输出""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = False + output.error_info = "test error" + output.content = [""] + + result = self.handler.get_prediction_result( + output=output, + data_abbr="test_dataset" + ) + + assert result["success"] is False + + def test_get_prediction_result_with_non_dict_prompt_items(self): + """测试get_prediction_result方法,非字典类型的prompt项""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + input_data = [{ + "prompt": [ + "string_item", + 123, + None + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True + + def test_get_prediction_result_with_non_dict_image_url(self): + """测试get_prediction_result方法,非字典类型的image_url""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + input_data = [{ + "prompt": [ + { + "type": "image_url", + "image_url": "not_a_dict" + } + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True + + def test_get_prediction_result_with_non_string_url(self): + """测试get_prediction_result方法,非字符串类型的URL""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + input_data = [{ + "prompt": [ + { + "type": "image_url", + "image_url": { + "url": 123 + } + } + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True diff --git a/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py b/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py new file mode 100644 index 00000000..99c9970d --- /dev/null +++ b/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py @@ -0,0 +1,177 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock, call +import uuid + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) + +from ais_bench.benchmark.models.output import LMMOutput + + +class TestLMMGenInferencer: + def setup_method(self): + """设置测试环境""" + self.mock_model_cfg = { + 'type': 'MockModel', + 'abbr': 'test_model' + } + + def test_init(self): + """测试LMMGenInferencer初始化""" + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + stopping_criteria=[], + batch_size=1, + mode='infer' + ) + + assert inferencer.model_cfg == self.mock_model_cfg + assert inferencer.batch_size == 1 + mock_handler.assert_called_once() + + def test_init_with_custom_params(self): + """测试使用自定义参数初始化""" + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + stopping_criteria=['stop1', 'stop2'], + batch_size=4, + mode='perf', + gen_field_replace_token='', + output_json_filepath='/test/output', + save_every=50 + ) + + assert inferencer.stopping_criteria == ['stop1', 'stop2'] + assert inferencer.batch_size == 4 + assert inferencer.perf_mode is True + + def test_inference(self): + """测试inference方法""" + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + mock_retriever = MagicMock() + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_gen_inferencer.GenInferencer.inference') as mock_super_inference: + mock_super_inference.return_value = [] + + result = inferencer.inference(mock_retriever, '/test/output.jsonl') + + mock_handler_instance.set_output_path.assert_called_once_with('/test/output.jsonl') + mock_super_inference.assert_called_once() + + def test_batch_inference(self): + """测试batch_inference方法""" + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + inferencer.model = MagicMock() + + datum = { + 'index': [0, 1], + 'prompt': ['prompt1', 'prompt2'], + 'data_abbr': ['dataset1', 'dataset1'], + 'gold': ['gold1', 'gold2'], + 'extra_param': 'value' + } + + inferencer.batch_inference(datum) + + inferencer.model.generate.assert_called_once() + assert mock_handler_instance.report_cache_info_sync.call_count == 2 + + def test_batch_inference_without_gold(self): + """测试batch_inference方法,没有gold字段""" + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + inferencer.model = MagicMock() + + datum = { + 'index': [0], + 'prompt': ['prompt1'], + 'data_abbr': ['dataset1'] + } + + inferencer.batch_inference(datum) + + inferencer.model.generate.assert_called_once() + mock_handler_instance.report_cache_info_sync.assert_called_once() + + def test_batch_inference_uuid_generation(self): + """测试batch_inference方法中UUID生成""" + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + inferencer.model = MagicMock() + + datum = { + 'index': [0], + 'prompt': ['prompt1'], + 'data_abbr': ['dataset1'] + } + + inferencer.batch_inference(datum) + + call_args = mock_handler_instance.report_cache_info_sync.call_args + output = call_args[0][2] + assert isinstance(output, LMMOutput) + assert len(output.uuid) == 32