mirror of https://github.com/microsoft/autogen.git
parent
06ba5d3ca8
commit
c6360feeb6
|
@ -16,6 +16,7 @@ from typing import (
|
|||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
@ -67,7 +68,8 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
|||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
typed_signature = inspect.Signature(typed_params)
|
||||
return_annotation = get_typed_annotation(signature.return_annotation, globalns)
|
||||
typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation)
|
||||
return typed_signature
|
||||
|
||||
|
||||
|
@ -313,7 +315,7 @@ def normalize_annotated_type(type_hint: Type[Any]) -> Type[Any]:
|
|||
|
||||
|
||||
def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
|
||||
fields: List[tuple[str, Any]] = []
|
||||
fields: Dict[str, tuple[Type[Any], Any]] = {}
|
||||
for name, param in sig.parameters.items():
|
||||
# This is handled externally
|
||||
if name == "cancellation_token":
|
||||
|
@ -326,24 +328,6 @@ def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[Ba
|
|||
description = type2description(name, param.annotation)
|
||||
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined
|
||||
|
||||
fields.append((name, (type, Field(default=default_value, description=description))))
|
||||
fields[name] = (type, Field(default=default_value, description=description))
|
||||
|
||||
return create_model(name, *fields)
|
||||
|
||||
|
||||
def return_value_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
|
||||
if issubclass(BaseModel, sig.return_annotation):
|
||||
return sig.return_annotation # type: ignore
|
||||
|
||||
fields: List[tuple[str, Any]] = []
|
||||
for name, param in sig.return_annotation:
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
raise ValueError("No annotation")
|
||||
|
||||
type = normalize_annotated_type(param.annotation)
|
||||
description = type2description(name, param.annotation)
|
||||
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined
|
||||
|
||||
fields.append((name, (type, Field(default=default_value, description=description))))
|
||||
|
||||
return create_model(name, *fields)
|
||||
return cast(BaseModel, create_model(name, **fields)) # type: ignore
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar
|
||||
|
||||
|
@ -20,11 +21,13 @@ class Tool(Protocol):
|
|||
|
||||
def args_type(self) -> Type[BaseModel]: ...
|
||||
|
||||
def return_type(self) -> Type[BaseModel]: ...
|
||||
def return_type(self) -> Type[Any]: ...
|
||||
|
||||
def state_type(self) -> Type[BaseModel] | None: ...
|
||||
|
||||
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> BaseModel: ...
|
||||
def return_value_as_string(self, value: Any) -> str: ...
|
||||
|
||||
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> Any: ...
|
||||
|
||||
def save_state_json(self) -> Mapping[str, Any]: ...
|
||||
|
||||
|
@ -63,16 +66,25 @@ class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
|
|||
def args_type(self) -> Type[BaseModel]:
|
||||
return self._args_type
|
||||
|
||||
def return_type(self) -> Type[BaseModel]:
|
||||
def return_type(self) -> Type[Any]:
|
||||
return self._return_type
|
||||
|
||||
def state_type(self) -> Type[BaseModel] | None:
|
||||
return None
|
||||
|
||||
def return_value_as_string(self, value: Any) -> str:
|
||||
if isinstance(value, BaseModel):
|
||||
dumped = value.model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return json.dumps(dumped)
|
||||
return str(dumped)
|
||||
|
||||
return str(value)
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, args: ArgsT, cancellation_token: CancellationToken) -> ReturnT: ...
|
||||
|
||||
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> BaseModel:
|
||||
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> Any:
|
||||
return_value = await self.run(self._args_type.model_validate(args), cancellation_token)
|
||||
return return_value
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from ...core import CancellationToken
|
|||
from .._function_utils import (
|
||||
args_base_model_from_signature,
|
||||
get_typed_signature,
|
||||
return_value_base_model_from_signature,
|
||||
)
|
||||
from ._base import BaseTool
|
||||
|
||||
|
@ -19,12 +18,12 @@ class FunctionTool(BaseTool[BaseModel, BaseModel]):
|
|||
signature = get_typed_signature(func)
|
||||
func_name = name or func.__name__
|
||||
args_model = args_base_model_from_signature(func_name + "args", signature)
|
||||
return_model = return_value_base_model_from_signature(func_name + "return", signature)
|
||||
return_type = signature.return_annotation
|
||||
self._has_cancellation_support = "cancellation_token" in signature.parameters
|
||||
|
||||
super().__init__(args_model, return_model, func_name, description)
|
||||
super().__init__(args_model, return_type, func_name, description)
|
||||
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> BaseModel:
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
|
||||
if asyncio.iscoroutinefunction(self._func):
|
||||
if self._has_cancellation_support:
|
||||
result = await self._func(**args.model_dump(), cancellation_token=cancellation_token)
|
||||
|
@ -42,5 +41,5 @@ class FunctionTool(BaseTool[BaseModel, BaseModel]):
|
|||
cancellation_token.link_future(future)
|
||||
result = await future
|
||||
|
||||
assert isinstance(result, BaseModel)
|
||||
assert isinstance(result, self.return_type())
|
||||
return result
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
|
||||
import inspect
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from agnext.components.tools import BaseTool
|
||||
from agnext.components._function_utils import get_typed_signature
|
||||
from agnext.components.tools import BaseTool, FunctionTool
|
||||
from agnext.core import CancellationToken
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_serializer
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
|
||||
class MyArgs(BaseModel):
|
||||
|
@ -58,3 +63,181 @@ def test_tool_properties()-> None:
|
|||
assert tool.args_type() == MyArgs
|
||||
assert tool.return_type() == MyResult
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_get_typed_signature()-> None:
|
||||
def my_function() -> str:
|
||||
return "result"
|
||||
|
||||
sig = get_typed_signature(my_function)
|
||||
assert isinstance(sig, inspect.Signature)
|
||||
assert len(sig.parameters) == 0
|
||||
assert sig.return_annotation == str
|
||||
|
||||
def test_get_typed_signature_annotated()-> None:
|
||||
def my_function() -> Annotated[str, "The return type"]:
|
||||
return "result"
|
||||
|
||||
sig = get_typed_signature(my_function)
|
||||
assert isinstance(sig, inspect.Signature)
|
||||
assert len(sig.parameters) == 0
|
||||
assert sig.return_annotation == Annotated[str, "The return type"]
|
||||
|
||||
def test_get_typed_signature_string()-> None:
|
||||
def my_function() -> "str":
|
||||
return "result"
|
||||
|
||||
sig = get_typed_signature(my_function)
|
||||
assert isinstance(sig, inspect.Signature)
|
||||
assert len(sig.parameters) == 0
|
||||
assert sig.return_annotation == str
|
||||
|
||||
|
||||
def test_func_tool()-> None:
|
||||
def my_function() -> str:
|
||||
return "result"
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert issubclass(tool.return_type(), str)
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_func_tool_annotated_arg()-> None:
|
||||
def my_function(my_arg: Annotated[str, "test description"]) -> str:
|
||||
return "result"
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert issubclass(tool.return_type(), str)
|
||||
assert tool.args_type().model_fields["my_arg"].description == "test description"
|
||||
assert tool.args_type().model_fields["my_arg"].annotation == str
|
||||
assert tool.args_type().model_fields["my_arg"].is_required() is True
|
||||
assert tool.args_type().model_fields["my_arg"].default is PydanticUndefined
|
||||
assert len(tool.args_type().model_fields) == 1
|
||||
assert tool.return_type() == str
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_func_tool_return_annotated()-> None:
|
||||
def my_function() -> Annotated[str, "test description"]:
|
||||
return "result"
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert tool.return_type() == Annotated[str, "test description"]
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_func_tool_no_args()-> None:
|
||||
def my_function() -> str:
|
||||
return "result"
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert len(tool.args_type().model_fields) == 0
|
||||
assert tool.return_type() == str
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_func_tool_return_none()-> None:
|
||||
def my_function() -> None:
|
||||
return None
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert tool.return_type() is None
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_func_tool_return_base_model()-> None:
|
||||
def my_function() -> MyResult:
|
||||
return MyResult(result="value")
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert tool.return_type() is MyResult
|
||||
assert tool.state_type() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_call_tool()-> None:
|
||||
def my_function() -> str:
|
||||
return "result"
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({}, CancellationToken())
|
||||
assert result == "result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_call_tool_base_model()-> None:
|
||||
def my_function() -> MyResult:
|
||||
return MyResult(result="value")
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({}, CancellationToken())
|
||||
assert isinstance(result, MyResult)
|
||||
assert result.result == "value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_call_tool_with_arg_base_model()-> None:
|
||||
def my_function(arg: str) -> MyResult:
|
||||
return MyResult(result="value")
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({"arg": "test"}, CancellationToken())
|
||||
assert isinstance(result, MyResult)
|
||||
assert result.result == "value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_str_res()-> None:
|
||||
def my_function(arg: str) -> str:
|
||||
return "test"
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({"arg": "test"}, CancellationToken())
|
||||
assert tool.return_value_as_string(result) == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_base_model_res()-> None:
|
||||
|
||||
|
||||
def my_function(arg: str) -> MyResult:
|
||||
return MyResult(result="test")
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({"arg": "test"}, CancellationToken())
|
||||
assert tool.return_value_as_string(result) == '{"result": "test"}'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_base_model_custom_dump_res()-> None:
|
||||
|
||||
class MyResultCustomDump(BaseModel):
|
||||
result: str = Field(description="The other description.")
|
||||
|
||||
@model_serializer
|
||||
def ser_model(self) -> str:
|
||||
return "custom: " + self.result
|
||||
|
||||
|
||||
def my_function(arg: str) -> MyResultCustomDump:
|
||||
return MyResultCustomDump(result="test")
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({"arg": "test"}, CancellationToken())
|
||||
assert tool.return_value_as_string(result) == "custom: test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_int_res()-> None:
|
||||
def my_function(arg: int) -> int:
|
||||
return arg
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({"arg": 5}, CancellationToken())
|
||||
assert tool.return_value_as_string(result) == "5"
|
||||
|
|
Loading…
Reference in New Issue