From 926adb3a6c9eff695e8ad16408c4f8508b21dae0 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Thu, 5 Mar 2026 08:55:27 +0800 Subject: [PATCH] ut part 2 --- tests/UT/models/test_output.py | 228 +++++++++++-- .../test_lmm_gen_inferencer_output_handler.py | 302 ++++++++++++++++++ .../test_icl_lmm_gen_inferencer.py | 177 ++++++++++ 3 files changed, 685 insertions(+), 22 deletions(-) create mode 100644 tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py create mode 100644 tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py diff --git a/tests/UT/models/test_output.py b/tests/UT/models/test_output.py index e70f59ab..1f294802 100644 --- a/tests/UT/models/test_output.py +++ b/tests/UT/models/test_output.py @@ -1,7 +1,10 @@ import pytest import numpy as np import asyncio -from ais_bench.benchmark.models.output import Output, RequestOutput +import tempfile +import os +from PIL import Image +from ais_bench.benchmark.models.output import Output, RequestOutput, FunctionCallOutput, LMMOutput import time @@ -13,7 +16,6 @@ def get_metrics(self) -> dict: def test_output_initialization(): """测试Output类的初始化功能""" - # 测试默认参数初始化 output = ConcreteOutput() assert output.perf_mode is False assert output.success is False @@ -29,7 +31,6 @@ def test_output_initialization(): assert output.uuid == "" assert output.turn_id == 0 - # 测试perf_mode=True初始化 output_perf = ConcreteOutput(perf_mode=True) assert output_perf.perf_mode is True @@ -40,7 +41,13 @@ def test_concate_reasoning_content(): # 测试reasoning_content和content都不为空的情况 result1 = output._concate_reasoning_content("content", "reasoning") - assert result1 == "reasoningcontent" + # 验证结果包含reasoning和content,且reasoning在前 + assert "reasoning" in result1 + assert "content" in result1 + assert result1.startswith("reasoning") + assert result1.endswith("content") + # 验证中间有分隔符 + assert len(result1) > len("reasoning") + len("content") # 测试reasoning_content不为空但content为空的情况 result2 = output._concate_reasoning_content("", "reasoning") @@ -59,7 +66,6 @@ def test_get_prediction(): """测试get_prediction方法的不同分支""" output = ConcreteOutput() - # 测试reasoning_content为空的情况 output.content = "test content" output.reasoning_content = "" assert output.get_prediction() == "test content" @@ -67,16 +73,24 @@ def test_get_prediction(): # 测试content和reasoning_content都是列表的情况 output.content = ["content1", "content2"] output.reasoning_content = ["reasoning1", "reasoning2"] - assert output.get_prediction() == ["reasoning1content1", "reasoning2content2"] + result = output.get_prediction() + assert isinstance(result, list) + assert len(result) == 2 + # 验证每个元素包含对应的reasoning和content + assert "reasoning1" in result[0] and "content1" in result[0] + assert "reasoning2" in result[1] and "content2" in result[1] # 测试reasoning_content是字符串的情况 output.content = "content string" output.reasoning_content = "reasoning string" - assert output.get_prediction() == "reasoning stringcontent string" + result = output.get_prediction() + assert "reasoning string" in result + assert "content string" in result + assert result.startswith("reasoning string") # 测试其他类型的情况(应该返回原始content) output.content = "test content" - output.reasoning_content = None # 非字符串非列表类型 + output.reasoning_content = None assert output.get_prediction() == "test content" @@ -92,7 +106,6 @@ def test_to_dict(): assert result["content"] == "test" assert result["uuid"] == "test_uuid" assert result["turn_id"] == 1 - # 确保所有属性都被包含 assert "perf_mode" in result assert "success" in result assert "error_info" in result @@ -107,18 +120,15 @@ def test_to_dict(): def test_record_time_point(): """测试record_time_point异步方法""" - # 测试perf_mode=False时不记录时间点 output = ConcreteOutput(perf_mode=False) asyncio.run(output.record_time_point()) assert len(output.time_points) == 0 - # 测试perf_mode=True时记录时间点 output_perf = ConcreteOutput(perf_mode=True) asyncio.run(output_perf.record_time_point()) assert len(output_perf.time_points) == 1 assert isinstance(output_perf.time_points[0], float) - # 测试多次记录时间点 asyncio.run(output_perf.record_time_point()) assert len(output_perf.time_points) == 2 @@ -136,7 +146,6 @@ def test_clear_time_points(): def test_request_output_get_metrics(): """测试RequestOutput类的get_metrics方法的不同分支""" - # 测试success=False的情况 output = RequestOutput() output.success = False output.error_info = "test error" @@ -148,25 +157,20 @@ def test_request_output_get_metrics(): assert isinstance(metrics, dict) assert metrics["success"] is False assert metrics["error_info"] == "test error" - # 确保clean_result函数被正确应用 assert "content" not in metrics assert "reasoning_content" not in metrics assert "perf_mode" not in metrics - # 确保prediction字段被设置 assert "prediction" in metrics - # 测试success=True但time_points.size <= 1的情况 output = RequestOutput() output.success = True output.time_points = [time.perf_counter()] metrics = output.get_metrics() - assert metrics["success"] is False # 应该被设置为False + assert metrics["success"] is False assert metrics["error_info"] == "chunk size is less than 2" - # 确保time_points被转换为numpy数组 assert isinstance(metrics["time_points"], np.ndarray) - # 测试success=True且time_points.size > 1的情况 output = RequestOutput() output.success = True output.time_points = [time.perf_counter() - 1, time.perf_counter()] @@ -183,7 +187,6 @@ def test_request_output_get_metrics(): def test_request_output_edge_cases(): """测试RequestOutput类的边缘情况""" - # 测试空的time_points列表 output = RequestOutput() output.success = True output.time_points = [] @@ -192,7 +195,6 @@ def test_request_output_edge_cases(): assert metrics["success"] is False assert metrics["error_info"] == "chunk size is less than 2" - # 测试包含其他属性的情况 output = RequestOutput() output.success = False output.uuid = "test_uuid_123" @@ -202,4 +204,186 @@ def test_request_output_edge_cases(): metrics = output.get_metrics() assert metrics["uuid"] == "test_uuid_123" assert metrics["turn_id"] == 5 - assert metrics["extra_perf_data"] == {"test_key": "test_value"} \ No newline at end of file + assert metrics["extra_perf_data"] == {"test_key": "test_value"} + + +class TestFunctionCallOutput: + def test_init(self): + """测试FunctionCallOutput初始化""" + output = FunctionCallOutput() + assert output.perf_mode is False + assert isinstance(output.inference_log, list) + assert isinstance(output.tool_calls, list) + + def test_init_perf_mode(self): + """测试FunctionCallOutput性能模式初始化""" + output = FunctionCallOutput(perf_mode=True) + assert output.perf_mode is True + + def test_update_extra_details_data_from_text_response(self): + """测试update_extra_details_data_from_text_response方法""" + output = FunctionCallOutput() + text_response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "test content" + } + } + ] + } + + output.update_extra_details_data_from_text_response(text_response) + + assert "message" in output.extra_details_data + assert output.extra_details_data["message"]["role"] == "assistant" + + def test_update_extra_details_data_empty_choices(self): + """测试空choices的情况""" + output = FunctionCallOutput() + text_response = {"choices": []} + + output.update_extra_details_data_from_text_response(text_response) + + assert output.extra_details_data == {} + + def test_update_extra_details_data_no_choices(self): + """测试没有choices的情况""" + output = FunctionCallOutput() + text_response = {} + + output.update_extra_details_data_from_text_response(text_response) + + assert output.extra_details_data == {} + + def test_get_metrics_inherited(self): + """测试get_metrics方法(继承自Output抽象基类,返回None)""" + output = FunctionCallOutput() + output.success = True + output.uuid = "test_uuid" + output.tool_calls = [{"function": "test_func"}] + + metrics = output.get_metrics() + assert metrics is None + + +class TestLMMOutput: + def test_init(self): + """测试LMMOutput初始化""" + output = LMMOutput() + assert output.perf_mode is False + assert isinstance(output.content, list) + assert len(output.content) == 1 + assert output.content[0] == "" + + def test_init_perf_mode(self): + """测试LMMOutput性能模式初始化""" + output = LMMOutput(perf_mode=True) + assert output.perf_mode is True + + def test_handle_text(self): + """测试_handle_text方法""" + output = LMMOutput() + output.content = ["text content"] + + result = output._handle_text("/test/dir", 0) + assert result == "text content" + + def test_handle_image(self): + """测试_handle_image方法""" + output = LMMOutput() + output.uuid = "test_uuid" + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image] + + with tempfile.TemporaryDirectory() as tmpdir: + result = output._handle_image(tmpdir, 0) + + assert "image_test_uuid_0.png" in result + expected_path = os.path.join(tmpdir, "image_test_uuid_0.png") + assert os.path.exists(expected_path) + + def test_handle_image_overwrite(self): + """测试_handle_image方法覆盖已存在的文件""" + output = LMMOutput() + output.uuid = "test_uuid" + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image] + + with tempfile.TemporaryDirectory() as tmpdir: + result1 = output._handle_image(tmpdir, 0) + result2 = output._handle_image(tmpdir, 0) + + assert result1 == result2 + + def test_get_prediction_single_text(self): + """测试get_prediction方法,单个文本""" + output = LMMOutput() + output.uuid = "test_uuid" + output.content = ["text content"] + + with tempfile.TemporaryDirectory() as tmpdir: + result = output.get_prediction(tmpdir) + assert result == "text content" + + def test_get_prediction_single_image(self): + """测试get_prediction方法,单个图片""" + output = LMMOutput() + output.uuid = "test_uuid" + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image] + + with tempfile.TemporaryDirectory() as tmpdir: + result = output.get_prediction(tmpdir) + assert "image_test_uuid_0.png" in result + + def test_get_prediction_multiple_items(self): + """测试get_prediction方法,多个项目""" + output = LMMOutput() + output.uuid = "test_uuid" + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image, "text content"] + + with tempfile.TemporaryDirectory() as tmpdir: + result = output.get_prediction(tmpdir) + assert isinstance(result, list) + assert len(result) == 2 + + def test_get_metrics_inherited(self): + """测试get_metrics方法(继承自Output抽象基类,返回None)""" + output = LMMOutput() + output.success = True + output.uuid = "test_uuid" + output.content = ["test"] + + metrics = output.get_metrics() + assert metrics is None + + +def test_output_update_extra_perf_data_from_stream_response(): + """测试update_extra_perf_data_from_stream_response方法(默认实现)""" + output = ConcreteOutput() + output.update_extra_perf_data_from_stream_response({"test": "data"}) + assert output.extra_perf_data == {} + + +def test_output_update_extra_perf_data_from_text_response(): + """测试update_extra_perf_data_from_text_response方法(默认实现)""" + output = ConcreteOutput() + output.update_extra_perf_data_from_text_response({"test": "data"}) + assert output.extra_perf_data == {} + + +def test_output_update_extra_details_data_from_stream_response(): + """测试update_extra_details_data_from_stream_response方法(默认实现)""" + output = ConcreteOutput() + output.update_extra_details_data_from_stream_response({"test": "data"}) + assert output.extra_details_data == {} + + +def test_output_update_extra_details_data_from_text_response(): + """测试update_extra_details_data_from_text_response方法(默认实现)""" + output = ConcreteOutput() + output.update_extra_details_data_from_text_response({"test": "data"}) + assert output.extra_details_data == {} diff --git a/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py b/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py new file mode 100644 index 00000000..f7b8d439 --- /dev/null +++ b/tests/UT/openicl/icl_inferencer/output_handler/test_lmm_gen_inferencer_output_handler.py @@ -0,0 +1,302 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock +import tempfile +import uuid +from pathlib import Path + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../../'))) + +from ais_bench.benchmark.openicl.icl_inferencer.output_handler.lmm_gen_inferencer_output_handler import ( + LMMGenInferencerOutputHandler +) +from ais_bench.benchmark.models.output import LMMOutput +from ais_bench.benchmark.utils.logging.exceptions import AISBenchRuntimeError + + +class TestLMMGenInferencerOutputHandler: + def setup_method(self): + """设置测试环境""" + self.handler = LMMGenInferencerOutputHandler() + self.handler.output_path = None + + def test_init(self): + """测试初始化""" + handler = LMMGenInferencerOutputHandler(perf_mode=True, save_every=50) + handler.output_path = None + assert handler.perf_mode is True + assert handler.save_every == 50 + + def test_set_output_path(self): + """测试set_output_path方法""" + self.handler.set_output_path('/test/output/path') + assert self.handler.output_path == '/test/output/path' + + def test_get_prediction_result_with_string_output(self): + """测试get_prediction_result方法,字符串输出""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + result = self.handler.get_prediction_result( + output="test prediction", + gold="test gold", + input="test input", + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert result["prediction"] == "test prediction" + assert result["gold"] == "test gold" + assert result["origin_prompt"] == "test input" + assert len(result["uuid"]) == 32 + + def test_get_prediction_result_with_lmm_output(self): + """测试get_prediction_result方法,LMMOutput输出""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + output = LMMOutput() + output.uuid = "test_uuid_123" + output.success = True + output.content = ["text content"] + + result = self.handler.get_prediction_result( + output=output, + gold="test gold", + input="test input", + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert result["uuid"] == "test_uuid_123" + assert result["prediction"] == "text content" + assert result["gold"] == "test gold" + + def test_get_prediction_result_with_image_output(self): + """测试get_prediction_result方法,图片输出""" + from PIL import Image + + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + output = LMMOutput() + output.uuid = "test_uuid_456" + output.success = True + test_image = Image.new('RGB', (100, 100), color='red') + output.content = [test_image] + + result = self.handler.get_prediction_result( + output=output, + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert "image_test_uuid_456_0.png" in result["prediction"] + + def test_get_prediction_result_creates_output_dir(self): + """测试get_prediction_result方法创建输出目录""" + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, 'subdir') + self.handler.set_output_path(output_path) + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + data_abbr="test_dataset" + ) + + assert os.path.exists(os.path.join(output_path, "test_dataset_out_file")) + + def test_get_prediction_result_with_long_base64_input(self): + """测试get_prediction_result方法,处理长base64输入""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + long_base64 = "a" * 300 + input_data = [{ + "prompt": [ + { + "type": "image_url", + "image_url": { + "url": long_base64 + } + } + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert "..." in input_data[0]["prompt"][0]["image_url"]["url"] + + def test_get_prediction_result_with_dict_input(self): + """测试get_prediction_result方法,字典类型输入""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + input_data = [{ + "prompt": [ + { + "type": "text", + "text": "test text" + } + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True + + def test_get_prediction_result_with_empty_input(self): + """测试get_prediction_result方法,空输入""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=None, + data_abbr="test_dataset" + ) + + assert result["success"] is True + assert result["origin_prompt"] == "" + + def test_get_prediction_result_without_gold(self): + """测试get_prediction_result方法,没有gold""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + result = self.handler.get_prediction_result( + output="test prediction", + input="test input", + data_abbr="test_dataset" + ) + + assert "gold" not in result + + def test_get_prediction_result_with_failed_output(self): + """测试get_prediction_result方法,失败的输出""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = False + output.error_info = "test error" + output.content = [""] + + result = self.handler.get_prediction_result( + output=output, + data_abbr="test_dataset" + ) + + assert result["success"] is False + + def test_get_prediction_result_with_non_dict_prompt_items(self): + """测试get_prediction_result方法,非字典类型的prompt项""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + input_data = [{ + "prompt": [ + "string_item", + 123, + None + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True + + def test_get_prediction_result_with_non_dict_image_url(self): + """测试get_prediction_result方法,非字典类型的image_url""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + input_data = [{ + "prompt": [ + { + "type": "image_url", + "image_url": "not_a_dict" + } + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True + + def test_get_prediction_result_with_non_string_url(self): + """测试get_prediction_result方法,非字符串类型的URL""" + with tempfile.TemporaryDirectory() as tmpdir: + self.handler.set_output_path(tmpdir) + + input_data = [{ + "prompt": [ + { + "type": "image_url", + "image_url": { + "url": 123 + } + } + ] + }] + + output = LMMOutput() + output.uuid = "test_uuid" + output.success = True + output.content = ["text"] + + result = self.handler.get_prediction_result( + output=output, + input=input_data, + data_abbr="test_dataset" + ) + + assert result["success"] is True diff --git a/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py b/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py new file mode 100644 index 00000000..99c9970d --- /dev/null +++ b/tests/UT/openicl/icl_inferencer/test_icl_lmm_gen_inferencer.py @@ -0,0 +1,177 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock, call +import uuid + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) + +from ais_bench.benchmark.models.output import LMMOutput + + +class TestLMMGenInferencer: + def setup_method(self): + """设置测试环境""" + self.mock_model_cfg = { + 'type': 'MockModel', + 'abbr': 'test_model' + } + + def test_init(self): + """测试LMMGenInferencer初始化""" + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + stopping_criteria=[], + batch_size=1, + mode='infer' + ) + + assert inferencer.model_cfg == self.mock_model_cfg + assert inferencer.batch_size == 1 + mock_handler.assert_called_once() + + def test_init_with_custom_params(self): + """测试使用自定义参数初始化""" + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + stopping_criteria=['stop1', 'stop2'], + batch_size=4, + mode='perf', + gen_field_replace_token='', + output_json_filepath='/test/output', + save_every=50 + ) + + assert inferencer.stopping_criteria == ['stop1', 'stop2'] + assert inferencer.batch_size == 4 + assert inferencer.perf_mode is True + + def test_inference(self): + """测试inference方法""" + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + mock_retriever = MagicMock() + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_gen_inferencer.GenInferencer.inference') as mock_super_inference: + mock_super_inference.return_value = [] + + result = inferencer.inference(mock_retriever, '/test/output.jsonl') + + mock_handler_instance.set_output_path.assert_called_once_with('/test/output.jsonl') + mock_super_inference.assert_called_once() + + def test_batch_inference(self): + """测试batch_inference方法""" + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + inferencer.model = MagicMock() + + datum = { + 'index': [0, 1], + 'prompt': ['prompt1', 'prompt2'], + 'data_abbr': ['dataset1', 'dataset1'], + 'gold': ['gold1', 'gold2'], + 'extra_param': 'value' + } + + inferencer.batch_inference(datum) + + inferencer.model.generate.assert_called_once() + assert mock_handler_instance.report_cache_info_sync.call_count == 2 + + def test_batch_inference_without_gold(self): + """测试batch_inference方法,没有gold字段""" + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + inferencer.model = MagicMock() + + datum = { + 'index': [0], + 'prompt': ['prompt1'], + 'data_abbr': ['dataset1'] + } + + inferencer.batch_inference(datum) + + inferencer.model.generate.assert_called_once() + mock_handler_instance.report_cache_info_sync.assert_called_once() + + def test_batch_inference_uuid_generation(self): + """测试batch_inference方法中UUID生成""" + mock_model = MagicMock() + + with patch('ais_bench.benchmark.utils.config.build.MODELS.build', return_value=mock_model): + from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer + + with patch('ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer.LMMGenInferencerOutputHandler') as mock_handler: + mock_handler_instance = MagicMock() + mock_handler.return_value = mock_handler_instance + + inferencer = LMMGenInferencer( + model_cfg=self.mock_model_cfg, + batch_size=1 + ) + + inferencer.model = MagicMock() + + datum = { + 'index': [0], + 'prompt': ['prompt1'], + 'data_abbr': ['dataset1'] + } + + inferencer.batch_inference(datum) + + call_args = mock_handler_instance.report_cache_info_sync.call_args + output = call_args[0][2] + assert isinstance(output, LMMOutput) + assert len(output.uuid) == 32