Fix function tools (#57)

* Fix function tools

* lint
This commit is contained in:
Jack Gerrits 2024-06-06 21:58:11 -04:00 committed by GitHub
parent 06ba5d3ca8
commit c6360feeb6
4 changed files with 211 additions and 33 deletions

View File

@ -16,6 +16,7 @@ from typing import (
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
)
@ -67,7 +68,8 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
)
for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return_annotation = get_typed_annotation(signature.return_annotation, globalns)
typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation)
return typed_signature
@ -313,7 +315,7 @@ def normalize_annotated_type(type_hint: Type[Any]) -> Type[Any]:
def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
fields: List[tuple[str, Any]] = []
fields: Dict[str, tuple[Type[Any], Any]] = {}
for name, param in sig.parameters.items():
# This is handled externally
if name == "cancellation_token":
@ -326,24 +328,6 @@ def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[Ba
description = type2description(name, param.annotation)
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined
fields.append((name, (type, Field(default=default_value, description=description))))
fields[name] = (type, Field(default=default_value, description=description))
return create_model(name, *fields)
def return_value_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
if issubclass(BaseModel, sig.return_annotation):
return sig.return_annotation # type: ignore
fields: List[tuple[str, Any]] = []
for name, param in sig.return_annotation:
if param.annotation is inspect.Parameter.empty:
raise ValueError("No annotation")
type = normalize_annotated_type(param.annotation)
description = type2description(name, param.annotation)
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined
fields.append((name, (type, Field(default=default_value, description=description))))
return create_model(name, *fields)
return cast(BaseModel, create_model(name, **fields)) # type: ignore

View File

@ -1,3 +1,4 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar
@ -20,11 +21,13 @@ class Tool(Protocol):
def args_type(self) -> Type[BaseModel]: ...
def return_type(self) -> Type[BaseModel]: ...
def return_type(self) -> Type[Any]: ...
def state_type(self) -> Type[BaseModel] | None: ...
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> BaseModel: ...
def return_value_as_string(self, value: Any) -> str: ...
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> Any: ...
def save_state_json(self) -> Mapping[str, Any]: ...
@ -63,16 +66,25 @@ class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
def args_type(self) -> Type[BaseModel]:
return self._args_type
def return_type(self) -> Type[BaseModel]:
def return_type(self) -> Type[Any]:
return self._return_type
def state_type(self) -> Type[BaseModel] | None:
return None
def return_value_as_string(self, value: Any) -> str:
if isinstance(value, BaseModel):
dumped = value.model_dump()
if isinstance(dumped, dict):
return json.dumps(dumped)
return str(dumped)
return str(value)
@abstractmethod
async def run(self, args: ArgsT, cancellation_token: CancellationToken) -> ReturnT: ...
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> BaseModel:
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> Any:
return_value = await self.run(self._args_type.model_validate(args), cancellation_token)
return return_value

View File

@ -8,7 +8,6 @@ from ...core import CancellationToken
from .._function_utils import (
args_base_model_from_signature,
get_typed_signature,
return_value_base_model_from_signature,
)
from ._base import BaseTool
@ -19,12 +18,12 @@ class FunctionTool(BaseTool[BaseModel, BaseModel]):
signature = get_typed_signature(func)
func_name = name or func.__name__
args_model = args_base_model_from_signature(func_name + "args", signature)
return_model = return_value_base_model_from_signature(func_name + "return", signature)
return_type = signature.return_annotation
self._has_cancellation_support = "cancellation_token" in signature.parameters
super().__init__(args_model, return_model, func_name, description)
super().__init__(args_model, return_type, func_name, description)
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> BaseModel:
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
if asyncio.iscoroutinefunction(self._func):
if self._has_cancellation_support:
result = await self._func(**args.model_dump(), cancellation_token=cancellation_token)
@ -42,5 +41,5 @@ class FunctionTool(BaseTool[BaseModel, BaseModel]):
cancellation_token.link_future(future)
result = await future
assert isinstance(result, BaseModel)
assert isinstance(result, self.return_type())
return result

View File

@ -1,8 +1,13 @@
import inspect
from typing import Annotated
import pytest
from agnext.components.tools import BaseTool
from agnext.components._function_utils import get_typed_signature
from agnext.components.tools import BaseTool, FunctionTool
from agnext.core import CancellationToken
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_serializer
from pydantic_core import PydanticUndefined
class MyArgs(BaseModel):
@ -58,3 +63,181 @@ def test_tool_properties()-> None:
assert tool.args_type() == MyArgs
assert tool.return_type() == MyResult
assert tool.state_type() is None
def test_get_typed_signature()-> None:
def my_function() -> str:
return "result"
sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert len(sig.parameters) == 0
assert sig.return_annotation == str
def test_get_typed_signature_annotated()-> None:
def my_function() -> Annotated[str, "The return type"]:
return "result"
sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert len(sig.parameters) == 0
assert sig.return_annotation == Annotated[str, "The return type"]
def test_get_typed_signature_string()-> None:
def my_function() -> "str":
return "result"
sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert len(sig.parameters) == 0
assert sig.return_annotation == str
def test_func_tool()-> None:
def my_function() -> str:
return "result"
tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert issubclass(tool.return_type(), str)
assert tool.state_type() is None
def test_func_tool_annotated_arg()-> None:
def my_function(my_arg: Annotated[str, "test description"]) -> str:
return "result"
tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert issubclass(tool.return_type(), str)
assert tool.args_type().model_fields["my_arg"].description == "test description"
assert tool.args_type().model_fields["my_arg"].annotation == str
assert tool.args_type().model_fields["my_arg"].is_required() is True
assert tool.args_type().model_fields["my_arg"].default is PydanticUndefined
assert len(tool.args_type().model_fields) == 1
assert tool.return_type() == str
assert tool.state_type() is None
def test_func_tool_return_annotated()-> None:
def my_function() -> Annotated[str, "test description"]:
return "result"
tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert tool.return_type() == Annotated[str, "test description"]
assert tool.state_type() is None
def test_func_tool_no_args()-> None:
def my_function() -> str:
return "result"
tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert len(tool.args_type().model_fields) == 0
assert tool.return_type() == str
assert tool.state_type() is None
def test_func_tool_return_none()-> None:
def my_function() -> None:
return None
tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert tool.return_type() is None
assert tool.state_type() is None
def test_func_tool_return_base_model()-> None:
def my_function() -> MyResult:
return MyResult(result="value")
tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert tool.return_type() is MyResult
assert tool.state_type() is None
@pytest.mark.asyncio
async def test_func_call_tool()-> None:
def my_function() -> str:
return "result"
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({}, CancellationToken())
assert result == "result"
@pytest.mark.asyncio
async def test_func_call_tool_base_model()-> None:
def my_function() -> MyResult:
return MyResult(result="value")
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({}, CancellationToken())
assert isinstance(result, MyResult)
assert result.result == "value"
@pytest.mark.asyncio
async def test_func_call_tool_with_arg_base_model()-> None:
def my_function(arg: str) -> MyResult:
return MyResult(result="value")
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": "test"}, CancellationToken())
assert isinstance(result, MyResult)
assert result.result == "value"
@pytest.mark.asyncio
async def test_func_str_res()-> None:
def my_function(arg: str) -> str:
return "test"
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": "test"}, CancellationToken())
assert tool.return_value_as_string(result) == "test"
@pytest.mark.asyncio
async def test_func_base_model_res()-> None:
def my_function(arg: str) -> MyResult:
return MyResult(result="test")
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": "test"}, CancellationToken())
assert tool.return_value_as_string(result) == '{"result": "test"}'
@pytest.mark.asyncio
async def test_func_base_model_custom_dump_res()-> None:
class MyResultCustomDump(BaseModel):
result: str = Field(description="The other description.")
@model_serializer
def ser_model(self) -> str:
return "custom: " + self.result
def my_function(arg: str) -> MyResultCustomDump:
return MyResultCustomDump(result="test")
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": "test"}, CancellationToken())
assert tool.return_value_as_string(result) == "custom: test"
@pytest.mark.asyncio
async def test_func_int_res()-> None:
def my_function(arg: int) -> int:
return arg
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": 5}, CancellationToken())
assert tool.return_value_as_string(result) == "5"