Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,8 @@ def _parse_annotation(annotation: Any) -> Any:
If the second annotation (after the type) is a string, then we convert that to a Pydantic Field description.
The rest are returned as-is, allowing for multiple annotations.

Literal types are returned as-is to preserve their enum-like values.

Args:
annotation: The type annotation to parse.

Expand All @@ -894,6 +896,12 @@ def _parse_annotation(annotation: Any) -> Any:
"""
origin = get_origin(annotation)
if origin is not None:
# Literal types should be returned as-is - their args are the allowed values,
# not type annotations to be parsed. For example, Literal["Data", "Security"]
# has args ("Data", "Security") which are the valid string values.
if origin is Literal:
return annotation

args = get_args(annotation)
# For other generics, return the origin type (e.g., list for List[int])
if len(args) > 1 and isinstance(args[1], str):
Expand Down
160 changes: 158 additions & 2 deletions python/packages/core/tests/core/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import Any
from typing import Annotated, Any, Literal
from unittest.mock import Mock

import pytest
Expand All @@ -14,7 +14,7 @@
ToolProtocol,
ai_function,
)
from agent_framework._tools import _parse_inputs
from agent_framework._tools import _parse_annotation, _parse_inputs
from agent_framework.exceptions import ToolException
from agent_framework.observability import OtelAttr

Expand Down Expand Up @@ -128,6 +128,95 @@ def test_tool(self, x: int, y: int) -> int:
assert test_tool(1, 2) == 3


def test_ai_function_with_literal_type_parameter():
"""Test ai_function decorator with Literal type parameter (issue #2891)."""

@ai_function
def search_flows(category: Literal["Data", "Security", "Network"], issue: str) -> str:
"""Search flows by category."""
return f"{category}: {issue}"

assert isinstance(search_flows, AIFunction)
schema = search_flows.parameters()
assert schema == {
"properties": {
"category": {"enum": ["Data", "Security", "Network"], "title": "Category", "type": "string"},
"issue": {"title": "Issue", "type": "string"},
},
"required": ["category", "issue"],
"title": "search_flows_input",
"type": "object",
}
# Verify invocation works
assert search_flows("Data", "test issue") == "Data: test issue"


def test_ai_function_with_literal_type_in_class_method():
"""Test ai_function decorator with Literal type parameter in a class method (issue #2891)."""

class MyTools:
@ai_function
def search_flows(self, category: Literal["Data", "Security", "Network"], issue: str) -> str:
"""Search flows by category."""
return f"{category}: {issue}"

tools = MyTools()
search_tool = tools.search_flows
assert isinstance(search_tool, AIFunction)
schema = search_tool.parameters()
assert schema == {
"properties": {
"category": {"enum": ["Data", "Security", "Network"], "title": "Category", "type": "string"},
"issue": {"title": "Issue", "type": "string"},
},
"required": ["category", "issue"],
"title": "search_flows_input",
"type": "object",
}
# Verify invocation works
assert search_tool("Security", "test issue") == "Security: test issue"


def test_ai_function_with_literal_int_type():
"""Test ai_function decorator with Literal int type parameter."""

@ai_function
def set_priority(priority: Literal[1, 2, 3], task: str) -> str:
"""Set priority for a task."""
return f"Priority {priority}: {task}"

assert isinstance(set_priority, AIFunction)
schema = set_priority.parameters()
assert schema == {
"properties": {
"priority": {"enum": [1, 2, 3], "title": "Priority", "type": "integer"},
"task": {"title": "Task", "type": "string"},
},
"required": ["priority", "task"],
"title": "set_priority_input",
"type": "object",
}
assert set_priority(1, "important task") == "Priority 1: important task"


def test_ai_function_with_literal_and_annotated():
"""Test ai_function decorator with Literal type combined with Annotated for description."""

@ai_function
def categorize(
category: Annotated[Literal["A", "B", "C"], "The category to assign"],
name: str,
) -> str:
"""Categorize an item."""
return f"{category}: {name}"

assert isinstance(categorize, AIFunction)
schema = categorize.parameters()
# Literal type inside Annotated should preserve enum values
assert schema["properties"]["category"]["enum"] == ["A", "B", "C"]
assert categorize("A", "test") == "A: test"


async def test_ai_function_decorator_shared_state():
"""Test that decorated methods maintain shared state across multiple calls and tool usage."""

Expand Down Expand Up @@ -1368,3 +1457,70 @@ def tool_with_kwargs(x: int, **kwargs: Any) -> str:
arguments=tool_with_kwargs.input_model(x=10),
)
assert result_default == "x=10, user=unknown"


# region _parse_annotation tests


def test_parse_annotation_with_literal_type():
"""Test that _parse_annotation returns Literal types unchanged (issue #2891)."""
from typing import get_args, get_origin

# Literal with string values
literal_annotation = Literal["Data", "Security", "Network"]
result = _parse_annotation(literal_annotation)
assert result is literal_annotation
assert get_origin(result) is Literal
assert get_args(result) == ("Data", "Security", "Network")


def test_parse_annotation_with_literal_int_type():
"""Test that _parse_annotation returns Literal int types unchanged."""
from typing import get_args, get_origin

literal_annotation = Literal[1, 2, 3]
result = _parse_annotation(literal_annotation)
assert result is literal_annotation
assert get_origin(result) is Literal
assert get_args(result) == (1, 2, 3)


def test_parse_annotation_with_literal_bool_type():
"""Test that _parse_annotation returns Literal bool types unchanged."""
from typing import get_args, get_origin

literal_annotation = Literal[True, False]
result = _parse_annotation(literal_annotation)
assert result is literal_annotation
assert get_origin(result) is Literal
assert get_args(result) == (True, False)


def test_parse_annotation_with_simple_types():
"""Test that _parse_annotation returns simple types unchanged."""
assert _parse_annotation(str) is str
assert _parse_annotation(int) is int
assert _parse_annotation(float) is float
assert _parse_annotation(bool) is bool


def test_parse_annotation_with_annotated_and_literal():
"""Test that Annotated[Literal[...], description] works correctly."""
from typing import get_args, get_origin

# When Literal is inside Annotated, it should still be preserved
annotated_literal = Annotated[Literal["A", "B", "C"], "The category"]
result = _parse_annotation(annotated_literal)

# The Annotated type should be preserved
origin = get_origin(result)
assert origin is Annotated

args = get_args(result)
# First arg is the Literal type
literal_type = args[0]
assert get_origin(literal_type) is Literal
assert get_args(literal_type) == ("A", "B", "C")


# endregion
Loading