Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 206 additions & 22 deletions tests/UT/models/test_output.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pytest
import numpy as np
import asyncio
from ais_bench.benchmark.models.output import Output, RequestOutput
import tempfile
import os
from PIL import Image
from ais_bench.benchmark.models.output import Output, RequestOutput, FunctionCallOutput, LMMOutput
import time


Expand All @@ -13,7 +16,6 @@ def get_metrics(self) -> dict:

def test_output_initialization():
"""测试Output类的初始化功能"""
# 测试默认参数初始化
output = ConcreteOutput()
assert output.perf_mode is False
assert output.success is False
Expand All @@ -29,7 +31,6 @@ def test_output_initialization():
assert output.uuid == ""
assert output.turn_id == 0

# 测试perf_mode=True初始化
output_perf = ConcreteOutput(perf_mode=True)
assert output_perf.perf_mode is True

Expand All @@ -40,7 +41,13 @@ def test_concate_reasoning_content():

# 测试reasoning_content和content都不为空的情况
result1 = output._concate_reasoning_content("content", "reasoning")
assert result1 == "reasoning</think>content"
# 验证结果包含reasoning和content,且reasoning在前
assert "reasoning" in result1
assert "content" in result1
assert result1.startswith("reasoning")
assert result1.endswith("content")
# 验证中间有分隔符
assert len(result1) > len("reasoning") + len("content")
Comment on lines +44 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此处的测试断言被改得不够具体。原来的断言 assert result1 == "reasoning</think>content" 更好,因为它验证了确切的输出格式,包括分隔符 </think>。当前的断言过于通用,如果分隔符格式发生变化,可能无法捕获到回归问题。建议测试具体的预期输出,以确保代码的健壮性。

Suggested change
# 验证结果包含reasoning和content,且reasoning在前
assert "reasoning" in result1
assert "content" in result1
assert result1.startswith("reasoning")
assert result1.endswith("content")
# 验证中间有分隔符
assert len(result1) > len("reasoning") + len("content")
assert result1 == "reasoning</think>content"


# 测试reasoning_content不为空但content为空的情况
result2 = output._concate_reasoning_content("", "reasoning")
Expand All @@ -59,24 +66,31 @@ def test_get_prediction():
"""测试get_prediction方法的不同分支"""
output = ConcreteOutput()

# 测试reasoning_content为空的情况
output.content = "test content"
output.reasoning_content = ""
assert output.get_prediction() == "test content"

# 测试content和reasoning_content都是列表的情况
output.content = ["content1", "content2"]
output.reasoning_content = ["reasoning1", "reasoning2"]
assert output.get_prediction() == ["reasoning1</think>content1", "reasoning2</think>content2"]
result = output.get_prediction()
assert isinstance(result, list)
assert len(result) == 2
# 验证每个元素包含对应的reasoning和content
assert "reasoning1" in result[0] and "content1" in result[0]
assert "reasoning2" in result[1] and "content2" in result[1]
Comment on lines +76 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

与前一个评论类似,这里的断言也被削弱了。原来的测试 assert output.get_prediction() == ["reasoning1</think>content1", "reasoning2</think>content2"] 更为精确,能够确保列表中拼接后字符串的正确格式。当前的测试只检查子字符串是否存在,不够健壮。

Suggested change
result = output.get_prediction()
assert isinstance(result, list)
assert len(result) == 2
# 验证每个元素包含对应的reasoning和content
assert "reasoning1" in result[0] and "content1" in result[0]
assert "reasoning2" in result[1] and "content2" in result[1]
assert output.get_prediction() == ["reasoning1</think>content1", "reasoning2</think>content2"]


# 测试reasoning_content是字符串的情况
output.content = "content string"
output.reasoning_content = "reasoning string"
assert output.get_prediction() == "reasoning string</think>content string"
result = output.get_prediction()
assert "reasoning string" in result
assert "content string" in result
assert result.startswith("reasoning string")
Comment on lines +86 to +89
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

字符串情况的断言也被削弱了。最好是检查确切的预期字符串,以确保 _concate_reasoning_content 方法在 get_prediction 中按预期工作。

Suggested change
result = output.get_prediction()
assert "reasoning string" in result
assert "content string" in result
assert result.startswith("reasoning string")
assert output.get_prediction() == "reasoning string</think>content string"


# 测试其他类型的情况(应该返回原始content)
output.content = "test content"
output.reasoning_content = None # 非字符串非列表类型
output.reasoning_content = None
assert output.get_prediction() == "test content"


Expand All @@ -92,7 +106,6 @@ def test_to_dict():
assert result["content"] == "test"
assert result["uuid"] == "test_uuid"
assert result["turn_id"] == 1
# 确保所有属性都被包含
assert "perf_mode" in result
assert "success" in result
assert "error_info" in result
Expand All @@ -107,18 +120,15 @@ def test_to_dict():

def test_record_time_point():
"""测试record_time_point异步方法"""
# 测试perf_mode=False时不记录时间点
output = ConcreteOutput(perf_mode=False)
asyncio.run(output.record_time_point())
assert len(output.time_points) == 0

# 测试perf_mode=True时记录时间点
output_perf = ConcreteOutput(perf_mode=True)
asyncio.run(output_perf.record_time_point())
assert len(output_perf.time_points) == 1
assert isinstance(output_perf.time_points[0], float)

# 测试多次记录时间点
asyncio.run(output_perf.record_time_point())
assert len(output_perf.time_points) == 2

Expand All @@ -136,7 +146,6 @@ def test_clear_time_points():

def test_request_output_get_metrics():
"""测试RequestOutput类的get_metrics方法的不同分支"""
# 测试success=False的情况
output = RequestOutput()
output.success = False
output.error_info = "test error"
Expand All @@ -148,25 +157,20 @@ def test_request_output_get_metrics():
assert isinstance(metrics, dict)
assert metrics["success"] is False
assert metrics["error_info"] == "test error"
# 确保clean_result函数被正确应用
assert "content" not in metrics
assert "reasoning_content" not in metrics
assert "perf_mode" not in metrics
# 确保prediction字段被设置
assert "prediction" in metrics

# 测试success=True但time_points.size <= 1的情况
output = RequestOutput()
output.success = True
output.time_points = [time.perf_counter()]

metrics = output.get_metrics()
assert metrics["success"] is False # 应该被设置为False
assert metrics["success"] is False
assert metrics["error_info"] == "chunk size is less than 2"
# 确保time_points被转换为numpy数组
assert isinstance(metrics["time_points"], np.ndarray)

# 测试success=True且time_points.size > 1的情况
output = RequestOutput()
output.success = True
output.time_points = [time.perf_counter() - 1, time.perf_counter()]
Expand All @@ -183,7 +187,6 @@ def test_request_output_get_metrics():

def test_request_output_edge_cases():
"""测试RequestOutput类的边缘情况"""
# 测试空的time_points列表
output = RequestOutput()
output.success = True
output.time_points = []
Expand All @@ -192,7 +195,6 @@ def test_request_output_edge_cases():
assert metrics["success"] is False
assert metrics["error_info"] == "chunk size is less than 2"

# 测试包含其他属性的情况
output = RequestOutput()
output.success = False
output.uuid = "test_uuid_123"
Expand All @@ -202,4 +204,186 @@ def test_request_output_edge_cases():
metrics = output.get_metrics()
assert metrics["uuid"] == "test_uuid_123"
assert metrics["turn_id"] == 5
assert metrics["extra_perf_data"] == {"test_key": "test_value"}
assert metrics["extra_perf_data"] == {"test_key": "test_value"}


class TestFunctionCallOutput:
def test_init(self):
"""测试FunctionCallOutput初始化"""
output = FunctionCallOutput()
assert output.perf_mode is False
assert isinstance(output.inference_log, list)
assert isinstance(output.tool_calls, list)

def test_init_perf_mode(self):
"""测试FunctionCallOutput性能模式初始化"""
output = FunctionCallOutput(perf_mode=True)
assert output.perf_mode is True

def test_update_extra_details_data_from_text_response(self):
"""测试update_extra_details_data_from_text_response方法"""
output = FunctionCallOutput()
text_response = {
"choices": [
{
"message": {
"role": "assistant",
"content": "test content"
}
}
]
}

output.update_extra_details_data_from_text_response(text_response)

assert "message" in output.extra_details_data
assert output.extra_details_data["message"]["role"] == "assistant"

def test_update_extra_details_data_empty_choices(self):
"""测试空choices的情况"""
output = FunctionCallOutput()
text_response = {"choices": []}

output.update_extra_details_data_from_text_response(text_response)

assert output.extra_details_data == {}

def test_update_extra_details_data_no_choices(self):
"""测试没有choices的情况"""
output = FunctionCallOutput()
text_response = {}

output.update_extra_details_data_from_text_response(text_response)

assert output.extra_details_data == {}

def test_get_metrics_inherited(self):
"""测试get_metrics方法(继承自Output抽象基类,返回None)"""
output = FunctionCallOutput()
output.success = True
output.uuid = "test_uuid"
output.tool_calls = [{"function": "test_func"}]

metrics = output.get_metrics()
assert metrics is None


class TestLMMOutput:
def test_init(self):
"""测试LMMOutput初始化"""
output = LMMOutput()
assert output.perf_mode is False
assert isinstance(output.content, list)
assert len(output.content) == 1
assert output.content[0] == ""

def test_init_perf_mode(self):
"""测试LMMOutput性能模式初始化"""
output = LMMOutput(perf_mode=True)
assert output.perf_mode is True

def test_handle_text(self):
"""测试_handle_text方法"""
output = LMMOutput()
output.content = ["text content"]

result = output._handle_text("/test/dir", 0)
assert result == "text content"

def test_handle_image(self):
"""测试_handle_image方法"""
output = LMMOutput()
output.uuid = "test_uuid"
test_image = Image.new('RGB', (100, 100), color='red')
output.content = [test_image]

with tempfile.TemporaryDirectory() as tmpdir:
result = output._handle_image(tmpdir, 0)

assert "image_test_uuid_0.png" in result
expected_path = os.path.join(tmpdir, "image_test_uuid_0.png")
assert os.path.exists(expected_path)

def test_handle_image_overwrite(self):
"""测试_handle_image方法覆盖已存在的文件"""
output = LMMOutput()
output.uuid = "test_uuid"
test_image = Image.new('RGB', (100, 100), color='red')
output.content = [test_image]

with tempfile.TemporaryDirectory() as tmpdir:
result1 = output._handle_image(tmpdir, 0)
result2 = output._handle_image(tmpdir, 0)

assert result1 == result2

def test_get_prediction_single_text(self):
"""测试get_prediction方法,单个文本"""
output = LMMOutput()
output.uuid = "test_uuid"
output.content = ["text content"]

with tempfile.TemporaryDirectory() as tmpdir:
result = output.get_prediction(tmpdir)
assert result == "text content"

def test_get_prediction_single_image(self):
"""测试get_prediction方法,单个图片"""
output = LMMOutput()
output.uuid = "test_uuid"
test_image = Image.new('RGB', (100, 100), color='red')
output.content = [test_image]

with tempfile.TemporaryDirectory() as tmpdir:
result = output.get_prediction(tmpdir)
assert "image_test_uuid_0.png" in result

def test_get_prediction_multiple_items(self):
"""测试get_prediction方法,多个项目"""
output = LMMOutput()
output.uuid = "test_uuid"
test_image = Image.new('RGB', (100, 100), color='red')
output.content = [test_image, "text content"]

with tempfile.TemporaryDirectory() as tmpdir:
result = output.get_prediction(tmpdir)
assert isinstance(result, list)
assert len(result) == 2

def test_get_metrics_inherited(self):
"""测试get_metrics方法(继承自Output抽象基类,返回None)"""
output = LMMOutput()
output.success = True
output.uuid = "test_uuid"
output.content = ["test"]

metrics = output.get_metrics()
assert metrics is None


def test_output_update_extra_perf_data_from_stream_response():
"""测试update_extra_perf_data_from_stream_response方法(默认实现)"""
output = ConcreteOutput()
output.update_extra_perf_data_from_stream_response({"test": "data"})
assert output.extra_perf_data == {}


def test_output_update_extra_perf_data_from_text_response():
"""测试update_extra_perf_data_from_text_response方法(默认实现)"""
output = ConcreteOutput()
output.update_extra_perf_data_from_text_response({"test": "data"})
assert output.extra_perf_data == {}


def test_output_update_extra_details_data_from_stream_response():
"""测试update_extra_details_data_from_stream_response方法(默认实现)"""
output = ConcreteOutput()
output.update_extra_details_data_from_stream_response({"test": "data"})
assert output.extra_details_data == {}


def test_output_update_extra_details_data_from_text_response():
"""测试update_extra_details_data_from_text_response方法(默认实现)"""
output = ConcreteOutput()
output.update_extra_details_data_from_text_response({"test": "data"})
assert output.extra_details_data == {}
Loading