mirror of https://github.com/microsoft/autogen.git
Organize some more modules (#48)
* Organize some more modules * cleanup model_client
This commit is contained in:
parent
19570fdd98
commit
ed0229734d
|
@ -15,7 +15,7 @@ pip install azure-identity
|
|||
## Using the Model Client
|
||||
|
||||
```python
|
||||
from agnext.components.model_client import AzureOpenAI
|
||||
from agnext.components.llm import AzureOpenAI
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
|
||||
# Create the token provider
|
||||
|
|
|
@ -16,8 +16,7 @@ from agnext.chat.types import TextMessage
|
|||
from agnext.components.function_executor._impl.in_process_function_executor import (
|
||||
InProcessFunctionExecutor,
|
||||
)
|
||||
from agnext.components.model_client import OpenAI
|
||||
from agnext.components.types import SystemMessage
|
||||
from agnext.components.llm import OpenAI, SystemMessage
|
||||
from agnext.core import Agent, AgentRuntime
|
||||
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
||||
from tavily import TavilyClient
|
||||
|
|
|
@ -5,8 +5,6 @@ from typing import Any, Coroutine, Dict, List, Mapping, Tuple
|
|||
from agnext.chat.agents.base import BaseChatAgent
|
||||
from agnext.chat.types import (
|
||||
FunctionCallMessage,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
Message,
|
||||
Reset,
|
||||
RespondNow,
|
||||
|
@ -15,12 +13,11 @@ from agnext.chat.types import (
|
|||
)
|
||||
from agnext.chat.utils import convert_messages_to_llm_messages
|
||||
from agnext.components.function_executor import FunctionExecutor
|
||||
from agnext.components.model_client import ModelClient
|
||||
from agnext.components.llm import FunctionExecutionResult, FunctionExecutionResultMessage, ModelClient, SystemMessage
|
||||
from agnext.components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from agnext.components.types import (
|
||||
FunctionCall,
|
||||
FunctionSignature,
|
||||
SystemMessage,
|
||||
)
|
||||
from agnext.core import AgentRuntime, CancellationToken
|
||||
|
||||
|
@ -141,7 +138,7 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
|
|||
results.append(FunctionExecutionResult(content=execution_result, call_id=call_id))
|
||||
|
||||
# Create a tool call result message.
|
||||
tool_call_result_msg = FunctionExecutionResultMessage(content=results, source=self.name)
|
||||
tool_call_result_msg = FunctionExecutionResultMessage(content=results)
|
||||
|
||||
# Add tool call result message.
|
||||
self._chat_messages.append(tool_call_result_msg)
|
||||
|
|
|
@ -5,6 +5,7 @@ from enum import Enum
|
|||
from typing import List, Union
|
||||
|
||||
from agnext.components.image import Image
|
||||
from agnext.components.llm import FunctionExecutionResultMessage
|
||||
from agnext.components.types import FunctionCall
|
||||
|
||||
|
||||
|
@ -29,17 +30,6 @@ class FunctionCallMessage(BaseMessage):
|
|||
content: List[FunctionCall]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResult:
|
||||
content: str
|
||||
call_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResultMessage(BaseMessage):
|
||||
content: List[FunctionExecutionResult]
|
||||
|
||||
|
||||
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]
|
||||
|
||||
|
||||
|
|
|
@ -4,22 +4,17 @@ from typing_extensions import Literal
|
|||
|
||||
from agnext.chat.types import (
|
||||
FunctionCallMessage,
|
||||
FunctionExecutionResultMessage,
|
||||
Message,
|
||||
MultiModalMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from agnext.components.types import (
|
||||
from agnext.components.llm import (
|
||||
AssistantMessage,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.components.types import (
|
||||
FunctionExecutionResult as FunctionExecutionResultType,
|
||||
)
|
||||
from agnext.components.types import (
|
||||
FunctionExecutionResultMessage as FunctionExecutionResultMessageType,
|
||||
)
|
||||
|
||||
|
||||
def convert_content_message_to_assistant_message(
|
||||
|
@ -61,11 +56,11 @@ def convert_content_message_to_user_message(
|
|||
def convert_tool_call_response_message(
|
||||
message: FunctionExecutionResultMessage,
|
||||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
|
||||
) -> Optional[FunctionExecutionResultMessageType]:
|
||||
) -> Optional[FunctionExecutionResultMessage]:
|
||||
match message:
|
||||
case FunctionExecutionResultMessage():
|
||||
return FunctionExecutionResultMessageType(
|
||||
content=[FunctionExecutionResultType(content=x.content, call_id=x.call_id) for x in message.content]
|
||||
return FunctionExecutionResultMessage(
|
||||
content=[FunctionExecutionResult(content=x.content, call_id=x.call_id) for x in message.content]
|
||||
)
|
||||
|
||||
|
||||
|
@ -93,7 +88,7 @@ def convert_messages_to_llm_messages(
|
|||
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
|
||||
if converted_message_2 is not None:
|
||||
result.append(converted_message_2)
|
||||
case FunctionExecutionResultMessage(_, source=source) if source == self_name:
|
||||
case FunctionExecutionResultMessage(_):
|
||||
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
|
||||
if converted_message_3 is not None:
|
||||
result.append(converted_message_3)
|
||||
|
|
|
@ -20,7 +20,7 @@ from typing import (
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, Literal
|
||||
|
||||
from .pydantic_compat import evaluate_forwardref, model_dump, type2schema
|
||||
from ._pydantic_compat import evaluate_forwardref, model_dump, type2schema
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, Protocol, TypedDict, Union, runtime_chec
|
|||
|
||||
from typing_extensions import NotRequired, Required
|
||||
|
||||
from ..function_utils import get_function_schema
|
||||
from .._function_utils import get_function_schema
|
||||
from ..types import FunctionSignature
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
from ._model_client import ModelCapabilities, ModelClient
|
||||
from ._openai_client import (
|
||||
AzureOpenAI,
|
||||
OpenAI,
|
||||
)
|
||||
from ._types import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FinishReasons,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureOpenAI",
|
||||
"OpenAI",
|
||||
"ModelCapabilities",
|
||||
"ModelClient",
|
||||
"SystemMessage",
|
||||
"UserMessage",
|
||||
"AssistantMessage",
|
||||
"FunctionExecutionResult",
|
||||
"FunctionExecutionResultMessage",
|
||||
"LLMMessage",
|
||||
"RequestUsage",
|
||||
"FinishReasons",
|
||||
"CreateResult",
|
||||
]
|
|
@ -11,7 +11,8 @@ from typing_extensions import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from ..types import CreateResult, FunctionSignature, LLMMessage, RequestUsage
|
||||
from ..types import FunctionSignature
|
||||
from ._types import CreateResult, LLMMessage, RequestUsage
|
||||
|
||||
|
||||
class ModelCapabilities(TypedDict, total=False):
|
|
@ -4,11 +4,8 @@ import warnings
|
|||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
|
@ -30,25 +27,26 @@ from openai.types.chat import (
|
|||
ChatCompletionUserMessageParam,
|
||||
completion_create_params,
|
||||
)
|
||||
from typing_extensions import Required, TypedDict, Unpack
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent
|
||||
|
||||
# from ..._pydantic import type2schema
|
||||
from ..image import Image
|
||||
from ..types import (
|
||||
FunctionCall,
|
||||
FunctionSignature,
|
||||
)
|
||||
from . import _model_info
|
||||
from ._model_client import ModelCapabilities, ModelClient
|
||||
from ._types import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FunctionCall,
|
||||
FunctionExecutionResultMessage,
|
||||
FunctionSignature,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from . import _model_info
|
||||
from ._model_client import ModelCapabilities, ModelClient
|
||||
from .config import AzureOpenAIClientConfiguration, OpenAIClientConfiguration
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
@ -202,53 +200,6 @@ def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
|
|||
)
|
||||
|
||||
|
||||
class ResponseFormat(TypedDict):
|
||||
type: Literal["text", "json_object"]
|
||||
|
||||
|
||||
class CreateArguments(TypedDict, total=False):
|
||||
frequency_penalty: Optional[float]
|
||||
logit_bias: Optional[Dict[str, int]]
|
||||
max_tokens: Optional[int]
|
||||
n: Optional[int]
|
||||
presence_penalty: Optional[float]
|
||||
response_format: ResponseFormat
|
||||
seed: Optional[int]
|
||||
stop: Union[Optional[str], List[str]]
|
||||
temperature: Optional[float]
|
||||
top_p: Optional[float]
|
||||
user: str
|
||||
|
||||
|
||||
AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]]
|
||||
|
||||
|
||||
class BaseOpenAIClientConfiguration(CreateArguments, total=False):
|
||||
model: str
|
||||
api_key: str
|
||||
timeout: Union[float, None]
|
||||
max_retries: int
|
||||
|
||||
|
||||
# See OpenAI docs for explanation of these parameters
|
||||
class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
|
||||
organization: str
|
||||
base_url: str
|
||||
# Not required
|
||||
model_capabilities: ModelCapabilities
|
||||
|
||||
|
||||
class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
|
||||
# Azure specific
|
||||
azure_endpoint: Required[str]
|
||||
azure_deployment: str
|
||||
api_version: Required[str]
|
||||
azure_ad_token: str
|
||||
azure_ad_token_provider: AsyncAzureADTokenProvider
|
||||
# Must be provided
|
||||
model_capabilities: Required[ModelCapabilities]
|
||||
|
||||
|
||||
def convert_functions(
|
||||
functions: Sequence[FunctionSignature],
|
||||
) -> List[ChatCompletionToolParam]:
|
|
@ -0,0 +1,57 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from ..image import Image
|
||||
from ..types import FunctionCall
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessage:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage:
|
||||
content: Union[str, List[Union[str, Image]]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantMessage:
|
||||
content: Union[str, List[FunctionCall]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResult:
|
||||
content: str
|
||||
call_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResultMessage:
|
||||
content: List[FunctionExecutionResult]
|
||||
|
||||
|
||||
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestUsage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
|
||||
FinishReasons = Literal["stop", "length", "function_calls", "content_filter"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateResult:
|
||||
finish_reason: FinishReasons
|
||||
content: Union[str, List[FunctionCall]]
|
||||
usage: RequestUsage
|
||||
cached: bool
|
|
@ -0,0 +1,52 @@
|
|||
from typing import Awaitable, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from .._model_client import ModelCapabilities
|
||||
|
||||
|
||||
class ResponseFormat(TypedDict):
|
||||
type: Literal["text", "json_object"]
|
||||
|
||||
|
||||
class CreateArguments(TypedDict, total=False):
|
||||
frequency_penalty: Optional[float]
|
||||
logit_bias: Optional[Dict[str, int]]
|
||||
max_tokens: Optional[int]
|
||||
n: Optional[int]
|
||||
presence_penalty: Optional[float]
|
||||
response_format: ResponseFormat
|
||||
seed: Optional[int]
|
||||
stop: Union[Optional[str], List[str]]
|
||||
temperature: Optional[float]
|
||||
top_p: Optional[float]
|
||||
user: str
|
||||
|
||||
|
||||
AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]]
|
||||
|
||||
|
||||
class BaseOpenAIClientConfiguration(CreateArguments, total=False):
|
||||
model: str
|
||||
api_key: str
|
||||
timeout: Union[float, None]
|
||||
max_retries: int
|
||||
|
||||
|
||||
# See OpenAI docs for explanation of these parameters
|
||||
class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
|
||||
organization: str
|
||||
base_url: str
|
||||
# Not required
|
||||
model_capabilities: ModelCapabilities
|
||||
|
||||
|
||||
class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
|
||||
# Azure specific
|
||||
azure_endpoint: Required[str]
|
||||
azure_deployment: str
|
||||
api_version: Required[str]
|
||||
azure_ad_token: str
|
||||
azure_ad_token_provider: AsyncAzureADTokenProvider
|
||||
# Must be provided
|
||||
model_capabilities: Required[ModelCapabilities]
|
|
@ -1,24 +0,0 @@
|
|||
from ._model_client import ModelCapabilities, ModelClient
|
||||
from ._openai_client import (
|
||||
AsyncAzureADTokenProvider,
|
||||
AzureOpenAI,
|
||||
AzureOpenAIClientConfiguration,
|
||||
BaseOpenAIClientConfiguration,
|
||||
CreateArguments,
|
||||
OpenAI,
|
||||
OpenAIClientConfiguration,
|
||||
ResponseFormat,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureOpenAI",
|
||||
"OpenAI",
|
||||
"OpenAIClientConfiguration",
|
||||
"AzureOpenAIClientConfiguration",
|
||||
"ResponseFormat",
|
||||
"CreateArguments",
|
||||
"AsyncAzureADTokenProvider",
|
||||
"BaseOpenAIClientConfiguration",
|
||||
"ModelCapabilities",
|
||||
"ModelClient",
|
||||
]
|
|
@ -1,11 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .image import Image
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -22,55 +18,3 @@ class FunctionSignature:
|
|||
name: str
|
||||
parameters: Dict[str, Any]
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestUsage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessage:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage:
|
||||
content: Union[str, List[Union[str, Image]]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantMessage:
|
||||
content: Union[str, List[FunctionCall]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResult:
|
||||
content: str
|
||||
call_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResultMessage:
|
||||
content: List[FunctionExecutionResult]
|
||||
|
||||
|
||||
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
|
||||
|
||||
|
||||
FinishReasons = Literal["stop", "length", "function_calls", "content_filter"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateResult:
|
||||
finish_reason: FinishReasons
|
||||
content: Union[str, List[FunctionCall]]
|
||||
usage: RequestUsage
|
||||
cached: bool
|
||||
|
|
Loading…
Reference in New Issue