mirror of https://github.com/microsoft/autogen.git
Fix threading issue for logging (#1901)
* fix mypy errors for logging * cleanup * formatting * fix threading issue in logging * remove kwarg --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
2a62ffc566
commit
811fd6926d
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, List, TYPE_CHECKING, Union
|
||||
import sqlite3
|
||||
import uuid
|
||||
|
||||
|
@ -11,6 +11,9 @@ from openai.types.chat import ChatCompletion
|
|||
if TYPE_CHECKING:
|
||||
from autogen import ConversableAgent, OpenAIWrapper
|
||||
|
||||
ConfigItem = Dict[str, Union[str, List[str]]]
|
||||
LLMConfig = Dict[str, Union[None, float, int, ConfigItem, List[ConfigItem]]]
|
||||
|
||||
|
||||
class BaseLogger(ABC):
|
||||
@abstractmethod
|
||||
|
@ -25,10 +28,11 @@ class BaseLogger(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def log_chat_completion(
|
||||
self,
|
||||
invocation_id: uuid.UUID,
|
||||
client_id: int,
|
||||
wrapper_id: int,
|
||||
request: Dict,
|
||||
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
|
||||
response: Union[str, ChatCompletion],
|
||||
is_cached: int,
|
||||
cost: float,
|
||||
|
@ -54,7 +58,7 @@ class BaseLogger(ABC):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
def log_new_agent(agent: ConversableAgent, init_args: Dict) -> None:
|
||||
def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Log the birth of a new agent.
|
||||
|
||||
|
@ -65,7 +69,7 @@ class BaseLogger(ABC):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict) -> None:
|
||||
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
|
||||
"""
|
||||
Log the birth of a new OpenAIWrapper.
|
||||
|
||||
|
@ -76,7 +80,9 @@ class BaseLogger(ABC):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict) -> None:
|
||||
def log_new_client(
|
||||
self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Log the birth of a new OpenAIWrapper.
|
||||
|
||||
|
@ -87,14 +93,14 @@ class BaseLogger(ABC):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
def stop() -> None:
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Close the connection to the logging database, and stop logging.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_connection() -> Union[sqlite3.Connection]:
|
||||
def get_connection(self) -> Union[None, sqlite3.Connection]:
|
||||
"""
|
||||
Return a connection to the logging database.
|
||||
"""
|
||||
|
|
|
@ -5,14 +5,14 @@ from typing import Any, Dict, List, Tuple, Union
|
|||
__all__ = ("get_current_ts", "to_dict")
|
||||
|
||||
|
||||
def get_current_ts():
|
||||
def get_current_ts() -> str:
|
||||
return datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
|
||||
|
||||
def to_dict(
|
||||
obj: Union[int, float, str, bool, Dict[Any, Any], List[Any], Tuple[Any, ...], Any],
|
||||
exclude: Tuple[str] = (),
|
||||
no_recursive: Tuple[str] = (),
|
||||
exclude: Tuple[str, ...] = (),
|
||||
no_recursive: Tuple[Any, ...] = (),
|
||||
) -> Any:
|
||||
if isinstance(obj, (int, float, str, bool)):
|
||||
return obj
|
||||
|
|
|
@ -4,7 +4,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from autogen.logger.base_logger import BaseLogger
|
||||
|
@ -12,17 +12,15 @@ from autogen.logger.logger_utils import get_current_ts, to_dict
|
|||
|
||||
from openai import OpenAI, AzureOpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
from typing import Dict, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, List, TYPE_CHECKING, Tuple, Union
|
||||
from .base_logger import LLMConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogen import ConversableAgent, OpenAIWrapper
|
||||
|
||||
|
||||
# this is a pointer to the module object instance itself
|
||||
this = sys.modules[__name__]
|
||||
this._session_id = None
|
||||
logger = logging.getLogger(__name__)
|
||||
lock = threading.Lock()
|
||||
|
||||
__all__ = ("SqliteLogger",)
|
||||
|
||||
|
@ -30,19 +28,19 @@ __all__ = ("SqliteLogger",)
|
|||
class SqliteLogger(BaseLogger):
|
||||
schema_version = 1
|
||||
|
||||
def __init__(self, config):
|
||||
self.con = None
|
||||
self.cur = None
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
|
||||
def start(self) -> str:
|
||||
dbname = self.config["dbname"] if "dbname" in self.config else "logs.db"
|
||||
this._session_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
self.con = sqlite3.connect(dbname)
|
||||
self.dbname = self.config.get("dbname", "logs.db")
|
||||
self.con = sqlite3.connect(self.dbname, check_same_thread=False)
|
||||
self.cur = self.con.cursor()
|
||||
self.session_id = str(uuid.uuid4())
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"[SqliteLogger] Failed to connect to database {self.dbname}: {e}")
|
||||
|
||||
def start(self) -> str:
|
||||
try:
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS chat_completions(
|
||||
id INTEGER PRIMARY KEY,
|
||||
|
@ -57,8 +55,7 @@ class SqliteLogger(BaseLogger):
|
|||
start_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
end_time DATETIME DEFAULT CURRENT_TIMESTAMP)
|
||||
"""
|
||||
self.cur.execute(query)
|
||||
self.con.commit()
|
||||
self._run_query(query=query)
|
||||
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS agents (
|
||||
|
@ -72,8 +69,7 @@ class SqliteLogger(BaseLogger):
|
|||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(agent_id, session_id))
|
||||
"""
|
||||
self.cur.execute(query)
|
||||
self.con.commit()
|
||||
self._run_query(query=query)
|
||||
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS oai_wrappers (
|
||||
|
@ -84,8 +80,7 @@ class SqliteLogger(BaseLogger):
|
|||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(wrapper_id, session_id))
|
||||
"""
|
||||
self.cur.execute(query)
|
||||
self.con.commit()
|
||||
self._run_query(query=query)
|
||||
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS oai_clients (
|
||||
|
@ -98,8 +93,7 @@ class SqliteLogger(BaseLogger):
|
|||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(client_id, session_id))
|
||||
"""
|
||||
self.cur.execute(query)
|
||||
self.con.commit()
|
||||
self._run_query(query=query)
|
||||
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS version (
|
||||
|
@ -107,31 +101,30 @@ class SqliteLogger(BaseLogger):
|
|||
version_number INTEGER NOT NULL -- version of the logging database
|
||||
);
|
||||
"""
|
||||
self.cur.execute(query)
|
||||
self.con.commit()
|
||||
self._run_query(query=query)
|
||||
|
||||
current_verion = self._get_current_db_version()
|
||||
if current_verion is None:
|
||||
self.cur.execute(
|
||||
"INSERT INTO version (id, version_number) VALUES (1, ?);", (SqliteLogger.schema_version,)
|
||||
self._run_query(
|
||||
query="INSERT INTO version (id, version_number) VALUES (1, ?);", args=(SqliteLogger.schema_version,)
|
||||
)
|
||||
self.con.commit()
|
||||
|
||||
self._apply_migration(dbname)
|
||||
self._apply_migration()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"[SqliteLogger] start logging error: {e}")
|
||||
finally:
|
||||
return this._session_id
|
||||
return self.session_id
|
||||
|
||||
def _get_current_db_version(self):
|
||||
def _get_current_db_version(self) -> Union[None, int]:
|
||||
self.cur.execute("SELECT version_number FROM version ORDER BY id DESC LIMIT 1")
|
||||
result = self.cur.fetchone()
|
||||
return result[0] if result else None
|
||||
return result[0] if result is not None else None
|
||||
|
||||
# Example migration script name format: 002_update_agents_table.sql
|
||||
def _apply_migration(self, dbname, migrations_dir="./migrations"):
|
||||
def _apply_migration(self, migrations_dir: str = "./migrations") -> None:
|
||||
current_version = self._get_current_db_version()
|
||||
current_version = SqliteLogger.schema_version if current_version is None else current_version
|
||||
|
||||
if os.path.isdir(migrations_dir):
|
||||
migrations = sorted(os.listdir(migrations_dir))
|
||||
else:
|
||||
|
@ -143,19 +136,48 @@ class SqliteLogger(BaseLogger):
|
|||
for script in migrations_to_apply:
|
||||
with open(script, "r") as f:
|
||||
migration_sql = f.read()
|
||||
self.con.executescript(migration_sql)
|
||||
self.con.commit()
|
||||
self._run_query_script(script=migration_sql)
|
||||
|
||||
latest_version = int(script.split("_")[0])
|
||||
self.cur.execute("UPDATE version SET version_number = ? WHERE id = 1", (latest_version))
|
||||
query = "UPDATE version SET version_number = ? WHERE id = 1"
|
||||
args = (latest_version,)
|
||||
self._run_query(query=query, args=args)
|
||||
|
||||
def _run_query(self, query: str, args: Tuple[Any, ...] = ()) -> None:
|
||||
"""
|
||||
Executes a given SQL query.
|
||||
|
||||
Args:
|
||||
query (str): The SQL query to execute.
|
||||
args (Tuple): The arguments to pass to the SQL query.
|
||||
"""
|
||||
try:
|
||||
with lock:
|
||||
self.cur.execute(query, args)
|
||||
self.con.commit()
|
||||
except Exception as e:
|
||||
logger.error("[sqlite logger]Error running query with query %s and args %s: %s", query, args, e)
|
||||
|
||||
def _run_query_script(self, script: str) -> None:
|
||||
"""
|
||||
Executes SQL script.
|
||||
|
||||
Args:
|
||||
script (str): SQL script to execute.
|
||||
"""
|
||||
try:
|
||||
with lock:
|
||||
self.cur.executescript(script)
|
||||
self.con.commit()
|
||||
except Exception as e:
|
||||
logger.error("[sqlite logger]Error running query script %s: %s", script, e)
|
||||
|
||||
def log_chat_completion(
|
||||
self,
|
||||
invocation_id: uuid.UUID,
|
||||
client_id: int,
|
||||
wrapper_id: int,
|
||||
request: Dict,
|
||||
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
|
||||
response: Union[str, ChatCompletion],
|
||||
is_cached: int,
|
||||
cost: float,
|
||||
|
@ -176,28 +198,22 @@ class SqliteLogger(BaseLogger):
|
|||
invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
try:
|
||||
self.cur.execute(
|
||||
query,
|
||||
(
|
||||
args = (
|
||||
invocation_id,
|
||||
client_id,
|
||||
wrapper_id,
|
||||
this._session_id,
|
||||
self.session_id,
|
||||
json.dumps(request),
|
||||
response_messages,
|
||||
is_cached,
|
||||
cost,
|
||||
start_time,
|
||||
end_time,
|
||||
),
|
||||
)
|
||||
self.con.commit()
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"[SqliteLogger] log_chat_completion error: {e}")
|
||||
|
||||
def log_new_agent(self, agent: ConversableAgent, init_args: Dict) -> None:
|
||||
self._run_query(query=query, args=args)
|
||||
|
||||
def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
|
||||
from autogen import Agent
|
||||
|
||||
if self.con is None:
|
||||
|
@ -206,7 +222,7 @@ class SqliteLogger(BaseLogger):
|
|||
args = to_dict(
|
||||
init_args,
|
||||
exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint"),
|
||||
no_recursive=(Agent),
|
||||
no_recursive=(Agent,),
|
||||
)
|
||||
|
||||
# We do an upsert since both the superclass and subclass may call this method (in that order)
|
||||
|
@ -219,24 +235,18 @@ class SqliteLogger(BaseLogger):
|
|||
init_args = excluded.init_args,
|
||||
timestamp = excluded.timestamp
|
||||
"""
|
||||
try:
|
||||
self.cur.execute(
|
||||
query,
|
||||
(
|
||||
args = (
|
||||
id(agent),
|
||||
agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else "",
|
||||
this._session_id,
|
||||
self.session_id,
|
||||
agent.name if hasattr(agent, "name") and agent.name is not None else "",
|
||||
type(agent).__name__,
|
||||
json.dumps(args),
|
||||
get_current_ts(),
|
||||
),
|
||||
)
|
||||
self.con.commit()
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"[SqliteLogger] log_new_agent error: {e}")
|
||||
self._run_query(query=query, args=args)
|
||||
|
||||
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict) -> None:
|
||||
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
|
||||
if self.con is None:
|
||||
return
|
||||
|
||||
|
@ -248,21 +258,17 @@ class SqliteLogger(BaseLogger):
|
|||
INSERT INTO oai_wrappers (wrapper_id, session_id, init_args, timestamp) VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT (wrapper_id, session_id) DO NOTHING;
|
||||
"""
|
||||
try:
|
||||
self.cur.execute(
|
||||
query,
|
||||
(
|
||||
args = (
|
||||
id(wrapper),
|
||||
this._session_id,
|
||||
self.session_id,
|
||||
json.dumps(args),
|
||||
get_current_ts(),
|
||||
),
|
||||
)
|
||||
self.con.commit()
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"[SqliteLogger] log_new_wrapper error: {e}")
|
||||
self._run_query(query=query, args=args)
|
||||
|
||||
def log_new_client(self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict) -> None:
|
||||
def log_new_client(
|
||||
self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
|
||||
) -> None:
|
||||
if self.con is None:
|
||||
return
|
||||
|
||||
|
@ -274,28 +280,21 @@ class SqliteLogger(BaseLogger):
|
|||
INSERT INTO oai_clients (client_id, wrapper_id, session_id, class, init_args, timestamp) VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (client_id, session_id) DO NOTHING;
|
||||
"""
|
||||
try:
|
||||
self.cur.execute(
|
||||
query,
|
||||
(
|
||||
args = (
|
||||
id(client),
|
||||
id(wrapper),
|
||||
this._session_id,
|
||||
self.session_id,
|
||||
type(client).__name__,
|
||||
json.dumps(args),
|
||||
get_current_ts(),
|
||||
),
|
||||
)
|
||||
self.con.commit()
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"[SqliteLogger] log_new_client error: {e}")
|
||||
self._run_query(query=query, args=args)
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.con:
|
||||
self.con.close()
|
||||
self.con = None
|
||||
self.cur = None
|
||||
|
||||
def get_connection(self) -> sqlite3.Connection:
|
||||
def get_connection(self) -> Union[None, sqlite3.Connection]:
|
||||
if self.con:
|
||||
return self.con
|
||||
return None
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from autogen.logger.logger_factory import LoggerFactory
|
||||
from autogen.logger.base_logger import LLMConfig
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
import uuid
|
||||
|
||||
from openai import OpenAI, AzureOpenAI
|
||||
|
@ -11,6 +14,8 @@ from openai.types.chat import ChatCompletion
|
|||
if TYPE_CHECKING:
|
||||
from autogen import ConversableAgent, OpenAIWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
autogen_logger = None
|
||||
is_logging = False
|
||||
|
||||
|
@ -19,12 +24,14 @@ def start(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None)
|
|||
global autogen_logger
|
||||
global is_logging
|
||||
|
||||
if autogen_logger is None:
|
||||
autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
|
||||
|
||||
try:
|
||||
session_id = autogen_logger.start()
|
||||
is_logging = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[runtime logging] Failed to start logging: {e}")
|
||||
finally:
|
||||
return session_id
|
||||
|
||||
|
||||
|
@ -32,26 +39,42 @@ def log_chat_completion(
|
|||
invocation_id: uuid.UUID,
|
||||
client_id: int,
|
||||
wrapper_id: int,
|
||||
request: Dict,
|
||||
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
|
||||
response: Union[str, ChatCompletion],
|
||||
is_cached: int,
|
||||
cost: float,
|
||||
start_time: str,
|
||||
) -> None:
|
||||
if autogen_logger is None:
|
||||
logger.error("[runtime logging] log_chat_completion: autogen logger is None")
|
||||
return
|
||||
|
||||
autogen_logger.log_chat_completion(
|
||||
invocation_id, client_id, wrapper_id, request, response, is_cached, cost, start_time
|
||||
)
|
||||
|
||||
|
||||
def log_new_agent(agent: ConversableAgent, init_args: Dict) -> None:
|
||||
def log_new_agent(agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
|
||||
if autogen_logger is None:
|
||||
logger.error("[runtime logging] log_new_agent: autogen logger is None")
|
||||
return
|
||||
|
||||
autogen_logger.log_new_agent(agent, init_args)
|
||||
|
||||
|
||||
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict) -> None:
|
||||
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
|
||||
if autogen_logger is None:
|
||||
logger.error("[runtime logging] log_new_wrapper: autogen logger is None")
|
||||
return
|
||||
|
||||
autogen_logger.log_new_wrapper(wrapper, init_args)
|
||||
|
||||
|
||||
def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict) -> None:
|
||||
def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]) -> None:
|
||||
if autogen_logger is None:
|
||||
logger.error("[runtime logging] log_new_client: autogen logger is None")
|
||||
return
|
||||
|
||||
autogen_logger.log_new_client(client, wrapper, init_args)
|
||||
|
||||
|
||||
|
@ -62,7 +85,11 @@ def stop() -> None:
|
|||
is_logging = False
|
||||
|
||||
|
||||
def get_connection() -> Union[sqlite3.Connection]:
|
||||
def get_connection() -> Union[None, sqlite3.Connection]:
|
||||
if autogen_logger is None:
|
||||
logger.error("[runtime logging] get_connection: autogen logger is None")
|
||||
return None
|
||||
|
||||
return autogen_logger.get_connection()
|
||||
|
||||
|
||||
|
|
|
@ -265,4 +265,4 @@ def test_logging_exception_will_not_crash_only_print_error(mock_logger_error, db
|
|||
|
||||
args, _ = mock_logger_error.call_args
|
||||
error_message = args[0]
|
||||
assert error_message.startswith("[SqliteLogger] log_chat_completion error:")
|
||||
assert error_message.startswith("[sqlite logger]Error running query with query")
|
||||
|
|
Loading…
Reference in New Issue