Organize some more modules (#48)

* Organize some more modules

* cleanup model_client
This commit is contained in:
Jack Gerrits 2024-06-04 11:13:13 -04:00 committed by GitHub
parent 19570fdd98
commit ed0229734d
16 changed files with 166 additions and 172 deletions

View File

@ -15,7 +15,7 @@ pip install azure-identity
## Using the Model Client
```python
from agnext.components.model_client import AzureOpenAI
from agnext.components.llm import AzureOpenAI
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
# Create the token provider

View File

@ -16,8 +16,7 @@ from agnext.chat.types import TextMessage
from agnext.components.function_executor._impl.in_process_function_executor import (
InProcessFunctionExecutor,
)
from agnext.components.model_client import OpenAI
from agnext.components.types import SystemMessage
from agnext.components.llm import OpenAI, SystemMessage
from agnext.core import Agent, AgentRuntime
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
from tavily import TavilyClient

View File

@ -5,8 +5,6 @@ from typing import Any, Coroutine, Dict, List, Mapping, Tuple
from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.types import (
FunctionCallMessage,
FunctionExecutionResult,
FunctionExecutionResultMessage,
Message,
Reset,
RespondNow,
@ -15,12 +13,11 @@ from agnext.chat.types import (
)
from agnext.chat.utils import convert_messages_to_llm_messages
from agnext.components.function_executor import FunctionExecutor
from agnext.components.model_client import ModelClient
from agnext.components.llm import FunctionExecutionResult, FunctionExecutionResultMessage, ModelClient, SystemMessage
from agnext.components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.components.types import (
FunctionCall,
FunctionSignature,
SystemMessage,
)
from agnext.core import AgentRuntime, CancellationToken
@ -141,7 +138,7 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
results.append(FunctionExecutionResult(content=execution_result, call_id=call_id))
# Create a tool call result message.
tool_call_result_msg = FunctionExecutionResultMessage(content=results, source=self.name)
tool_call_result_msg = FunctionExecutionResultMessage(content=results)
# Add tool call result message.
self._chat_messages.append(tool_call_result_msg)

View File

@ -5,6 +5,7 @@ from enum import Enum
from typing import List, Union
from agnext.components.image import Image
from agnext.components.llm import FunctionExecutionResultMessage
from agnext.components.types import FunctionCall
@ -29,17 +30,6 @@ class FunctionCallMessage(BaseMessage):
content: List[FunctionCall]
@dataclass
class FunctionExecutionResult:
content: str
call_id: str
@dataclass
class FunctionExecutionResultMessage(BaseMessage):
content: List[FunctionExecutionResult]
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]

View File

@ -4,22 +4,17 @@ from typing_extensions import Literal
from agnext.chat.types import (
FunctionCallMessage,
FunctionExecutionResultMessage,
Message,
MultiModalMessage,
TextMessage,
)
from agnext.components.types import (
from agnext.components.llm import (
AssistantMessage,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
UserMessage,
)
from agnext.components.types import (
FunctionExecutionResult as FunctionExecutionResultType,
)
from agnext.components.types import (
FunctionExecutionResultMessage as FunctionExecutionResultMessageType,
)
def convert_content_message_to_assistant_message(
@ -61,11 +56,11 @@ def convert_content_message_to_user_message(
def convert_tool_call_response_message(
message: FunctionExecutionResultMessage,
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[FunctionExecutionResultMessageType]:
) -> Optional[FunctionExecutionResultMessage]:
match message:
case FunctionExecutionResultMessage():
return FunctionExecutionResultMessageType(
content=[FunctionExecutionResultType(content=x.content, call_id=x.call_id) for x in message.content]
return FunctionExecutionResultMessage(
content=[FunctionExecutionResult(content=x.content, call_id=x.call_id) for x in message.content]
)
@ -93,7 +88,7 @@ def convert_messages_to_llm_messages(
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
if converted_message_2 is not None:
result.append(converted_message_2)
case FunctionExecutionResultMessage(_, source=source) if source == self_name:
case FunctionExecutionResultMessage(_):
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
if converted_message_3 is not None:
result.append(converted_message_3)

View File

@ -20,7 +20,7 @@ from typing import (
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Literal
from .pydantic_compat import evaluate_forwardref, model_dump, type2schema
from ._pydantic_compat import evaluate_forwardref, model_dump, type2schema
logger = getLogger(__name__)

View File

@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, Protocol, TypedDict, Union, runtime_chec
from typing_extensions import NotRequired, Required
from ..function_utils import get_function_schema
from .._function_utils import get_function_schema
from ..types import FunctionSignature

View File

@ -0,0 +1,32 @@
from ._model_client import ModelCapabilities, ModelClient
from ._openai_client import (
AzureOpenAI,
OpenAI,
)
from ._types import (
AssistantMessage,
CreateResult,
FinishReasons,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
__all__ = [
"AzureOpenAI",
"OpenAI",
"ModelCapabilities",
"ModelClient",
"SystemMessage",
"UserMessage",
"AssistantMessage",
"FunctionExecutionResult",
"FunctionExecutionResultMessage",
"LLMMessage",
"RequestUsage",
"FinishReasons",
"CreateResult",
]

View File

@ -11,7 +11,8 @@ from typing_extensions import (
Union,
)
from ..types import CreateResult, FunctionSignature, LLMMessage, RequestUsage
from ..types import FunctionSignature
from ._types import CreateResult, LLMMessage, RequestUsage
class ModelCapabilities(TypedDict, total=False):

View File

@ -4,11 +4,8 @@ import warnings
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
@ -30,25 +27,26 @@ from openai.types.chat import (
ChatCompletionUserMessageParam,
completion_create_params,
)
from typing_extensions import Required, TypedDict, Unpack
from typing_extensions import Unpack
from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent
# from ..._pydantic import type2schema
from ..image import Image
from ..types import (
FunctionCall,
FunctionSignature,
)
from . import _model_info
from ._model_client import ModelCapabilities, ModelClient
from ._types import (
AssistantMessage,
CreateResult,
FunctionCall,
FunctionExecutionResultMessage,
FunctionSignature,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
from . import _model_info
from ._model_client import ModelCapabilities, ModelClient
from .config import AzureOpenAIClientConfiguration, OpenAIClientConfiguration
logger = logging.getLogger(EVENT_LOGGER_NAME)
@ -202,53 +200,6 @@ def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
)
class ResponseFormat(TypedDict):
type: Literal["text", "json_object"]
class CreateArguments(TypedDict, total=False):
frequency_penalty: Optional[float]
logit_bias: Optional[Dict[str, int]]
max_tokens: Optional[int]
n: Optional[int]
presence_penalty: Optional[float]
response_format: ResponseFormat
seed: Optional[int]
stop: Union[Optional[str], List[str]]
temperature: Optional[float]
top_p: Optional[float]
user: str
AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]]
class BaseOpenAIClientConfiguration(CreateArguments, total=False):
model: str
api_key: str
timeout: Union[float, None]
max_retries: int
# See OpenAI docs for explanation of these parameters
class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
organization: str
base_url: str
# Not required
model_capabilities: ModelCapabilities
class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
# Azure specific
azure_endpoint: Required[str]
azure_deployment: str
api_version: Required[str]
azure_ad_token: str
azure_ad_token_provider: AsyncAzureADTokenProvider
# Must be provided
model_capabilities: Required[ModelCapabilities]
def convert_functions(
functions: Sequence[FunctionSignature],
) -> List[ChatCompletionToolParam]:

View File

@ -0,0 +1,57 @@
from dataclasses import dataclass
from typing import List, Literal, Union
from ..image import Image
from ..types import FunctionCall
@dataclass
class SystemMessage:
content: str
@dataclass
class UserMessage:
content: Union[str, List[Union[str, Image]]]
# Name of the agent that sent this message
source: str
@dataclass
class AssistantMessage:
content: Union[str, List[FunctionCall]]
# Name of the agent that sent this message
source: str
@dataclass
class FunctionExecutionResult:
content: str
call_id: str
@dataclass
class FunctionExecutionResultMessage:
content: List[FunctionExecutionResult]
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
@dataclass
class RequestUsage:
prompt_tokens: int
completion_tokens: int
FinishReasons = Literal["stop", "length", "function_calls", "content_filter"]
@dataclass
class CreateResult:
finish_reason: FinishReasons
content: Union[str, List[FunctionCall]]
usage: RequestUsage
cached: bool

View File

@ -0,0 +1,52 @@
from typing import Awaitable, Callable, Dict, List, Literal, Optional, Union
from typing_extensions import Required, TypedDict
from .._model_client import ModelCapabilities
class ResponseFormat(TypedDict):
type: Literal["text", "json_object"]
class CreateArguments(TypedDict, total=False):
frequency_penalty: Optional[float]
logit_bias: Optional[Dict[str, int]]
max_tokens: Optional[int]
n: Optional[int]
presence_penalty: Optional[float]
response_format: ResponseFormat
seed: Optional[int]
stop: Union[Optional[str], List[str]]
temperature: Optional[float]
top_p: Optional[float]
user: str
AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]]
class BaseOpenAIClientConfiguration(CreateArguments, total=False):
model: str
api_key: str
timeout: Union[float, None]
max_retries: int
# See OpenAI docs for explanation of these parameters
class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
organization: str
base_url: str
# Not required
model_capabilities: ModelCapabilities
class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
# Azure specific
azure_endpoint: Required[str]
azure_deployment: str
api_version: Required[str]
azure_ad_token: str
azure_ad_token_provider: AsyncAzureADTokenProvider
# Must be provided
model_capabilities: Required[ModelCapabilities]

View File

@ -1,24 +0,0 @@
from ._model_client import ModelCapabilities, ModelClient
from ._openai_client import (
AsyncAzureADTokenProvider,
AzureOpenAI,
AzureOpenAIClientConfiguration,
BaseOpenAIClientConfiguration,
CreateArguments,
OpenAI,
OpenAIClientConfiguration,
ResponseFormat,
)
__all__ = [
"AzureOpenAI",
"OpenAI",
"OpenAIClientConfiguration",
"AzureOpenAIClientConfiguration",
"ResponseFormat",
"CreateArguments",
"AsyncAzureADTokenProvider",
"BaseOpenAIClientConfiguration",
"ModelCapabilities",
"ModelClient",
]

View File

@ -1,11 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from typing_extensions import Literal
from .image import Image
from typing import Any, Dict
@dataclass
@ -22,55 +18,3 @@ class FunctionSignature:
name: str
parameters: Dict[str, Any]
description: str
@dataclass
class RequestUsage:
prompt_tokens: int
completion_tokens: int
@dataclass
class SystemMessage:
content: str
@dataclass
class UserMessage:
content: Union[str, List[Union[str, Image]]]
# Name of the agent that sent this message
source: str
@dataclass
class AssistantMessage:
content: Union[str, List[FunctionCall]]
# Name of the agent that sent this message
source: str
@dataclass
class FunctionExecutionResult:
content: str
call_id: str
@dataclass
class FunctionExecutionResultMessage:
content: List[FunctionExecutionResult]
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
FinishReasons = Literal["stop", "length", "function_calls", "content_filter"]
@dataclass
class CreateResult:
finish_reason: FinishReasons
content: Union[str, List[FunctionCall]]
usage: RequestUsage
cached: bool