Adds a standard logging / log-printing class to TeamOne (#194)

* Added initial code for TeamOne utils.

* Fixed hatch errors.

* Updated examples.

* Fixed more hatch errors.

* examples/example_coder.py

* Added standard logging for TeamOne

* Read time from log record.
This commit is contained in:
afourney 2024-07-09 13:51:05 -07:00 committed by GitHub
parent 05e72084e8
commit 699f024a6d
6 changed files with 75 additions and 66 deletions

View File

@ -1,11 +1,13 @@
import asyncio
import logging
from agnext.application import SingleThreadedAgentRuntime
from agnext.components.models import UserMessage
from agnext.application.logging import EVENT_LOGGER_NAME
from team_one.agents.coder import Coder, Executor
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.messages import BroadcastMessage
from team_one.utils import create_completion_client_from_env
from team_one.agents.user_proxy import UserProxy
from team_one.messages import RequestReplyMessage
from team_one.utils import LogHandler, create_completion_client_from_env
async def main() -> None:
@ -17,25 +19,24 @@ async def main() -> None:
"Coder",
lambda: Coder(model_client=create_completion_client_from_env()),
)
executor = runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor]))
task = input("Enter a task: ")
run_context = runtime.start()
await runtime.publish_message(
BroadcastMessage(content=UserMessage(content=task, source="human")), namespace="default"
user_proxy = runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(),
)
# Run the runtime until the task is completed.
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy]))
run_context = runtime.start()
await runtime.send_message(RequestReplyMessage(), user_proxy.id)
await run_context.stop_when_idle()
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.setLevel(logging.INFO)
log_handler = LogHandler()
logger.handlers = [log_handler]
asyncio.run(main())

View File

@ -1,37 +1,42 @@
import asyncio
import logging
from agnext.application import SingleThreadedAgentRuntime
from agnext.components.models import UserMessage
from agnext.application.logging import EVENT_LOGGER_NAME
from team_one.agents.file_surfer import FileSurfer
from team_one.messages import BroadcastMessage, RequestReplyMessage
from team_one.utils import create_completion_client_from_env
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.user_proxy import UserProxy
from team_one.messages import RequestReplyMessage
from team_one.utils import LogHandler, create_completion_client_from_env
async def main() -> None:
# Create the runtime.
runtime = SingleThreadedAgentRuntime()
# Get an appropriate client
client = create_completion_client_from_env()
# Register agents.
file_surfer = runtime.register_and_get(
file_surfer = runtime.register_and_get_proxy(
"file_surfer",
lambda: FileSurfer(model_client=create_completion_client_from_env()),
lambda: FileSurfer(model_client=client),
)
task = input(f"Enter a task for {file_surfer.name}: ")
msg = BroadcastMessage(content=UserMessage(content=task, source="human"))
user_proxy = runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(),
)
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy]))
run_context = runtime.start()
# Send a task to the tool user.
await runtime.publish_message(msg, namespace="default")
await runtime.publish_message(RequestReplyMessage(), namespace="default")
# Run the runtime until the task is completed.
await runtime.send_message(RequestReplyMessage(), user_proxy.id)
await run_context.stop_when_idle()
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.setLevel(logging.INFO)
log_handler = LogHandler()
logger.handlers = [log_handler]
asyncio.run(main())

View File

@ -1,10 +1,13 @@
import asyncio
import logging
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.components.models import UserMessage
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.reflex_agents import ReflexAgent
from team_one.messages import BroadcastMessage
from team_one.utils import LogHandler
async def main() -> None:
@ -23,9 +26,8 @@ async def main() -> None:
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.setLevel(logging.INFO)
log_handler = LogHandler()
logger.handlers = [log_handler]
asyncio.run(main())

View File

@ -1,6 +0,0 @@
"""
A team that use multiple agents including, coder and file_surfer
to solve tasks.
"""
# TODO: Add code here to implement the team.

View File

@ -7,8 +7,8 @@ from agnext.application.logging import EVENT_LOGGER_NAME
from team_one.agents.coder import Coder
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.user_proxy import UserProxy
from team_one.messages import OrchestrationEvent, RequestReplyMessage
from team_one.utils import create_completion_client_from_env
from team_one.messages import RequestReplyMessage
from team_one.utils import LogHandler, create_completion_client_from_env
async def main() -> None:
@ -35,27 +35,9 @@ async def main() -> None:
await run_context.stop_when_idle()
class MyHandler(logging.Handler):
def __init__(self) -> None:
super().__init__()
def emit(self, record: logging.LogRecord) -> None:
try:
if isinstance(record.msg, OrchestrationEvent):
print(
f"""---------------------------------------------------------------------------
\033[91m{record.msg.source}:\033[0m
{record.msg.message}""",
flush=True,
)
except Exception:
self.handleError(record)
if __name__ == "__main__":
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.setLevel(logging.INFO)
my_handler = MyHandler()
logger.handlers = [my_handler]
log_handler = LogHandler()
logger.handlers = [log_handler]
asyncio.run(main())

View File

@ -1,5 +1,7 @@
import json
import logging
import os
from datetime import datetime
from typing import Any, Dict
from agnext.components.models import (
@ -9,6 +11,8 @@ from agnext.components.models import (
OpenAIChatCompletionClient,
)
from .messages import OrchestrationEvent
ENVIRON_KEY_CHAT_COMPLETION_PROVIDER = "CHAT_COMPLETION_PROVIDER"
ENVIRON_KEY_CHAT_COMPLETION_KWARGS_JSON = "CHAT_COMPLETION_KWARGS_JSON"
@ -64,3 +68,24 @@ def create_completion_client_from_env(env: Dict[str, str] | None = None, **kwarg
return AzureOpenAIChatCompletionClient(**_kwargs)
else:
raise ValueError(f"Unknown OAI provider '{_provider}'")
# TeamOne log event handler
class LogHandler(logging.Handler):
def __init__(self) -> None:
super().__init__()
def emit(self, record: logging.LogRecord) -> None:
try:
if isinstance(record.msg, OrchestrationEvent):
ts = datetime.fromtimestamp(record.created).isoformat()
print(
f"""
---------------------------------------------------------------------------
\033[91m[{ts}], {record.msg.source}:\033[0m
{record.msg.message}""",
flush=True,
)
except Exception:
self.handleError(record)