mirror of https://github.com/microsoft/autogen.git
Add decorator for function calling (#1018)
* add function decorator to converasble agent * polishing * polishing * added function decorator to the notebook with async function calls * added support for return type hint and JSON encoding of returned value if needed * polishing * polishing * refactored async case * Python 3.8 support added * polishing * polishing * missing docs added * refacotring and changes as requested * getLogger * documentation added * test fix * test fix * added testing of agentchat_function_call_currency_calculator.ipynb to test_notebook.py * added support for Pydantic parameters in function decorator * polishing * Update website/docs/Use-Cases/agent_chat.md Co-authored-by: Li Jiang <bnujli@gmail.com> * Update website/docs/Use-Cases/agent_chat.md Co-authored-by: Li Jiang <bnujli@gmail.com> * fixes problem with logprob parameter in openai.types.chat.chat_completion.Choice added by openai version 1.5.0 * get 100% code coverage on code added * updated docs * default values added to JSON schema * serialization using json.dump() add for values not string or BaseModel * added limit to openai version because of breaking changes in 1.5.0 * added line-by-line comments in docs to explain the process * polishing --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
parent
b1adac5159
commit
4b5ec5a52f
|
@ -8,7 +8,7 @@ node_modules/
|
|||
*.log
|
||||
|
||||
# Python virtualenv
|
||||
.venv
|
||||
.venv*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
from typing import Any, Dict, Optional, Tuple, Type, Union, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from typing_extensions import get_origin
|
||||
|
||||
__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema")
|
||||
|
||||
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
|
||||
|
||||
if not PYDANTIC_V1:
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic._internal._typing_extra import eval_type_lenient as evaluate_forwardref
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
|
||||
def type2schema(t: Optional[Type]) -> JsonSchemaValue:
|
||||
"""Convert a type to a JSON schema
|
||||
|
||||
Args:
|
||||
t (Type): The type to convert
|
||||
|
||||
Returns:
|
||||
JsonSchemaValue: The JSON schema
|
||||
"""
|
||||
return TypeAdapter(t).json_schema()
|
||||
|
||||
def model_dump(model: BaseModel) -> Dict[str, Any]:
|
||||
"""Convert a pydantic model to a dict
|
||||
|
||||
Args:
|
||||
model (BaseModel): The model to convert
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The dict representation of the model
|
||||
|
||||
"""
|
||||
return model.model_dump()
|
||||
|
||||
def model_dump_json(model: BaseModel) -> str:
|
||||
"""Convert a pydantic model to a JSON string
|
||||
|
||||
Args:
|
||||
model (BaseModel): The model to convert
|
||||
|
||||
Returns:
|
||||
str: The JSON string representation of the model
|
||||
"""
|
||||
return model.model_dump_json()
|
||||
|
||||
|
||||
# Remove this once we drop support for pydantic 1.x
|
||||
else: # pragma: no cover
|
||||
from pydantic import schema_of
|
||||
from pydantic.typing import evaluate_forwardref as evaluate_forwardref
|
||||
|
||||
JsonSchemaValue = Dict[str, Any]
|
||||
|
||||
def type2schema(t: Optional[Type]) -> JsonSchemaValue:
|
||||
"""Convert a type to a JSON schema
|
||||
|
||||
Args:
|
||||
t (Type): The type to convert
|
||||
|
||||
Returns:
|
||||
JsonSchemaValue: The JSON schema
|
||||
"""
|
||||
if PYDANTIC_V1:
|
||||
if t is None:
|
||||
return {"type": "null"}
|
||||
elif get_origin(t) is Union:
|
||||
return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
|
||||
elif get_origin(t) in [Tuple, tuple]:
|
||||
prefixItems = [type2schema(tt) for tt in get_args(t)]
|
||||
return {
|
||||
"maxItems": len(prefixItems),
|
||||
"minItems": len(prefixItems),
|
||||
"prefixItems": prefixItems,
|
||||
"type": "array",
|
||||
}
|
||||
|
||||
d = schema_of(t)
|
||||
if "title" in d:
|
||||
d.pop("title")
|
||||
if "description" in d:
|
||||
d.pop("description")
|
||||
|
||||
return d
|
||||
|
||||
def model_dump(model: BaseModel) -> Dict[str, Any]:
|
||||
"""Convert a pydantic model to a dict
|
||||
|
||||
Args:
|
||||
model (BaseModel): The model to convert
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The dict representation of the model
|
||||
|
||||
"""
|
||||
return model.dict()
|
||||
|
||||
def model_dump_json(model: BaseModel) -> str:
|
||||
"""Convert a pydantic model to a JSON string
|
||||
|
||||
Args:
|
||||
model (BaseModel): The model to convert
|
||||
|
||||
Returns:
|
||||
str: The JSON string representation of the model
|
||||
"""
|
||||
return model.json()
|
|
@ -4,6 +4,7 @@ from pydantic import BaseModel, Extra, root_validator
|
|||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
||||
from time import sleep
|
||||
|
||||
from autogen._pydantic import PYDANTIC_V1
|
||||
from autogen.agentchat import Agent, UserProxyAgent
|
||||
from autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
|
||||
from autogen.math_utils import get_answer
|
||||
|
@ -384,7 +385,8 @@ class WolframAlphaAPIWrapper(BaseModel):
|
|||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
if PYDANTIC_V1:
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
from autogen import OpenAIWrapper
|
||||
from autogen.code_utils import DEFAULT_MODEL, UNKNOWN, content_str, execute_code, extract_code, infer_lang
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
from .. import OpenAIWrapper
|
||||
from ..code_utils import DEFAULT_MODEL, UNKNOWN, content_str, execute_code, extract_code, infer_lang
|
||||
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
|
||||
from .agent import Agent
|
||||
|
||||
try:
|
||||
|
@ -19,8 +20,12 @@ except ImportError:
|
|||
return x
|
||||
|
||||
|
||||
__all__ = ("ConversableAgent",)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class ConversableAgent(Agent):
|
||||
"""(In preview) A class for generic conversable agents which can be configured as assistant or user proxy.
|
||||
|
@ -1330,3 +1335,157 @@ class ConversableAgent(Agent):
|
|||
def function_map(self) -> Dict[str, Callable]:
|
||||
"""Return the function map."""
|
||||
return self._function_map
|
||||
|
||||
def _wrap_function(self, func: F) -> F:
|
||||
"""Wrap the function to dump the return value to json.
|
||||
|
||||
Handles both sync and async functions.
|
||||
|
||||
Args:
|
||||
func: the function to be wrapped.
|
||||
|
||||
Returns:
|
||||
The wrapped function.
|
||||
"""
|
||||
|
||||
@load_basemodels_if_needed
|
||||
@functools.wraps(func)
|
||||
def _wrapped_func(*args, **kwargs):
|
||||
retval = func(*args, **kwargs)
|
||||
|
||||
return serialize_to_str(retval)
|
||||
|
||||
@load_basemodels_if_needed
|
||||
@functools.wraps(func)
|
||||
async def _a_wrapped_func(*args, **kwargs):
|
||||
retval = await func(*args, **kwargs)
|
||||
return serialize_to_str(retval)
|
||||
|
||||
wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func
|
||||
|
||||
# needed for testing
|
||||
wrapped_func._origin = func
|
||||
|
||||
return wrapped_func
|
||||
|
||||
def register_for_llm(
|
||||
self,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator factory for registering a function to be used by an agent.
|
||||
|
||||
It's return value is used to decorate a function to be registered to the agent. The function uses type hints to
|
||||
specify the arguments and return type. The function name is used as the default name for the function,
|
||||
but a custom name can be provided. The function description is used to describe the function in the
|
||||
agent's configuration.
|
||||
|
||||
Args:
|
||||
name (optional(str)): name of the function. If None, the function name will be used (default: None).
|
||||
description (optional(str)): description of the function (default: None). It is mandatory
|
||||
for the initial decorator, but the following ones can omit it.
|
||||
|
||||
Returns:
|
||||
The decorator for registering a function to be used by an agent.
|
||||
|
||||
Examples:
|
||||
```
|
||||
@user_proxy.register_for_execution()
|
||||
@agent2.register_for_llm()
|
||||
@agent1.register_for_llm(description="This is a very useful function")
|
||||
def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14) -> str:
|
||||
return a + str(b * c)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def _decorator(func: F) -> F:
|
||||
"""Decorator for registering a function to be used by an agent.
|
||||
|
||||
Args:
|
||||
func: the function to be registered.
|
||||
|
||||
Returns:
|
||||
The function to be registered, with the _description attribute set to the function description.
|
||||
|
||||
Raises:
|
||||
ValueError: if the function description is not provided and not propagated by a previous decorator.
|
||||
RuntimeError: if the LLM config is not set up before registering a function.
|
||||
|
||||
"""
|
||||
# name can be overwriten by the parameter, by default it is the same as function name
|
||||
if name:
|
||||
func._name = name
|
||||
elif not hasattr(func, "_name"):
|
||||
func._name = func.__name__
|
||||
|
||||
# description is propagated from the previous decorator, but it is mandatory for the first one
|
||||
if description:
|
||||
func._description = description
|
||||
else:
|
||||
if not hasattr(func, "_description"):
|
||||
raise ValueError("Function description is required, none found.")
|
||||
|
||||
# get JSON schema for the function
|
||||
f = get_function_schema(func, name=func._name, description=func._description)
|
||||
|
||||
# register the function to the agent if there is LLM config, raise an exception otherwise
|
||||
if self.llm_config is None:
|
||||
raise RuntimeError("LLM config must be setup before registering a function for LLM.")
|
||||
|
||||
self.update_function_signature(f, is_remove=False)
|
||||
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
def register_for_execution(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator factory for registering a function to be executed by an agent.
|
||||
|
||||
It's return value is used to decorate a function to be registered to the agent.
|
||||
|
||||
Args:
|
||||
name (optional(str)): name of the function. If None, the function name will be used (default: None).
|
||||
|
||||
Returns:
|
||||
The decorator for registering a function to be used by an agent.
|
||||
|
||||
Examples:
|
||||
```
|
||||
@user_proxy.register_for_execution()
|
||||
@agent2.register_for_llm()
|
||||
@agent1.register_for_llm(description="This is a very useful function")
|
||||
def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14):
|
||||
return a + str(b * c)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def _decorator(func: F) -> F:
|
||||
"""Decorator for registering a function to be used by an agent.
|
||||
|
||||
Args:
|
||||
func: the function to be registered.
|
||||
|
||||
Returns:
|
||||
The function to be registered, with the _description attribute set to the function description.
|
||||
|
||||
Raises:
|
||||
ValueError: if the function description is not provided and not propagated by a previous decorator.
|
||||
|
||||
"""
|
||||
# name can be overwriten by the parameter, by default it is the same as function name
|
||||
if name:
|
||||
func._name = name
|
||||
elif not hasattr(func, "_name"):
|
||||
func._name = func.__name__
|
||||
|
||||
self.register_function({func._name: self._wrap_function(func)})
|
||||
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
|
|
@ -0,0 +1,330 @@
|
|||
import functools
|
||||
import inspect
|
||||
import json
|
||||
from logging import getLogger
|
||||
from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, Literal, get_args, get_origin
|
||||
|
||||
from ._pydantic import JsonSchemaValue, evaluate_forwardref, model_dump, model_dump_json, type2schema
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
|
||||
"""Get the type annotation of a parameter.
|
||||
|
||||
Args:
|
||||
annotation: The annotation of the parameter
|
||||
globalns: The global namespace of the function
|
||||
|
||||
Returns:
|
||||
The type annotation of the parameter
|
||||
"""
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
return annotation
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
"""Get the signature of a function with type annotations.
|
||||
|
||||
Args:
|
||||
call: The function to get the signature for
|
||||
|
||||
Returns:
|
||||
The signature of the function with type annotations
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=get_typed_annotation(param.annotation, globalns),
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
typed_signature = inspect.Signature(typed_params)
|
||||
return typed_signature
|
||||
|
||||
|
||||
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||
"""Get the return annotation of a function.
|
||||
|
||||
Args:
|
||||
call: The function to get the return annotation for
|
||||
|
||||
Returns:
|
||||
The return annotation of the function
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
annotation = signature.return_annotation
|
||||
|
||||
if annotation is inspect.Signature.empty:
|
||||
return None
|
||||
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
return get_typed_annotation(annotation, globalns)
|
||||
|
||||
|
||||
def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type, str], Type]]:
|
||||
"""Get the type annotations of the parameters of a function
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function with type annotations
|
||||
|
||||
Returns:
|
||||
A dictionary of the type annotations of the parameters of the function
|
||||
"""
|
||||
return {
|
||||
k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty
|
||||
}
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
"""Parameters of a function as defined by the OpenAI API"""
|
||||
|
||||
type: Literal["object"] = "object"
|
||||
properties: Dict[str, JsonSchemaValue]
|
||||
required: List[str]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""A function as defined by the OpenAI API"""
|
||||
|
||||
description: Annotated[str, Field(description="Description of the function")]
|
||||
name: Annotated[str, Field(description="Name of the function")]
|
||||
parameters: Annotated[Parameters, Field(description="Parameters of the function")]
|
||||
|
||||
|
||||
def get_parameter_json_schema(
|
||||
k: str, v: Union[Annotated[Type, str], Type], default_values: Dict[str, Any]
|
||||
) -> JsonSchemaValue:
|
||||
"""Get a JSON schema for a parameter as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
k: The name of the parameter
|
||||
v: The type of the parameter
|
||||
default_values: The default values of the parameters of the function
|
||||
|
||||
Returns:
|
||||
A Pydanitc model for the parameter
|
||||
"""
|
||||
|
||||
def type2description(k: str, v: Union[Annotated[Type, str], Type]) -> str:
|
||||
# handles Annotated
|
||||
if hasattr(v, "__metadata__"):
|
||||
return v.__metadata__[0]
|
||||
else:
|
||||
return k
|
||||
|
||||
schema = type2schema(v)
|
||||
if k in default_values:
|
||||
dv = default_values[k]
|
||||
schema["default"] = dv
|
||||
|
||||
schema["description"] = type2description(k, v)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_required_params(typed_signature: inspect.Signature) -> List[str]:
|
||||
"""Get the required parameters of a function
|
||||
|
||||
Args:
|
||||
signature: The signature of the function as returned by inspect.signature
|
||||
|
||||
Returns:
|
||||
A list of the required parameters of the function
|
||||
"""
|
||||
return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty]
|
||||
|
||||
|
||||
def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]:
|
||||
"""Get default values of parameters of a function
|
||||
|
||||
Args:
|
||||
signature: The signature of the function as returned by inspect.signature
|
||||
|
||||
Returns:
|
||||
A dictionary of the default values of the parameters of the function
|
||||
"""
|
||||
return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty}
|
||||
|
||||
|
||||
def get_parameters(
|
||||
required: List[str], param_annotations: Dict[str, Union[Annotated[Type, str], Type]], default_values: Dict[str, Any]
|
||||
) -> Parameters:
|
||||
"""Get the parameters of a function as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
required: The required parameters of the function
|
||||
hints: The type hints of the function as returned by typing.get_type_hints
|
||||
|
||||
Returns:
|
||||
A Pydantic model for the parameters of the function
|
||||
"""
|
||||
return Parameters(
|
||||
properties={
|
||||
k: get_parameter_json_schema(k, v, default_values)
|
||||
for k, v in param_annotations.items()
|
||||
if v is not inspect.Signature.empty
|
||||
},
|
||||
required=required,
|
||||
)
|
||||
|
||||
|
||||
def get_missing_annotations(typed_signature: inspect.Signature, required: List[str]) -> Tuple[Set[str], Set[str]]:
|
||||
"""Get the missing annotations of a function
|
||||
|
||||
Ignores the parameters with default values as they are not required to be annotated, but logs a warning.
|
||||
Args:
|
||||
typed_signature: The signature of the function with type annotations
|
||||
required: The required parameters of the function
|
||||
|
||||
Returns:
|
||||
A set of the missing annotations of the function
|
||||
"""
|
||||
all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty}
|
||||
missing = all_missing.intersection(set(required))
|
||||
unannotated_with_default = all_missing.difference(missing)
|
||||
return missing, unannotated_with_default
|
||||
|
||||
|
||||
def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> Dict[str, Any]:
|
||||
"""Get a JSON schema for a function as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
f: The function to get the JSON schema for
|
||||
name: The name of the function
|
||||
description: The description of the function
|
||||
|
||||
Returns:
|
||||
A JSON schema for the function
|
||||
|
||||
Raises:
|
||||
TypeError: If the function is not annotated
|
||||
|
||||
Examples:
|
||||
```
|
||||
def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1) -> None:
|
||||
pass
|
||||
|
||||
get_function_schema(f, description="function f")
|
||||
|
||||
# {'type': 'function',
|
||||
# 'function': {'description': 'function f',
|
||||
# 'name': 'f',
|
||||
# 'parameters': {'type': 'object',
|
||||
# 'properties': {'a': {'type': 'str', 'description': 'Parameter a'},
|
||||
# 'b': {'type': 'int', 'description': 'b'},
|
||||
# 'c': {'type': 'float', 'description': 'Parameter c'}},
|
||||
# 'required': ['a']}}}
|
||||
```
|
||||
|
||||
"""
|
||||
typed_signature = get_typed_signature(f)
|
||||
required = get_required_params(typed_signature)
|
||||
default_values = get_default_values(typed_signature)
|
||||
param_annotations = get_param_annotations(typed_signature)
|
||||
return_annotation = get_typed_return_annotation(f)
|
||||
missing, unannotated_with_default = get_missing_annotations(typed_signature, required)
|
||||
|
||||
if return_annotation is None:
|
||||
logger.warning(
|
||||
f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is "
|
||||
+ "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'."
|
||||
)
|
||||
|
||||
if unannotated_with_default != set():
|
||||
unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)]
|
||||
logger.warning(
|
||||
f"The following parameters of the function '{f.__name__}' with default values are not annotated: "
|
||||
+ f"{', '.join(unannotated_with_default_s)}."
|
||||
)
|
||||
|
||||
if missing != set():
|
||||
missing_s = [f"'{k}'" for k in sorted(missing)]
|
||||
raise TypeError(
|
||||
f"All parameters of the function '{f.__name__}' without default values must be annotated. "
|
||||
+ f"The annotations are missing for the following parameters: {', '.join(missing_s)}"
|
||||
)
|
||||
|
||||
fname = name if name else f.__name__
|
||||
|
||||
parameters = get_parameters(required, param_annotations, default_values=default_values)
|
||||
|
||||
function = Function(
|
||||
description=description,
|
||||
name=fname,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
return model_dump(function)
|
||||
|
||||
|
||||
def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type], BaseModel]]:
|
||||
"""Get a function to load a parameter if it is a Pydantic model
|
||||
|
||||
Args:
|
||||
t: The type annotation of the parameter
|
||||
|
||||
Returns:
|
||||
A function to load the parameter if it is a Pydantic model, otherwise None
|
||||
|
||||
"""
|
||||
if get_origin(t) is Annotated:
|
||||
return get_load_param_if_needed_function(get_args(t)[0])
|
||||
|
||||
def load_base_model(v: Dict[str, Any], t: Type[BaseModel]) -> BaseModel:
|
||||
return t(**v)
|
||||
|
||||
return load_base_model if isinstance(t, type) and issubclass(t, BaseModel) else None
|
||||
|
||||
|
||||
def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""A decorator to load the parameters of a function if they are Pydantic models
|
||||
|
||||
Args:
|
||||
func: The function with annotated parameters
|
||||
|
||||
Returns:
|
||||
A function that loads the parameters before calling the original function
|
||||
|
||||
"""
|
||||
# get the type annotations of the parameters
|
||||
typed_signature = get_typed_signature(func)
|
||||
param_annotations = get_param_annotations(typed_signature)
|
||||
|
||||
# get functions for loading BaseModels when needed based on the type annotations
|
||||
kwargs_mapping = {k: get_load_param_if_needed_function(t) for k, t in param_annotations.items()}
|
||||
|
||||
# remove the None values
|
||||
kwargs_mapping = {k: f for k, f in kwargs_mapping.items() if f is not None}
|
||||
|
||||
# a function that loads the parameters before calling the original function
|
||||
@functools.wraps(func)
|
||||
def load_parameters_if_needed(*args, **kwargs):
|
||||
# load the BaseModels if needed
|
||||
for k, f in kwargs_mapping.items():
|
||||
kwargs[k] = f(kwargs[k], param_annotations[k])
|
||||
|
||||
# call the original function
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return load_parameters_if_needed
|
||||
|
||||
|
||||
def serialize_to_str(x: Any) -> str:
|
||||
if isinstance(x, str):
|
||||
return x
|
||||
elif isinstance(x, BaseModel):
|
||||
return model_dump_json(x)
|
||||
else:
|
||||
return json.dumps(x)
|
|
@ -6,6 +6,7 @@ from typing import List, Optional, Dict, Callable, Union
|
|||
import logging
|
||||
import inspect
|
||||
from flaml.automl.logger import logger_formatter
|
||||
from pydantic import ValidationError
|
||||
|
||||
from autogen.oai.openai_utils import get_key, oai_price1k
|
||||
from autogen.token_count_utils import count_token
|
||||
|
@ -329,15 +330,27 @@ class OpenAIWrapper:
|
|||
),
|
||||
)
|
||||
for i in range(len(response_contents)):
|
||||
response.choices.append(
|
||||
Choice(
|
||||
try:
|
||||
# OpenAI versions 0.1.5 and above
|
||||
choice = Choice(
|
||||
index=i,
|
||||
finish_reason=finish_reasons[i],
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=response_contents[i], function_call=None
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
except ValidationError:
|
||||
# OpenAI version up to 0.1.4
|
||||
choice = Choice(
|
||||
index=i,
|
||||
finish_reason=finish_reasons[i],
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=response_contents[i], function_call=None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
response.choices.append(choice)
|
||||
else:
|
||||
# If streaming is not enabled or using functions, send a regular chat completion request
|
||||
# Functions are not supported, so ensure streaming is disabled
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -115,7 +115,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"id": "9fb85afb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
|
@ -132,9 +132,7 @@
|
|||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: timer *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"num_seconds\": \"5\"\n",
|
||||
"}\n",
|
||||
"{\"num_seconds\":\"5\"}\n",
|
||||
"\u001b[32m******************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
|
@ -151,9 +149,7 @@
|
|||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: stopwatch *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"num_seconds\": \"5\"\n",
|
||||
"}\n",
|
||||
"{\"num_seconds\":\"5\"}\n",
|
||||
"\u001b[32m**********************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
|
@ -178,52 +174,10 @@
|
|||
"# define functions according to the function description\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"# An example async function\n",
|
||||
"async def timer(num_seconds):\n",
|
||||
" for i in range(int(num_seconds)):\n",
|
||||
" time.sleep(1)\n",
|
||||
" # should print to stdout\n",
|
||||
" return \"Timer is done!\"\n",
|
||||
"\n",
|
||||
"# An example sync function \n",
|
||||
"def stopwatch(num_seconds):\n",
|
||||
" for i in range(int(num_seconds)):\n",
|
||||
" time.sleep(1)\n",
|
||||
" return \"Stopwatch is done!\"\n",
|
||||
"\n",
|
||||
"llm_config = {\n",
|
||||
" \"functions\": [\n",
|
||||
" {\n",
|
||||
" \"name\": \"timer\",\n",
|
||||
" \"description\": \"create a timer for N seconds\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"num_seconds\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"Number of seconds in the timer.\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" \"required\": [\"num_seconds\"],\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"name\": \"stopwatch\",\n",
|
||||
" \"description\": \"create a stopwatch for N seconds\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"num_seconds\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"Number of seconds in the stopwatch.\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" \"required\": [\"num_seconds\"],\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" \"config_list\": config_list,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"coder = autogen.AssistantAgent(\n",
|
||||
" name=\"chatbot\",\n",
|
||||
" system_message=\"For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.\",\n",
|
||||
|
@ -240,21 +194,35 @@
|
|||
" code_execution_config={\"work_dir\": \"coding\"},\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# register the functions\n",
|
||||
"user_proxy.register_function(\n",
|
||||
" function_map={\n",
|
||||
" \"timer\": timer,\n",
|
||||
" \"stopwatch\": stopwatch,\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"from typing_extensions import Annotated\n",
|
||||
"\n",
|
||||
"# An example async function\n",
|
||||
"@user_proxy.register_for_execution()\n",
|
||||
"@coder.register_for_llm(description=\"create a timer for N seconds\")\n",
|
||||
"async def timer(num_seconds: Annotated[str, \"Number of seconds in the timer.\"]) -> str:\n",
|
||||
" for i in range(int(num_seconds)):\n",
|
||||
" time.sleep(1)\n",
|
||||
" # should print to stdout\n",
|
||||
" return \"Timer is done!\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# An example sync function\n",
|
||||
"@user_proxy.register_for_execution()\n",
|
||||
"@coder.register_for_llm(description=\"create a stopwatch for N seconds\")\n",
|
||||
"def stopwatch(num_seconds: Annotated[str, \"Number of seconds in the stopwatch.\"]) -> str:\n",
|
||||
" for i in range(int(num_seconds)):\n",
|
||||
" time.sleep(1)\n",
|
||||
" return \"Stopwatch is done!\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# start the conversation\n",
|
||||
"# 'await' is used to pause and resume code execution for async IO operations. \n",
|
||||
"# 'await' is used to pause and resume code execution for async IO operations.\n",
|
||||
"# Without 'await', an async function returns a coroutine object but doesn't execute the function.\n",
|
||||
"# With 'await', the async function is executed and the current function is paused until the awaited function returns a result.\n",
|
||||
"await user_proxy.a_initiate_chat(\n",
|
||||
" coder,\n",
|
||||
" message=\"Create a timer for 5 seconds and then a stopwatch for 5 seconds.\",\n",
|
||||
")\n"
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -268,62 +236,36 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"id": "2472f95c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"# Add a function for robust group chat termination\n",
|
||||
"def terminate_group_chat(message):\n",
|
||||
" return f\"[GROUPCHAT_TERMINATE] {message}\"\n",
|
||||
"\n",
|
||||
"# update LLM config\n",
|
||||
"llm_config[\"functions\"].append(\n",
|
||||
" {\n",
|
||||
" \"name\": \"terminate_group_chat\",\n",
|
||||
" \"description\": \"terminate the group chat\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"message\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"Message to be sent to the group chat.\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" \"required\": [\"message\"],\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# redefine the coder agent so that it uses the new llm_config\n",
|
||||
"coder = autogen.AssistantAgent(\n",
|
||||
" name=\"chatbot\",\n",
|
||||
" system_message=\"For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.\",\n",
|
||||
" llm_config=llm_config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# register the new function with user proxy agent\n",
|
||||
"user_proxy.register_function(\n",
|
||||
" function_map={\n",
|
||||
" \"terminate_group_chat\": terminate_group_chat,\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"markdownagent = autogen.AssistantAgent(\n",
|
||||
" name=\"Markdown_agent\",\n",
|
||||
" system_message=\"Respond in markdown only\",\n",
|
||||
" llm_config=llm_config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Add a function for robust group chat termination\n",
|
||||
"@user_proxy.register_for_execution()\n",
|
||||
"@markdownagent.register_for_llm()\n",
|
||||
"@coder.register_for_llm(description=\"terminate the group chat\")\n",
|
||||
"def terminate_group_chat(message: Annotated[str, \"Message to be sent to the group chat.\"]) -> str:\n",
|
||||
" return f\"[GROUPCHAT_TERMINATE] {message}\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"groupchat = autogen.GroupChat(agents=[user_proxy, coder, markdownagent], messages=[], max_round=12)\n",
|
||||
"manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config,\n",
|
||||
" is_termination_msg=lambda x: \"GROUPCHAT_TERMINATE\" in x.get(\"content\", \"\"),\n",
|
||||
" )"
|
||||
"manager = autogen.GroupChatManager(\n",
|
||||
" groupchat=groupchat,\n",
|
||||
" llm_config=llm_config,\n",
|
||||
" is_termination_msg=lambda x: \"GROUPCHAT_TERMINATE\" in x.get(\"content\", \"\"),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"id": "e2c9267a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
|
@ -340,25 +282,21 @@
|
|||
"4) when 1-3 are done, terminate the group chat\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: timer *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"\n",
|
||||
"{\n",
|
||||
" \"num_seconds\": \"5\"\n",
|
||||
"}\n",
|
||||
"{\"num_seconds\":\"5\"}\n",
|
||||
"\u001b[32m******************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling function \"timer\" *****\u001b[0m\n",
|
||||
|
@ -366,14 +304,16 @@
|
|||
"\u001b[32m**************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: stopwatch *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"\n",
|
||||
"{\n",
|
||||
" \"num_seconds\": \"5\"\n",
|
||||
"}\n",
|
||||
"{\"num_seconds\":\"5\"}\n",
|
||||
"\u001b[32m**********************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
|
@ -388,19 +328,18 @@
|
|||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mMarkdown_agent\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"```markdown\n",
|
||||
"# Results \n",
|
||||
"The results are as follows:\n",
|
||||
"\n",
|
||||
"1. Timer: The timer for 5 seconds has completed.\n",
|
||||
"2. Stopwatch: The stopwatch for 5 seconds has completed.\n",
|
||||
"```\n",
|
||||
"By the way, step 3 is done now. Moving on to step 4.\n",
|
||||
"- Timer: Completed after `5 seconds`.\n",
|
||||
"- Stopwatch: Recorded time of `5 seconds`.\n",
|
||||
"\n",
|
||||
"**Timer and Stopwatch Summary:**\n",
|
||||
"Both the timer and stopwatch were set for `5 seconds` and have now concluded successfully. \n",
|
||||
"\n",
|
||||
"Now, let's proceed to terminate the group chat as requested.\n",
|
||||
"\u001b[32m***** Suggested function Call: terminate_group_chat *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"\n",
|
||||
"{\n",
|
||||
" \"message\": \"The tasks have been completed. Terminating the group chat now.\"\n",
|
||||
"}\n",
|
||||
"{\"message\":\"All tasks have been completed. The group chat will now be terminated. Goodbye!\"}\n",
|
||||
"\u001b[32m*********************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
|
@ -409,7 +348,7 @@
|
|||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling function \"terminate_group_chat\" *****\u001b[0m\n",
|
||||
"[GROUPCHAT_TERMINATE] The tasks have been completed. Terminating the group chat now.\n",
|
||||
"[GROUPCHAT_TERMINATE] All tasks have been completed. The group chat will now be terminated. Goodbye!\n",
|
||||
"\u001b[32m*****************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
|
@ -417,13 +356,23 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"await user_proxy.a_initiate_chat(manager,\n",
|
||||
" message=\"\"\"\n",
|
||||
"await user_proxy.a_initiate_chat(\n",
|
||||
" manager,\n",
|
||||
" message=\"\"\"\n",
|
||||
"1) Create a timer for 5 seconds.\n",
|
||||
"2) a stopwatch for 5 seconds.\n",
|
||||
"3) Pretty print the result as md.\n",
|
||||
"4) when 1-3 are done, terminate the group chat\"\"\")\n"
|
||||
"4) when 1-3 are done, terminate the group chat\"\"\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6d074e51",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -442,7 +391,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -0,0 +1,551 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "ae1f50ec",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/microsoft/autogen/blob/main/notebook/agentchat_function_call.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "9a71fa36",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Auto Generated Agent Chat: Task Solving with Provided Tools as Functions\n",
|
||||
"\n",
|
||||
"AutoGen offers conversable agents powered by LLM, tool, or human, which can be used to perform tasks collectively via automated chat. This framework allows tool use and human participation through multi-agent conversation. Please find documentation about this feature [here](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat).\n",
|
||||
"\n",
|
||||
"In this notebook, we demonstrate how to use `AssistantAgent` and `UserProxyAgent` to make function calls with the new feature of OpenAI models (in model version 0613). A specified prompt and function configs must be passed to `AssistantAgent` to initialize the agent. The corresponding functions must be passed to `UserProxyAgent`, which will execute any function calls made by `AssistantAgent`. Besides this requirement of matching descriptions with functions, we recommend checking the system message in the `AssistantAgent` to ensure the instructions align with the function call descriptions.\n",
|
||||
"\n",
|
||||
"## Requirements\n",
|
||||
"\n",
|
||||
"AutoGen requires `Python>=3.8`. To run this notebook example, please install `pyautogen`:\n",
|
||||
"```bash\n",
|
||||
"pip install pyautogen\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "2b803c17",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# %pip install \"pyautogen~=0.2.2\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "5ebd2397",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set your API Endpoint\n",
|
||||
"\n",
|
||||
"The [`config_list_from_json`](https://microsoft.github.io/autogen/docs/reference/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "dca301a4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import autogen\n",
|
||||
"\n",
|
||||
"config_list = autogen.config_list_from_json(\n",
|
||||
" \"OAI_CONFIG_LIST\",\n",
|
||||
" filter_dict={\n",
|
||||
" \"model\": [\"gpt-4\", \"gpt-3.5-turbo\", \"gpt-3.5-turbo-16k\"],\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "92fde41f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"It first looks for environment variable \"OAI_CONFIG_LIST\" which needs to be a valid json string. If that variable is not found, it then looks for a json file named \"OAI_CONFIG_LIST\". It filters the configs by models (you can filter by other keys as well). Only the models with matching names are kept in the list based on the filter condition.\n",
|
||||
"\n",
|
||||
"The config list looks like the following:\n",
|
||||
"```python\n",
|
||||
"config_list = [\n",
|
||||
" {\n",
|
||||
" 'model': 'gpt-4',\n",
|
||||
" 'api_key': '<your OpenAI API key here>',\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" 'model': 'gpt-3.5-turbo',\n",
|
||||
" 'api_key': '<your Azure OpenAI API key here>',\n",
|
||||
" 'base_url': '<your Azure OpenAI API base here>',\n",
|
||||
" 'api_type': 'azure',\n",
|
||||
" 'api_version': '2023-08-01-preview',\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" 'model': 'gpt-3.5-turbo-16k',\n",
|
||||
" 'api_key': '<your Azure OpenAI API key here>',\n",
|
||||
" 'base_url': '<your Azure OpenAI API base here>',\n",
|
||||
" 'api_type': 'azure',\n",
|
||||
" 'api_version': '2023-08-01-preview',\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"You can set the value of config_list in any way you prefer. Please refer to this [notebook](https://github.com/microsoft/autogen/blob/main/notebook/oai_openai_utils.ipynb) for full code examples of the different methods."
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "2b9526e7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Making Function Calls\n",
|
||||
"\n",
|
||||
"In this example, we demonstrate function call execution with `AssistantAgent` and `UserProxyAgent`. With the default system prompt of `AssistantAgent`, we allow the LLM assistant to perform tasks with code, and the `UserProxyAgent` would extract code blocks from the LLM response and execute them. With the new \"function_call\" feature, we define functions and specify the description of the function in the OpenAI config for the `AssistantAgent`. Then we register the functions in `UserProxyAgent`.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "9fb85afb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_config = {\n",
|
||||
" \"config_list\": config_list,\n",
|
||||
" \"timeout\": 120,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"chatbot = autogen.AssistantAgent(\n",
|
||||
" name=\"chatbot\",\n",
|
||||
" system_message=\"For currency exchange tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.\",\n",
|
||||
" llm_config=llm_config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# create a UserProxyAgent instance named \"user_proxy\"\n",
|
||||
"user_proxy = autogen.UserProxyAgent(\n",
|
||||
" name=\"user_proxy\",\n",
|
||||
" is_termination_msg=lambda x: x.get(\"content\", \"\") and x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
|
||||
" human_input_mode=\"NEVER\",\n",
|
||||
" max_consecutive_auto_reply=10,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"from typing import Literal\n",
|
||||
"\n",
|
||||
"from typing_extensions import Annotated\n",
|
||||
"\n",
|
||||
"CurrencySymbol = Literal[\"USD\", \"EUR\"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:\n",
|
||||
" if base_currency == quote_currency:\n",
|
||||
" return 1.0\n",
|
||||
" elif base_currency == \"USD\" and quote_currency == \"EUR\":\n",
|
||||
" return 1 / 1.1\n",
|
||||
" elif base_currency == \"EUR\" and quote_currency == \"USD\":\n",
|
||||
" return 1.1\n",
|
||||
" else:\n",
|
||||
" raise ValueError(f\"Unknown currencies {base_currency}, {quote_currency}\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@user_proxy.register_for_execution()\n",
|
||||
"@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n",
|
||||
"def currency_calculator(\n",
|
||||
" base_amount: Annotated[float, \"Amount of currency in base_currency\"],\n",
|
||||
" base_currency: Annotated[CurrencySymbol, \"Base currency\"] = \"USD\",\n",
|
||||
" quote_currency: Annotated[CurrencySymbol, \"Quote currency\"] = \"EUR\",\n",
|
||||
") -> str:\n",
|
||||
" quote_amount = exchange_rate(base_currency, quote_currency) * base_amount\n",
|
||||
" return f\"{quote_amount} {quote_currency}\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "39464dc3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The decorator `@chatbot.register_for_llm()` reads the annotated signature of the function `currency_calculator` and generates the following JSON schema used by OpenAI API to suggest calling the function. We can check the JSON schema generated as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "3e52bbfe",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'description': 'Currency exchange calculator.',\n",
|
||||
" 'name': 'currency_calculator',\n",
|
||||
" 'parameters': {'type': 'object',\n",
|
||||
" 'properties': {'base_amount': {'type': 'number',\n",
|
||||
" 'description': 'Amount of currency in base_currency'},\n",
|
||||
" 'base_currency': {'enum': ['USD', 'EUR'],\n",
|
||||
" 'type': 'string',\n",
|
||||
" 'default': 'USD',\n",
|
||||
" 'description': 'Base currency'},\n",
|
||||
" 'quote_currency': {'enum': ['USD', 'EUR'],\n",
|
||||
" 'type': 'string',\n",
|
||||
" 'default': 'EUR',\n",
|
||||
" 'description': 'Quote currency'}},\n",
|
||||
" 'required': ['base_amount']}}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chatbot.llm_config[\"functions\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "662bd12a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The decorator `@user_proxy.register_for_execution()` maps the name of the function to be proposed by OpenAI API to the actual implementation. The function mapped is wrapped since we also automatically handle serialization of the output of function as follows:\n",
|
||||
"\n",
|
||||
"- string are untouched, and\n",
|
||||
"\n",
|
||||
"- objects of the Pydantic BaseModel type are serialized to JSON.\n",
|
||||
"\n",
|
||||
"We can check the correctness of of function map by using `._origin` property of the wrapped funtion as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "bd943369",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"assert user_proxy.function_map[\"currency_calculator\"]._origin == currency_calculator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8a3a09c9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally, we can use this function to accurately calculate exchange amounts:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "d5518947",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"How much is 123.45 USD in EUR?\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: currency_calculator *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"base_amount\":123.45,\"base_currency\":\"USD\",\"quote_currency\":\"EUR\"}\n",
|
||||
"\u001b[32m********************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling function \"currency_calculator\" *****\u001b[0m\n",
|
||||
"112.22727272727272 EUR\n",
|
||||
"\u001b[32m****************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"123.45 USD is equivalent to approximately 112.23 EUR.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"TERMINATE\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# start the conversation\n",
|
||||
"user_proxy.initiate_chat(\n",
|
||||
" chatbot,\n",
|
||||
" message=\"How much is 123.45 USD in EUR?\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bd9d61cf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Pydantic models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d79fec0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also use Pydantic Base models to rewrite the function as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "7b3d8b58",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_config = {\n",
|
||||
" \"config_list\": config_list,\n",
|
||||
" \"timeout\": 120,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"chatbot = autogen.AssistantAgent(\n",
|
||||
" name=\"chatbot\",\n",
|
||||
" system_message=\"For currency exchange tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.\",\n",
|
||||
" llm_config=llm_config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# create a UserProxyAgent instance named \"user_proxy\"\n",
|
||||
"user_proxy = autogen.UserProxyAgent(\n",
|
||||
" name=\"user_proxy\",\n",
|
||||
" is_termination_msg=lambda x: x.get(\"content\", \"\") and x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
|
||||
" human_input_mode=\"NEVER\",\n",
|
||||
" max_consecutive_auto_reply=10,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"from typing import Literal\n",
|
||||
"\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"from typing_extensions import Annotated\n",
|
||||
"\n",
|
||||
"class Currency(BaseModel):\n",
|
||||
" currency: Annotated[CurrencySymbol, Field(..., description=\"Currency symbol\")]\n",
|
||||
" amount: Annotated[float, Field(0, description=\"Amount of currency\", ge=0)]\n",
|
||||
"\n",
|
||||
"@user_proxy.register_for_execution()\n",
|
||||
"@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n",
|
||||
"def currency_calculator(\n",
|
||||
" base: Annotated[Currency, \"Base currency: amount and currency symbol\"],\n",
|
||||
" quote_currency: Annotated[CurrencySymbol, \"Quote currency symbol\"] = \"USD\",\n",
|
||||
") -> Currency:\n",
|
||||
" quote_amount = exchange_rate(base.currency, quote_currency) * base.amount\n",
|
||||
" return Currency(amount=quote_amount, currency=quote_currency)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "971ed0d5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'description': 'Currency exchange calculator.',\n",
|
||||
" 'name': 'currency_calculator',\n",
|
||||
" 'parameters': {'type': 'object',\n",
|
||||
" 'properties': {'base': {'properties': {'currency': {'description': 'Currency symbol',\n",
|
||||
" 'enum': ['USD', 'EUR'],\n",
|
||||
" 'title': 'Currency',\n",
|
||||
" 'type': 'string'},\n",
|
||||
" 'amount': {'default': 0,\n",
|
||||
" 'description': 'Amount of currency',\n",
|
||||
" 'minimum': 0.0,\n",
|
||||
" 'title': 'Amount',\n",
|
||||
" 'type': 'number'}},\n",
|
||||
" 'required': ['currency'],\n",
|
||||
" 'title': 'Currency',\n",
|
||||
" 'type': 'object',\n",
|
||||
" 'description': 'Base currency: amount and currency symbol'},\n",
|
||||
" 'quote_currency': {'enum': ['USD', 'EUR'],\n",
|
||||
" 'type': 'string',\n",
|
||||
" 'default': 'USD',\n",
|
||||
" 'description': 'Quote currency symbol'}},\n",
|
||||
" 'required': ['base']}}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chatbot.llm_config[\"functions\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "ab081090",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"How much is 112.23 Euros in US Dollars?\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: currency_calculator *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"base\":{\"currency\":\"EUR\",\"amount\":112.23},\"quote_currency\":\"USD\"}\n",
|
||||
"\u001b[32m********************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling function \"currency_calculator\" *****\u001b[0m\n",
|
||||
"{\"currency\":\"USD\",\"amount\":123.45300000000002}\n",
|
||||
"\u001b[32m****************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"112.23 Euros is equivalent to approximately 123.45 US Dollars.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"TERMINATE\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# start the conversation\n",
|
||||
"user_proxy.initiate_chat(\n",
|
||||
" chatbot,\n",
|
||||
" message=\"How much is 112.23 Euros in US Dollars?\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "0064d9cd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"How much is 123.45 US Dollars in Euros?\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: currency_calculator *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}\n",
|
||||
"\u001b[32m********************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling function \"currency_calculator\" *****\u001b[0m\n",
|
||||
"{\"currency\":\"EUR\",\"amount\":112.22727272727272}\n",
|
||||
"\u001b[32m****************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"123.45 US Dollars is approximately 112.23 Euros.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"TERMINATE\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# start the conversation\n",
|
||||
"user_proxy.initiate_chat(\n",
|
||||
" chatbot,\n",
|
||||
" message=\"How much is 123.45 US Dollars in Euros?\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "06137f23",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "flaml_dev",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -59,7 +59,7 @@
|
|||
"config_list = autogen.config_list_from_json(\n",
|
||||
" \"OAI_CONFIG_LIST\",\n",
|
||||
" filter_dict={\n",
|
||||
" \"model\": [\"gpt-3.5-turbo\"],\n",
|
||||
" \"model\": [\"gpt-3.5-turbo\", \"gpt-35-turbo\"],\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
|
|
3
setup.py
3
setup.py
|
@ -14,12 +14,13 @@ with open(os.path.join(here, "autogen/version.py")) as fp:
|
|||
__version__ = version["__version__"]
|
||||
|
||||
install_requires = [
|
||||
"openai~=1.3",
|
||||
"openai>=1,<1.5", # a temporary fix for breaking changes in 1.5
|
||||
"diskcache",
|
||||
"termcolor",
|
||||
"flaml",
|
||||
"python-dotenv",
|
||||
"tiktoken",
|
||||
"pydantic>=1.10,<3", # could be both V1 and V2
|
||||
]
|
||||
|
||||
setuptools.setup(
|
||||
|
|
|
@ -68,6 +68,7 @@ def test_gpt35(human_input_mode="NEVER", max_consecutive_auto_reply=5):
|
|||
filter_dict={
|
||||
"model": {
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-35-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
import copy
|
||||
from typing import Any, Callable, Dict, Literal
|
||||
|
||||
import pytest
|
||||
from autogen.agentchat import ConversableAgent
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from autogen.agentchat import ConversableAgent, UserProxyAgent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -331,6 +337,278 @@ async def test_a_generate_reply_raises_on_messages_and_sender_none(conversable_a
|
|||
await conversable_agent.a_generate_reply(messages=None, sender=None)
|
||||
|
||||
|
||||
def test_update_function_signature_and_register_functions() -> None:
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
mp.setenv("OPENAI_API_KEY", "mock")
|
||||
agent = ConversableAgent(name="agent", llm_config={})
|
||||
|
||||
def exec_python(cell: str) -> None:
|
||||
pass
|
||||
|
||||
def exec_sh(script: str) -> None:
|
||||
pass
|
||||
|
||||
agent.update_function_signature(
|
||||
{
|
||||
"name": "python",
|
||||
"description": "run cell in ipython and return the execution result.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cell": {
|
||||
"type": "string",
|
||||
"description": "Valid Python cell to execute.",
|
||||
}
|
||||
},
|
||||
"required": ["cell"],
|
||||
},
|
||||
},
|
||||
is_remove=False,
|
||||
)
|
||||
|
||||
functions = agent.llm_config["functions"]
|
||||
assert {f["name"] for f in functions} == {"python"}
|
||||
|
||||
agent.update_function_signature(
|
||||
{
|
||||
"name": "sh",
|
||||
"description": "run a shell script and return the execution result.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "Valid shell script to execute.",
|
||||
}
|
||||
},
|
||||
"required": ["script"],
|
||||
},
|
||||
},
|
||||
is_remove=False,
|
||||
)
|
||||
|
||||
functions = agent.llm_config["functions"]
|
||||
assert {f["name"] for f in functions} == {"python", "sh"}
|
||||
|
||||
# register the functions
|
||||
agent.register_function(
|
||||
function_map={
|
||||
"python": exec_python,
|
||||
"sh": exec_sh,
|
||||
}
|
||||
)
|
||||
assert set(agent.function_map.keys()) == {"python", "sh"}
|
||||
assert agent.function_map["python"] == exec_python
|
||||
assert agent.function_map["sh"] == exec_sh
|
||||
|
||||
|
||||
def test__wrap_function_sync():
|
||||
CurrencySymbol = Literal["USD", "EUR"]
|
||||
|
||||
class Currency(BaseModel):
|
||||
currency: Annotated[CurrencySymbol, Field(..., description="Currency code")]
|
||||
amount: Annotated[float, Field(100.0, description="Amount of money in the currency")]
|
||||
|
||||
Currency(currency="USD", amount=100.0)
|
||||
|
||||
def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:
|
||||
if base_currency == quote_currency:
|
||||
return 1.0
|
||||
elif base_currency == "USD" and quote_currency == "EUR":
|
||||
return 1 / 1.1
|
||||
elif base_currency == "EUR" and quote_currency == "USD":
|
||||
return 1.1
|
||||
else:
|
||||
raise ValueError(f"Unknown currencies {base_currency}, {quote_currency}")
|
||||
|
||||
agent = ConversableAgent(name="agent", llm_config={})
|
||||
|
||||
@agent._wrap_function
|
||||
def currency_calculator(
|
||||
base: Annotated[Currency, "Base currency"],
|
||||
quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
|
||||
) -> Currency:
|
||||
quote_amount = exchange_rate(base.currency, quote_currency) * base.amount
|
||||
return Currency(amount=quote_amount, currency=quote_currency)
|
||||
|
||||
assert (
|
||||
currency_calculator(base={"currency": "USD", "amount": 110.11}, quote_currency="EUR")
|
||||
== '{"currency":"EUR","amount":100.1}'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__wrap_function_async():
|
||||
CurrencySymbol = Literal["USD", "EUR"]
|
||||
|
||||
class Currency(BaseModel):
|
||||
currency: Annotated[CurrencySymbol, Field(..., description="Currency code")]
|
||||
amount: Annotated[float, Field(100.0, description="Amount of money in the currency")]
|
||||
|
||||
Currency(currency="USD", amount=100.0)
|
||||
|
||||
def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:
|
||||
if base_currency == quote_currency:
|
||||
return 1.0
|
||||
elif base_currency == "USD" and quote_currency == "EUR":
|
||||
return 1 / 1.1
|
||||
elif base_currency == "EUR" and quote_currency == "USD":
|
||||
return 1.1
|
||||
else:
|
||||
raise ValueError(f"Unknown currencies {base_currency}, {quote_currency}")
|
||||
|
||||
agent = ConversableAgent(name="agent", llm_config={})
|
||||
|
||||
@agent._wrap_function
|
||||
async def currency_calculator(
|
||||
base: Annotated[Currency, "Base currency"],
|
||||
quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
|
||||
) -> Currency:
|
||||
quote_amount = exchange_rate(base.currency, quote_currency) * base.amount
|
||||
return Currency(amount=quote_amount, currency=quote_currency)
|
||||
|
||||
assert (
|
||||
await currency_calculator(base={"currency": "USD", "amount": 110.11}, quote_currency="EUR")
|
||||
== '{"currency":"EUR","amount":100.1}'
|
||||
)
|
||||
|
||||
|
||||
def get_origin(d: Dict[str, Callable[..., Any]]) -> Dict[str, Callable[..., Any]]:
|
||||
return {k: v._origin for k, v in d.items()}
|
||||
|
||||
|
||||
def test_register_for_llm():
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
mp.setenv("OPENAI_API_KEY", "mock")
|
||||
agent3 = ConversableAgent(name="agent3", llm_config={})
|
||||
agent2 = ConversableAgent(name="agent2", llm_config={})
|
||||
agent1 = ConversableAgent(name="agent1", llm_config={})
|
||||
|
||||
@agent3.register_for_llm()
|
||||
@agent2.register_for_llm(name="python")
|
||||
@agent1.register_for_llm(description="run cell in ipython and return the execution result.")
|
||||
def exec_python(cell: Annotated[str, "Valid Python cell to execute."]) -> str:
|
||||
pass
|
||||
|
||||
expected1 = [
|
||||
{
|
||||
"description": "run cell in ipython and return the execution result.",
|
||||
"name": "exec_python",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cell": {
|
||||
"type": "string",
|
||||
"description": "Valid Python cell to execute.",
|
||||
}
|
||||
},
|
||||
"required": ["cell"],
|
||||
},
|
||||
}
|
||||
]
|
||||
expected2 = copy.deepcopy(expected1)
|
||||
expected2[0]["name"] = "python"
|
||||
expected3 = expected2
|
||||
|
||||
assert agent1.llm_config["functions"] == expected1
|
||||
assert agent2.llm_config["functions"] == expected2
|
||||
assert agent3.llm_config["functions"] == expected3
|
||||
|
||||
@agent3.register_for_llm()
|
||||
@agent2.register_for_llm()
|
||||
@agent1.register_for_llm(name="sh", description="run a shell script and return the execution result.")
|
||||
async def exec_sh(script: Annotated[str, "Valid shell script to execute."]) -> str:
|
||||
pass
|
||||
|
||||
expected1 = expected1 + [
|
||||
{
|
||||
"name": "sh",
|
||||
"description": "run a shell script and return the execution result.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "Valid shell script to execute.",
|
||||
}
|
||||
},
|
||||
"required": ["script"],
|
||||
},
|
||||
}
|
||||
]
|
||||
expected2 = expected2 + [expected1[1]]
|
||||
expected3 = expected3 + [expected1[1]]
|
||||
|
||||
assert agent1.llm_config["functions"] == expected1
|
||||
assert agent2.llm_config["functions"] == expected2
|
||||
assert agent3.llm_config["functions"] == expected3
|
||||
|
||||
|
||||
def test_register_for_llm_without_description():
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
mp.setenv("OPENAI_API_KEY", "mock")
|
||||
agent = ConversableAgent(name="agent", llm_config={})
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
|
||||
@agent.register_for_llm()
|
||||
def exec_python(cell: Annotated[str, "Valid Python cell to execute."]) -> str:
|
||||
pass
|
||||
|
||||
assert e.value.args[0] == "Function description is required, none found."
|
||||
|
||||
|
||||
def test_register_for_llm_without_LLM():
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
mp.setenv("OPENAI_API_KEY", "mock")
|
||||
agent = ConversableAgent(name="agent", llm_config=None)
|
||||
agent.llm_config = None
|
||||
assert agent.llm_config is None
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
|
||||
@agent.register_for_llm(description="run cell in ipython and return the execution result.")
|
||||
def exec_python(cell: Annotated[str, "Valid Python cell to execute."]) -> str:
|
||||
pass
|
||||
|
||||
assert e.value.args[0] == "LLM config must be setup before registering a function for LLM."
|
||||
|
||||
|
||||
def test_register_for_execution():
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
mp.setenv("OPENAI_API_KEY", "mock")
|
||||
agent = ConversableAgent(name="agent", llm_config={})
|
||||
user_proxy_1 = UserProxyAgent(name="user_proxy_1")
|
||||
user_proxy_2 = UserProxyAgent(name="user_proxy_2")
|
||||
|
||||
@user_proxy_2.register_for_execution(name="python")
|
||||
@agent.register_for_execution()
|
||||
@agent.register_for_llm(description="run cell in ipython and return the execution result.")
|
||||
@user_proxy_1.register_for_execution()
|
||||
def exec_python(cell: Annotated[str, "Valid Python cell to execute."]):
|
||||
pass
|
||||
|
||||
expected_function_map_1 = {"exec_python": exec_python}
|
||||
assert get_origin(agent.function_map) == expected_function_map_1
|
||||
assert get_origin(user_proxy_1.function_map) == expected_function_map_1
|
||||
|
||||
expected_function_map_2 = {"python": exec_python}
|
||||
assert get_origin(user_proxy_2.function_map) == expected_function_map_2
|
||||
|
||||
@agent.register_for_execution()
|
||||
@agent.register_for_llm(description="run a shell script and return the execution result.")
|
||||
@user_proxy_1.register_for_execution(name="sh")
|
||||
async def exec_sh(script: Annotated[str, "Valid shell script to execute."]):
|
||||
pass
|
||||
|
||||
expected_function_map = {
|
||||
"exec_python": exec_python,
|
||||
"sh": exec_sh,
|
||||
}
|
||||
assert get_origin(agent.function_map) == expected_function_map
|
||||
assert get_origin(user_proxy_1.function_map) == expected_function_map
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_trigger()
|
||||
# test_context()
|
||||
|
|
|
@ -21,7 +21,7 @@ def test_aoai_chat_completion():
|
|||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
|
||||
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
|
||||
)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
# for config in config_list:
|
||||
|
@ -38,7 +38,7 @@ def test_oai_tool_calling_extraction():
|
|||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
|
||||
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
|
||||
)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(
|
||||
|
|
|
@ -15,7 +15,7 @@ def test_aoai_chat_completion_stream():
|
|||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
|
||||
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
|
||||
)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True)
|
||||
|
@ -28,7 +28,7 @@ def test_chat_completion_stream():
|
|||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
filter_dict={"model": ["gpt-3.5-turbo"]},
|
||||
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
|
||||
)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True)
|
||||
|
@ -41,7 +41,7 @@ def test_chat_functions_stream():
|
|||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
filter_dict={"model": ["gpt-3.5-turbo"]},
|
||||
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
|
||||
)
|
||||
functions = [
|
||||
{
|
||||
|
|
|
@ -0,0 +1,375 @@
|
|||
import inspect
|
||||
import unittest.mock
|
||||
from typing import Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from autogen._pydantic import PYDANTIC_V1, model_dump
|
||||
from autogen.function_utils import (
|
||||
get_default_values,
|
||||
get_function_schema,
|
||||
get_load_param_if_needed_function,
|
||||
get_missing_annotations,
|
||||
get_param_annotations,
|
||||
get_parameter_json_schema,
|
||||
get_parameters,
|
||||
get_required_params,
|
||||
get_typed_annotation,
|
||||
get_typed_return_annotation,
|
||||
get_typed_signature,
|
||||
load_basemodels_if_needed,
|
||||
serialize_to_str,
|
||||
)
|
||||
|
||||
|
||||
def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1, *, d):
|
||||
pass
|
||||
|
||||
|
||||
def g(
|
||||
a: Annotated[str, "Parameter a"],
|
||||
b: int = 2,
|
||||
c: Annotated[float, "Parameter c"] = 0.1,
|
||||
*,
|
||||
d: Dict[str, Tuple[Optional[int], List[float]]],
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
|
||||
async def a_g(
|
||||
a: Annotated[str, "Parameter a"],
|
||||
b: int = 2,
|
||||
c: Annotated[float, "Parameter c"] = 0.1,
|
||||
*,
|
||||
d: Dict[str, Tuple[Optional[int], List[float]]],
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
|
||||
def test_get_typed_annotation() -> None:
|
||||
globalns = getattr(f, "__globals__", {})
|
||||
assert get_typed_annotation(str, globalns) == str
|
||||
assert get_typed_annotation("float", globalns) == float
|
||||
|
||||
|
||||
def test_get_typed_signature() -> None:
|
||||
assert get_typed_signature(f).parameters == inspect.signature(f).parameters
|
||||
assert get_typed_signature(g).parameters == inspect.signature(g).parameters
|
||||
|
||||
|
||||
def test_get_typed_return_annotation() -> None:
|
||||
assert get_typed_return_annotation(f) is None
|
||||
assert get_typed_return_annotation(g) == str
|
||||
|
||||
|
||||
def test_get_parameter_json_schema() -> None:
|
||||
assert get_parameter_json_schema("c", str, {}) == {"type": "string", "description": "c"}
|
||||
assert get_parameter_json_schema("c", str, {"c": "ccc"}) == {"type": "string", "description": "c", "default": "ccc"}
|
||||
|
||||
assert get_parameter_json_schema("a", Annotated[str, "parameter a"], {}) == {
|
||||
"type": "string",
|
||||
"description": "parameter a",
|
||||
}
|
||||
assert get_parameter_json_schema("a", Annotated[str, "parameter a"], {"a": "3.14"}) == {
|
||||
"type": "string",
|
||||
"description": "parameter a",
|
||||
"default": "3.14",
|
||||
}
|
||||
|
||||
class B(BaseModel):
|
||||
b: float
|
||||
c: str
|
||||
|
||||
expected = {
|
||||
"description": "b",
|
||||
"properties": {"b": {"title": "B", "type": "number"}, "c": {"title": "C", "type": "string"}},
|
||||
"required": ["b", "c"],
|
||||
"title": "B",
|
||||
"type": "object",
|
||||
}
|
||||
assert get_parameter_json_schema("b", B, {}) == expected
|
||||
|
||||
expected["default"] = B(b=1.2, c="3.4")
|
||||
assert get_parameter_json_schema("b", B, {"b": B(b=1.2, c="3.4")}) == expected
|
||||
|
||||
|
||||
def test_get_required_params() -> None:
|
||||
assert get_required_params(inspect.signature(f)) == ["a", "d"]
|
||||
assert get_required_params(inspect.signature(g)) == ["a", "d"]
|
||||
|
||||
|
||||
def test_get_default_values() -> None:
|
||||
assert get_default_values(inspect.signature(f)) == {"b": 2, "c": 0.1}
|
||||
assert get_default_values(inspect.signature(g)) == {"b": 2, "c": 0.1}
|
||||
|
||||
|
||||
def test_get_param_annotations() -> None:
|
||||
def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0):
|
||||
pass
|
||||
|
||||
expected = {"a": Annotated[str, "Parameter a"], "c": Annotated[float, "Parameter c"]}
|
||||
|
||||
typed_signature = get_typed_signature(f)
|
||||
param_annotations = get_param_annotations(typed_signature)
|
||||
|
||||
assert param_annotations == expected, param_annotations
|
||||
|
||||
|
||||
def test_get_missing_annotations() -> None:
|
||||
def _f1(a: str, b=2):
|
||||
pass
|
||||
|
||||
missing, unannotated_with_default = get_missing_annotations(get_typed_signature(_f1), ["a"])
|
||||
assert missing == set()
|
||||
assert unannotated_with_default == {"b"}
|
||||
|
||||
def _f2(a: str, b) -> str:
|
||||
"ok"
|
||||
|
||||
missing, unannotated_with_default = get_missing_annotations(get_typed_signature(_f2), ["a", "b"])
|
||||
assert missing == {"b"}
|
||||
assert unannotated_with_default == set()
|
||||
|
||||
def _f3() -> None:
|
||||
pass
|
||||
|
||||
missing, unannotated_with_default = get_missing_annotations(get_typed_signature(_f3), [])
|
||||
assert missing == set()
|
||||
assert unannotated_with_default == set()
|
||||
|
||||
|
||||
def test_get_parameters() -> None:
|
||||
def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0):
|
||||
pass
|
||||
|
||||
typed_signature = get_typed_signature(f)
|
||||
param_annotations = get_param_annotations(typed_signature)
|
||||
required = get_required_params(typed_signature)
|
||||
default_values = get_default_values(typed_signature)
|
||||
|
||||
expected = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string", "description": "Parameter a"},
|
||||
"c": {"type": "number", "description": "Parameter c", "default": 1.0},
|
||||
},
|
||||
"required": ["a"],
|
||||
}
|
||||
|
||||
actual = model_dump(get_parameters(required, param_annotations, default_values))
|
||||
|
||||
assert actual == expected, actual
|
||||
|
||||
|
||||
def test_get_function_schema_no_return_type() -> None:
|
||||
def f(a: Annotated[str, "Parameter a"], b: int, c: float = 0.1):
|
||||
pass
|
||||
|
||||
expected = (
|
||||
"The return type of the function 'f' is not annotated. Although annotating it is "
|
||||
+ "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'."
|
||||
)
|
||||
|
||||
with unittest.mock.patch("autogen.function_utils.logger.warning") as mock_logger_warning:
|
||||
get_function_schema(f, description="function g")
|
||||
|
||||
mock_logger_warning.assert_called_once_with(expected)
|
||||
|
||||
|
||||
def test_get_function_schema_unannotated_with_default() -> None:
|
||||
with unittest.mock.patch("autogen.function_utils.logger.warning") as mock_logger_warning:
|
||||
|
||||
def f(
|
||||
a: Annotated[str, "Parameter a"], b=2, c: Annotated[float, "Parameter c"] = 0.1, d="whatever", e=None
|
||||
) -> str:
|
||||
return "ok"
|
||||
|
||||
get_function_schema(f, description="function f")
|
||||
|
||||
mock_logger_warning.assert_called_once_with(
|
||||
"The following parameters of the function 'f' with default values are not annotated: 'b', 'd', 'e'."
|
||||
)
|
||||
|
||||
|
||||
def test_get_function_schema_missing() -> None:
|
||||
def f(a: Annotated[str, "Parameter a"], b, c: Annotated[float, "Parameter c"] = 0.1) -> float:
|
||||
pass
|
||||
|
||||
expected = (
|
||||
"All parameters of the function 'f' without default values must be annotated. "
|
||||
+ "The annotations are missing for the following parameters: 'b'"
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
get_function_schema(f, description="function f")
|
||||
|
||||
assert str(e.value) == expected, e.value
|
||||
|
||||
|
||||
def test_get_function_schema() -> None:
|
||||
expected_v2 = {
|
||||
"description": "function g",
|
||||
"name": "fancy name for g",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string", "description": "Parameter a"},
|
||||
"b": {"type": "integer", "description": "b", "default": 2},
|
||||
"c": {"type": "number", "description": "Parameter c", "default": 0.1},
|
||||
"d": {
|
||||
"additionalProperties": {
|
||||
"maxItems": 2,
|
||||
"minItems": 2,
|
||||
"prefixItems": [
|
||||
{"anyOf": [{"type": "integer"}, {"type": "null"}]},
|
||||
{"items": {"type": "number"}, "type": "array"},
|
||||
],
|
||||
"type": "array",
|
||||
},
|
||||
"type": "object",
|
||||
"description": "d",
|
||||
},
|
||||
},
|
||||
"required": ["a", "d"],
|
||||
},
|
||||
}
|
||||
|
||||
# the difference is that the v1 version does not handle Union types (Optional is Union[T, None])
|
||||
expected_v1 = {
|
||||
"description": "function g",
|
||||
"name": "fancy name for g",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string", "description": "Parameter a"},
|
||||
"b": {"type": "integer", "description": "b", "default": 2},
|
||||
"c": {"type": "number", "description": "Parameter c", "default": 0.1},
|
||||
"d": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "array",
|
||||
"minItems": 2,
|
||||
"maxItems": 2,
|
||||
"items": [{"type": "integer"}, {"type": "array", "items": {"type": "number"}}],
|
||||
},
|
||||
"description": "d",
|
||||
},
|
||||
},
|
||||
"required": ["a", "d"],
|
||||
},
|
||||
}
|
||||
|
||||
actual = get_function_schema(g, description="function g", name="fancy name for g")
|
||||
|
||||
if PYDANTIC_V1:
|
||||
assert actual == expected_v1, actual
|
||||
else:
|
||||
assert actual == expected_v2, actual
|
||||
|
||||
actual = get_function_schema(a_g, description="function g", name="fancy name for g")
|
||||
if PYDANTIC_V1:
|
||||
assert actual == expected_v1, actual
|
||||
else:
|
||||
assert actual == expected_v2, actual
|
||||
|
||||
|
||||
CurrencySymbol = Literal["USD", "EUR"]
|
||||
|
||||
|
||||
class Currency(BaseModel):
|
||||
currency: Annotated[CurrencySymbol, Field(..., description="Currency code")]
|
||||
amount: Annotated[float, Field(100.0, description="Amount of money in the currency")]
|
||||
|
||||
|
||||
def test_get_function_schema_pydantic() -> None:
|
||||
def currency_calculator(
|
||||
base: Annotated[Currency, "Base currency: amount and currency symbol"],
|
||||
quote_currency: Annotated[CurrencySymbol, "Quote currency symbol (default: 'EUR')"] = "EUR",
|
||||
) -> Currency:
|
||||
pass
|
||||
|
||||
expected = {
|
||||
"description": "Currency exchange calculator.",
|
||||
"name": "currency_calculator",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"base": {
|
||||
"properties": {
|
||||
"currency": {
|
||||
"description": "Currency code",
|
||||
"enum": ["USD", "EUR"],
|
||||
"title": "Currency",
|
||||
"type": "string",
|
||||
},
|
||||
"amount": {
|
||||
"default": 100.0,
|
||||
"description": "Amount of money in the currency",
|
||||
"title": "Amount",
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"required": ["currency"],
|
||||
"title": "Currency",
|
||||
"type": "object",
|
||||
"description": "Base currency: amount and currency symbol",
|
||||
},
|
||||
"quote_currency": {
|
||||
"enum": ["USD", "EUR"],
|
||||
"type": "string",
|
||||
"default": "EUR",
|
||||
"description": "Quote currency symbol (default: 'EUR')",
|
||||
},
|
||||
},
|
||||
"required": ["base"],
|
||||
},
|
||||
}
|
||||
|
||||
actual = get_function_schema(
|
||||
currency_calculator, description="Currency exchange calculator.", name="currency_calculator"
|
||||
)
|
||||
|
||||
assert actual == expected, actual
|
||||
|
||||
|
||||
def test_get_load_param_if_needed_function() -> None:
|
||||
assert get_load_param_if_needed_function(CurrencySymbol) is None
|
||||
assert get_load_param_if_needed_function(Currency)({"currency": "USD", "amount": 123.45}, Currency) == Currency(
|
||||
currency="USD", amount=123.45
|
||||
)
|
||||
|
||||
f = get_load_param_if_needed_function(Annotated[Currency, "amount and a symbol of a currency"])
|
||||
actual = f({"currency": "USD", "amount": 123.45}, Currency)
|
||||
expected = Currency(currency="USD", amount=123.45)
|
||||
assert actual == expected, actual
|
||||
|
||||
|
||||
def test_load_basemodels_if_needed() -> None:
|
||||
@load_basemodels_if_needed
|
||||
def f(
|
||||
base: Annotated[Currency, "Base currency"],
|
||||
quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
|
||||
) -> Tuple[Currency, CurrencySymbol]:
|
||||
return base, quote_currency
|
||||
|
||||
actual = f(base={"currency": "USD", "amount": 123.45}, quote_currency="EUR")
|
||||
assert isinstance(actual[0], Currency)
|
||||
assert actual[0].amount == 123.45
|
||||
assert actual[0].currency == "USD"
|
||||
assert actual[1] == "EUR"
|
||||
|
||||
|
||||
def test_serialize_to_json():
|
||||
assert serialize_to_str("abc") == "abc"
|
||||
assert serialize_to_str(123) == "123"
|
||||
assert serialize_to_str([123, 456]) == "[123, 456]"
|
||||
assert serialize_to_str({"a": 1, "b": 2.3}) == '{"a": 1, "b": 2.3}'
|
||||
|
||||
class A(BaseModel):
|
||||
a: int
|
||||
b: float
|
||||
c: str
|
||||
|
||||
assert serialize_to_str(A(a=1, b=2.3, c="abc")) == '{"a":1,"b":2.3,"c":"abc"}'
|
|
@ -68,6 +68,14 @@ def test_agentchat_function_call(save=False):
|
|||
run_notebook("agentchat_function_call.ipynb", save=save)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip or not sys.version.startswith("3.10"),
|
||||
reason="do not run if openai is not installed or py!=3.10",
|
||||
)
|
||||
def test_agentchat_function_call_currency_calculator(save=False):
|
||||
run_notebook("agentchat_function_call_currency_calculator.ipynb", save=save)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip or not sys.version.startswith("3.10"),
|
||||
reason="do not run if openai is not installed or py!=3.10",
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from autogen._pydantic import model_dump, model_dump_json, type2schema
|
||||
|
||||
|
||||
def test_type2schema() -> None:
|
||||
assert type2schema(str) == {"type": "string"}
|
||||
assert type2schema(int) == {"type": "integer"}
|
||||
assert type2schema(float) == {"type": "number"}
|
||||
assert type2schema(bool) == {"type": "boolean"}
|
||||
assert type2schema(None) == {"type": "null"}
|
||||
assert type2schema(Optional[int]) == {"anyOf": [{"type": "integer"}, {"type": "null"}]}
|
||||
assert type2schema(List[int]) == {"items": {"type": "integer"}, "type": "array"}
|
||||
assert type2schema(Tuple[int, float, str]) == {
|
||||
"maxItems": 3,
|
||||
"minItems": 3,
|
||||
"prefixItems": [{"type": "integer"}, {"type": "number"}, {"type": "string"}],
|
||||
"type": "array",
|
||||
}
|
||||
assert type2schema(Dict[str, int]) == {"additionalProperties": {"type": "integer"}, "type": "object"}
|
||||
assert type2schema(Annotated[str, "some text"]) == {"type": "string"}
|
||||
assert type2schema(Union[int, float]) == {"anyOf": [{"type": "integer"}, {"type": "number"}]}
|
||||
|
||||
|
||||
def test_model_dump() -> None:
|
||||
class A(BaseModel):
|
||||
a: str
|
||||
b: int = 2
|
||||
|
||||
assert model_dump(A(a="aaa")) == {"a": "aaa", "b": 2}
|
||||
|
||||
|
||||
def test_model_dump_json() -> None:
|
||||
class A(BaseModel):
|
||||
a: str
|
||||
b: int = 2
|
||||
|
||||
assert model_dump_json(A(a="aaa")).replace(" ", "") == '{"a":"aaa","b":2}'
|
|
@ -39,6 +39,161 @@ assistant = AssistantAgent(name="assistant")
|
|||
# create a UserProxyAgent instance named "user_proxy"
|
||||
user_proxy = UserProxyAgent(name="user_proxy")
|
||||
```
|
||||
#### Function calling
|
||||
|
||||
Function calling enables agents to interact with external tools and APIs more efficiently.
|
||||
This feature allows the AI model to intelligently choose to output a JSON object containing
|
||||
arguments to call specific functions based on the user's input. A fnctions to be called is
|
||||
specified with a JSON schema describing its parameters and their types. Writing such JSON schema
|
||||
is complex and error-prone and that is why AutoGen framework provides two high level function decorators for automatically generating such schema using type hints on standard Python datatypes
|
||||
or Pydantic models:
|
||||
|
||||
1. [`ConversableAgent.register_for_llm`](../reference/agentchat/conversable_agent#register_for_llm) is used to register the function in the `llm_config` of a ConversableAgent. The ConversableAgent agent can propose execution of a registrated function, but the actual execution will be performed by a UserProxy agent.
|
||||
|
||||
2. [`ConversableAgent.register_for_execution`](../reference/agentchat/conversable_agent#register_for_execution) is used to register the function in the `function_map` of a UserProxy agent.
|
||||
|
||||
The following examples illustrates the process of registering a custom function for currency exchange calculation that uses type hints and standard Python datatypes:
|
||||
|
||||
``` python
|
||||
from typying import Literal
|
||||
from typing_extensions import Annotated
|
||||
from somewhere import exchange_rate
|
||||
# the agents are instances of UserProxyAgent and AssistantAgent
|
||||
from myagents import agent, user_proxy
|
||||
|
||||
CurrencySymbol = Literal["USD", "EUR"]
|
||||
|
||||
# registers the function for execution (updates function map)
|
||||
@user_proxy.register_for_execution()
|
||||
# creates JSON schema from type hints and registers the function to llm_config
|
||||
@agent.register_for_llm(description="Currency exchange calculator.")
|
||||
# python function with type hints
|
||||
def currency_calculator(
|
||||
# Annotated type is used for attaching description to the parameter
|
||||
base_amount: Annotated[float, "Amount of currency in base_currency"],
|
||||
# default values of parameters will be propagated to the LLM
|
||||
base_currency: Annotated[CurrencySymbol, "Base currency"] = "USD",
|
||||
quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
|
||||
) -> str: # return type must be either str, BaseModel or serializable by json.dumps()
|
||||
quote_amount = exchange_rate(base_currency, quote_currency) * base_amount
|
||||
return f"{quote_amount} {quote_currency}"
|
||||
```
|
||||
|
||||
Notice the use of [Annotated](https://docs.python.org/3/library/typing.html?highlight=annotated#typing.Annotated) to specify the type and the description of each parameter. The return value of the function must be either string or serializable to string using the [`json.dumps()`](https://docs.python.org/3/library/json.html#json.dumps) or [`Pydantic` model dump to JSON](https://docs.pydantic.dev/latest/concepts/serialization/#modelmodel_dump_json) (both version 1.x and 2.x are supported).
|
||||
|
||||
You can check the JSON schema generated by the decorator `chatbot.llm_config["functions"]`:
|
||||
```python
|
||||
[{'description': 'Currency exchange calculator.',
|
||||
'name': 'currency_calculator',
|
||||
'parameters': {'type': 'object',
|
||||
'properties': {'base_amount': {'type': 'number',
|
||||
'description': 'Amount of currency in base_currency'},
|
||||
'base_currency': {'enum': ['USD', 'EUR'],
|
||||
'type': 'string',
|
||||
'default': 'USD',
|
||||
'description': 'Base currency'},
|
||||
'quote_currency': {'enum': ['USD', 'EUR'],
|
||||
'type': 'string',
|
||||
'default': 'EUR',
|
||||
'description': 'Quote currency'}},
|
||||
'required': ['base_amount']}}]
|
||||
```
|
||||
Agents can now use the function as follows:
|
||||
```
|
||||
user_proxy (to chatbot):
|
||||
|
||||
How much is 123.45 USD in EUR?
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
chatbot (to user_proxy):
|
||||
|
||||
***** Suggested function Call: currency_calculator *****
|
||||
Arguments:
|
||||
{"base_amount":123.45,"base_currency":"USD","quote_currency":"EUR"}
|
||||
********************************************************
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
>>>>>>>> EXECUTING FUNCTION currency_calculator...
|
||||
user_proxy (to chatbot):
|
||||
|
||||
***** Response from calling function "currency_calculator" *****
|
||||
112.22727272727272 EUR
|
||||
****************************************************************
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
chatbot (to user_proxy):
|
||||
|
||||
123.45 USD is equivalent to approximately 112.23 EUR.
|
||||
...
|
||||
|
||||
TERMINATE
|
||||
```
|
||||
|
||||
Use of Pydantic models further simplifies writing of such functions. Pydantic models can be used
|
||||
for both the parameters of a function and for its return type. Parameters of such functions will
|
||||
be constructed from JSON provided by an AI model, while the output will be serialized as JSON
|
||||
encoded string automatically.
|
||||
|
||||
The following example shows how we could rewrite our currency exchange calculator example:
|
||||
|
||||
``` python
|
||||
from typying import Literal
|
||||
from typing_extensions import Annotated
|
||||
from pydantic import BaseModel, Field
|
||||
from somewhere import exchange_rate
|
||||
from myagents import agent, user_proxy
|
||||
|
||||
# defines a Pydantic model
|
||||
class Currency(BaseModel):
|
||||
# parameter of type CurrencySymbol
|
||||
currency: Annotated[CurrencySymbol, Field(..., description="Currency symbol")]
|
||||
# parameter of type float, must be greater or equal to 0 with default value 0
|
||||
amount: Annotated[float, Field(0, description="Amount of currency", ge=0)]
|
||||
|
||||
@user_proxy.register_for_execution()
|
||||
@chatbot.register_for_llm(description="Currency exchange calculator.")
|
||||
def currency_calculator(
|
||||
base: Annotated[Currency, "Base currency: amount and currency symbol"],
|
||||
quote_currency: Annotated[CurrencySymbol, "Quote currency symbol"] = "USD",
|
||||
) -> Currency:
|
||||
quote_amount = exchange_rate(base.currency, quote_currency) * base.amount
|
||||
return Currency(amount=quote_amount, currency=quote_currency)
|
||||
```
|
||||
|
||||
The generated JSON schema has additional properties such as minimum value encoded:
|
||||
```python
|
||||
[{'description': 'Currency exchange calculator.',
|
||||
'name': 'currency_calculator',
|
||||
'parameters': {'type': 'object',
|
||||
'properties': {'base': {'properties': {'currency': {'description': 'Currency symbol',
|
||||
'enum': ['USD', 'EUR'],
|
||||
'title': 'Currency',
|
||||
'type': 'string'},
|
||||
'amount': {'default': 0,
|
||||
'description': 'Amount of currency',
|
||||
'minimum': 0.0,
|
||||
'title': 'Amount',
|
||||
'type': 'number'}},
|
||||
'required': ['currency'],
|
||||
'title': 'Currency',
|
||||
'type': 'object',
|
||||
'description': 'Base currency: amount and currency symbol'},
|
||||
'quote_currency': {'enum': ['USD', 'EUR'],
|
||||
'type': 'string',
|
||||
'default': 'USD',
|
||||
'description': 'Quote currency symbol'}},
|
||||
'required': ['base']}}]
|
||||
```
|
||||
|
||||
For more in-depth examples, please check the following:
|
||||
|
||||
- Currency calculator examples - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_function_call_currency_calculator.ipynb)
|
||||
|
||||
- Use Provided Tools as Functions - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_function_call.ipynb)
|
||||
|
||||
- Use Tools via Sync and Async Function Calling - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_function_call_async.ipynb)
|
||||
|
||||
|
||||
## Multi-agent Conversations
|
||||
|
||||
|
@ -77,6 +232,7 @@ By adopting the conversation-driven control with both programming language and n
|
|||
- LLM-based function call. In this approach, LLM decides whether or not to call a particular function depending on the conversation status in each inference call.
|
||||
By messaging additional agents in the called functions, the LLM can drive dynamic multi-agent conversation. A working system showcasing this type of dynamic conversation can be found in the [multi-user math problem solving scenario](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_two_users.ipynb), where a student assistant would automatically resort to an expert using function calls.
|
||||
|
||||
|
||||
### Diverse Applications Implemented with AutoGen
|
||||
|
||||
The figure below shows six examples of applications built using AutoGen.
|
||||
|
|
Loading…
Reference in New Issue