Fix types in tool tests (#2285)

* fixed types related to function calling

* polishing

* fixed types in tests
This commit is contained in:
Davor Runje 2024-04-05 17:51:49 +02:00 committed by GitHub
parent 0e0895fe18
commit 0c0f953df3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 21 deletions

View File

@ -13,7 +13,7 @@ if not PYDANTIC_V1:
from pydantic._internal._typing_extra import eval_type_lenient as evaluate_forwardref from pydantic._internal._typing_extra import eval_type_lenient as evaluate_forwardref
from pydantic.json_schema import JsonSchemaValue from pydantic.json_schema import JsonSchemaValue
def type2schema(t: Optional[Type[Any]]) -> JsonSchemaValue: def type2schema(t: Any) -> JsonSchemaValue:
"""Convert a type to a JSON schema """Convert a type to a JSON schema
Args: Args:
@ -55,7 +55,7 @@ else: # pragma: no cover
JsonSchemaValue = Dict[str, Any] # type: ignore[misc] JsonSchemaValue = Dict[str, Any] # type: ignore[misc]
def type2schema(t: Optional[Type[Any]]) -> JsonSchemaValue: def type2schema(t: Any) -> JsonSchemaValue:
"""Convert a type to a JSON schema """Convert a type to a JSON schema
Args: Args:

View File

@ -110,9 +110,7 @@ class ToolFunction(BaseModel):
function: Annotated[Function, Field(description="Function under tool")] function: Annotated[Function, Field(description="Function under tool")]
def get_parameter_json_schema( def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> JsonSchemaValue:
k: str, v: Union[Annotated[Type[Any], str], Type[Any]], default_values: Dict[str, Any]
) -> JsonSchemaValue:
"""Get a JSON schema for a parameter as defined by the OpenAI API """Get a JSON schema for a parameter as defined by the OpenAI API
Args: Args:

View File

@ -72,6 +72,8 @@ files = [
"autogen/_pydantic.py", "autogen/_pydantic.py",
"autogen/function_utils.py", "autogen/function_utils.py",
"autogen/io", "autogen/io",
"test/test_pydantic.py",
"test/test_function_utils.py",
"test/io", "test/io",
] ]

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import inspect import inspect
import unittest.mock import unittest.mock
from typing import Dict, List, Literal, Optional, Tuple from typing import Any, Dict, List, Literal, Optional, Tuple
import pytest import pytest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -25,11 +25,11 @@ from autogen.function_utils import (
) )
def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1, *, d): def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1, *, d): # type: ignore[no-untyped-def]
pass pass
def g( def g( # type: ignore[empty-body]
a: Annotated[str, "Parameter a"], a: Annotated[str, "Parameter a"],
b: int = 2, b: int = 2,
c: Annotated[float, "Parameter c"] = 0.1, c: Annotated[float, "Parameter c"] = 0.1,
@ -39,7 +39,7 @@ def g(
pass pass
async def a_g( async def a_g( # type: ignore[empty-body]
a: Annotated[str, "Parameter a"], a: Annotated[str, "Parameter a"],
b: int = 2, b: int = 2,
c: Annotated[float, "Parameter c"] = 0.1, c: Annotated[float, "Parameter c"] = 0.1,
@ -83,7 +83,7 @@ def test_get_parameter_json_schema() -> None:
b: float b: float
c: str c: str
expected = { expected: Dict[str, Any] = {
"description": "b", "description": "b",
"properties": {"b": {"title": "B", "type": "number"}, "c": {"title": "C", "type": "string"}}, "properties": {"b": {"title": "B", "type": "number"}, "c": {"title": "C", "type": "string"}},
"required": ["b", "c"], "required": ["b", "c"],
@ -107,7 +107,7 @@ def test_get_default_values() -> None:
def test_get_param_annotations() -> None: def test_get_param_annotations() -> None:
def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0): def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0): # type: ignore[no-untyped-def]
pass pass
expected = {"a": Annotated[str, "Parameter a"], "c": Annotated[float, "Parameter c"]} expected = {"a": Annotated[str, "Parameter a"], "c": Annotated[float, "Parameter c"]}
@ -119,14 +119,14 @@ def test_get_param_annotations() -> None:
def test_get_missing_annotations() -> None: def test_get_missing_annotations() -> None:
def _f1(a: str, b=2): def _f1(a: str, b=2): # type: ignore[no-untyped-def]
pass pass
missing, unannotated_with_default = get_missing_annotations(get_typed_signature(_f1), ["a"]) missing, unannotated_with_default = get_missing_annotations(get_typed_signature(_f1), ["a"])
assert missing == set() assert missing == set()
assert unannotated_with_default == {"b"} assert unannotated_with_default == {"b"}
def _f2(a: str, b) -> str: def _f2(a: str, b) -> str: # type: ignore[empty-body,no-untyped-def]
"ok" "ok"
missing, unannotated_with_default = get_missing_annotations(get_typed_signature(_f2), ["a", "b"]) missing, unannotated_with_default = get_missing_annotations(get_typed_signature(_f2), ["a", "b"])
@ -142,7 +142,7 @@ def test_get_missing_annotations() -> None:
def test_get_parameters() -> None: def test_get_parameters() -> None:
def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0): def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0): # type: ignore[no-untyped-def]
pass pass
typed_signature = get_typed_signature(f) typed_signature = get_typed_signature(f)
@ -165,7 +165,7 @@ def test_get_parameters() -> None:
def test_get_function_schema_no_return_type() -> None: def test_get_function_schema_no_return_type() -> None:
def f(a: Annotated[str, "Parameter a"], b: int, c: float = 0.1): def f(a: Annotated[str, "Parameter a"], b: int, c: float = 0.1): # type: ignore[no-untyped-def]
pass pass
expected = ( expected = (
@ -182,7 +182,7 @@ def test_get_function_schema_no_return_type() -> None:
def test_get_function_schema_unannotated_with_default() -> None: def test_get_function_schema_unannotated_with_default() -> None:
with unittest.mock.patch("autogen.function_utils.logger.warning") as mock_logger_warning: with unittest.mock.patch("autogen.function_utils.logger.warning") as mock_logger_warning:
def f( def f( # type: ignore[no-untyped-def]
a: Annotated[str, "Parameter a"], b=2, c: Annotated[float, "Parameter c"] = 0.1, d="whatever", e=None a: Annotated[str, "Parameter a"], b=2, c: Annotated[float, "Parameter c"] = 0.1, d="whatever", e=None
) -> str: ) -> str:
return "ok" return "ok"
@ -195,7 +195,7 @@ def test_get_function_schema_unannotated_with_default() -> None:
def test_get_function_schema_missing() -> None: def test_get_function_schema_missing() -> None:
def f(a: Annotated[str, "Parameter a"], b, c: Annotated[float, "Parameter c"] = 0.1) -> float: def f(a: Annotated[str, "Parameter a"], b, c: Annotated[float, "Parameter c"] = 0.1) -> float: # type: ignore[no-untyped-def, empty-body]
pass pass
expected = ( expected = (
@ -291,7 +291,7 @@ class Currency(BaseModel):
def test_get_function_schema_pydantic() -> None: def test_get_function_schema_pydantic() -> None:
def currency_calculator( def currency_calculator( # type: ignore[empty-body]
base: Annotated[Currency, "Base currency: amount and currency symbol"], base: Annotated[Currency, "Base currency: amount and currency symbol"],
quote_currency: Annotated[CurrencySymbol, "Quote currency symbol (default: 'EUR')"] = "EUR", quote_currency: Annotated[CurrencySymbol, "Quote currency symbol (default: 'EUR')"] = "EUR",
) -> Currency: ) -> Currency:
@ -346,12 +346,12 @@ def test_get_function_schema_pydantic() -> None:
def test_get_load_param_if_needed_function() -> None: def test_get_load_param_if_needed_function() -> None:
assert get_load_param_if_needed_function(CurrencySymbol) is None assert get_load_param_if_needed_function(CurrencySymbol) is None
assert get_load_param_if_needed_function(Currency)({"currency": "USD", "amount": 123.45}, Currency) == Currency( assert get_load_param_if_needed_function(Currency)({"currency": "USD", "amount": 123.45}, Currency) == Currency( # type: ignore[misc]
currency="USD", amount=123.45 currency="USD", amount=123.45
) )
f = get_load_param_if_needed_function(Annotated[Currency, "amount and a symbol of a currency"]) f = get_load_param_if_needed_function(Annotated[Currency, "amount and a symbol of a currency"])
actual = f({"currency": "USD", "amount": 123.45}, Currency) actual = f({"currency": "USD", "amount": 123.45}, Currency) # type: ignore[misc]
expected = Currency(currency="USD", amount=123.45) expected = Currency(currency="USD", amount=123.45)
assert actual == expected, actual assert actual == expected, actual
@ -391,7 +391,7 @@ async def test_load_basemodels_if_needed_async() -> None:
assert actual[1] == "EUR" assert actual[1] == "EUR"
def test_serialize_to_json(): def test_serialize_to_json() -> None:
assert serialize_to_str("abc") == "abc" assert serialize_to_str("abc") == "abc"
assert serialize_to_str(123) == "123" assert serialize_to_str(123) == "123"
assert serialize_to_str([123, 456]) == "[123, 456]" assert serialize_to_str([123, 456]) == "[123, 456]"