Fix threading issue for logging (#1901)

* fix mypy errors for logging

* cleanup

* formatting

* fix threading issue in logging

* remove kwarg

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
cheng-tan 2024-03-07 19:08:22 -05:00 committed by GitHub
parent 2a62ffc566
commit 811fd6926d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 162 additions and 130 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, TYPE_CHECKING, Union from typing import Any, Dict, List, TYPE_CHECKING, Union
import sqlite3 import sqlite3
import uuid import uuid
@ -11,6 +11,9 @@ from openai.types.chat import ChatCompletion
if TYPE_CHECKING: if TYPE_CHECKING:
from autogen import ConversableAgent, OpenAIWrapper from autogen import ConversableAgent, OpenAIWrapper
ConfigItem = Dict[str, Union[str, List[str]]]
LLMConfig = Dict[str, Union[None, float, int, ConfigItem, List[ConfigItem]]]
class BaseLogger(ABC): class BaseLogger(ABC):
@abstractmethod @abstractmethod
@ -25,10 +28,11 @@ class BaseLogger(ABC):
@abstractmethod @abstractmethod
def log_chat_completion( def log_chat_completion(
self,
invocation_id: uuid.UUID, invocation_id: uuid.UUID,
client_id: int, client_id: int,
wrapper_id: int, wrapper_id: int,
request: Dict, request: Dict[str, Union[float, str, List[Dict[str, str]]]],
response: Union[str, ChatCompletion], response: Union[str, ChatCompletion],
is_cached: int, is_cached: int,
cost: float, cost: float,
@ -54,7 +58,7 @@ class BaseLogger(ABC):
... ...
@abstractmethod @abstractmethod
def log_new_agent(agent: ConversableAgent, init_args: Dict) -> None: def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
""" """
Log the birth of a new agent. Log the birth of a new agent.
@ -65,7 +69,7 @@ class BaseLogger(ABC):
... ...
@abstractmethod @abstractmethod
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict) -> None: def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
""" """
Log the birth of a new OpenAIWrapper. Log the birth of a new OpenAIWrapper.
@ -76,7 +80,9 @@ class BaseLogger(ABC):
... ...
@abstractmethod @abstractmethod
def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict) -> None: def log_new_client(
self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
) -> None:
""" """
Log the birth of a new OpenAIWrapper. Log the birth of a new OpenAIWrapper.
@ -87,14 +93,14 @@ class BaseLogger(ABC):
... ...
@abstractmethod @abstractmethod
def stop() -> None: def stop(self) -> None:
""" """
Close the connection to the logging database, and stop logging. Close the connection to the logging database, and stop logging.
""" """
... ...
@abstractmethod @abstractmethod
def get_connection() -> Union[sqlite3.Connection]: def get_connection(self) -> Union[None, sqlite3.Connection]:
""" """
Return a connection to the logging database. Return a connection to the logging database.
""" """

View File

@ -5,14 +5,14 @@ from typing import Any, Dict, List, Tuple, Union
__all__ = ("get_current_ts", "to_dict") __all__ = ("get_current_ts", "to_dict")
def get_current_ts(): def get_current_ts() -> str:
return datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f") return datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")
def to_dict( def to_dict(
obj: Union[int, float, str, bool, Dict[Any, Any], List[Any], Tuple[Any, ...], Any], obj: Union[int, float, str, bool, Dict[Any, Any], List[Any], Tuple[Any, ...], Any],
exclude: Tuple[str] = (), exclude: Tuple[str, ...] = (),
no_recursive: Tuple[str] = (), no_recursive: Tuple[Any, ...] = (),
) -> Any: ) -> Any:
if isinstance(obj, (int, float, str, bool)): if isinstance(obj, (int, float, str, bool)):
return obj return obj

View File

@ -4,7 +4,7 @@ import json
import logging import logging
import os import os
import sqlite3 import sqlite3
import sys import threading
import uuid import uuid
from autogen.logger.base_logger import BaseLogger from autogen.logger.base_logger import BaseLogger
@ -12,17 +12,15 @@ from autogen.logger.logger_utils import get_current_ts, to_dict
from openai import OpenAI, AzureOpenAI from openai import OpenAI, AzureOpenAI
from openai.types.chat import ChatCompletion from openai.types.chat import ChatCompletion
from typing import Dict, TYPE_CHECKING, Union from typing import Any, Dict, List, TYPE_CHECKING, Tuple, Union
from .base_logger import LLMConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from autogen import ConversableAgent, OpenAIWrapper from autogen import ConversableAgent, OpenAIWrapper
# this is a pointer to the module object instance itself
this = sys.modules[__name__]
this._session_id = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
lock = threading.Lock()
__all__ = ("SqliteLogger",) __all__ = ("SqliteLogger",)
@ -30,19 +28,19 @@ __all__ = ("SqliteLogger",)
class SqliteLogger(BaseLogger): class SqliteLogger(BaseLogger):
schema_version = 1 schema_version = 1
def __init__(self, config): def __init__(self, config: Dict[str, Any]):
self.con = None
self.cur = None
self.config = config self.config = config
def start(self) -> str:
dbname = self.config["dbname"] if "dbname" in self.config else "logs.db"
this._session_id = str(uuid.uuid4())
try: try:
self.con = sqlite3.connect(dbname) self.dbname = self.config.get("dbname", "logs.db")
self.con = sqlite3.connect(self.dbname, check_same_thread=False)
self.cur = self.con.cursor() self.cur = self.con.cursor()
self.session_id = str(uuid.uuid4())
except sqlite3.Error as e:
logger.error(f"[SqliteLogger] Failed to connect to database {self.dbname}: {e}")
def start(self) -> str:
try:
query = """ query = """
CREATE TABLE IF NOT EXISTS chat_completions( CREATE TABLE IF NOT EXISTS chat_completions(
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
@ -57,8 +55,7 @@ class SqliteLogger(BaseLogger):
start_time DATETIME DEFAULT CURRENT_TIMESTAMP, start_time DATETIME DEFAULT CURRENT_TIMESTAMP,
end_time DATETIME DEFAULT CURRENT_TIMESTAMP) end_time DATETIME DEFAULT CURRENT_TIMESTAMP)
""" """
self.cur.execute(query) self._run_query(query=query)
self.con.commit()
query = """ query = """
CREATE TABLE IF NOT EXISTS agents ( CREATE TABLE IF NOT EXISTS agents (
@ -72,8 +69,7 @@ class SqliteLogger(BaseLogger):
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(agent_id, session_id)) UNIQUE(agent_id, session_id))
""" """
self.cur.execute(query) self._run_query(query=query)
self.con.commit()
query = """ query = """
CREATE TABLE IF NOT EXISTS oai_wrappers ( CREATE TABLE IF NOT EXISTS oai_wrappers (
@ -84,8 +80,7 @@ class SqliteLogger(BaseLogger):
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(wrapper_id, session_id)) UNIQUE(wrapper_id, session_id))
""" """
self.cur.execute(query) self._run_query(query=query)
self.con.commit()
query = """ query = """
CREATE TABLE IF NOT EXISTS oai_clients ( CREATE TABLE IF NOT EXISTS oai_clients (
@ -98,8 +93,7 @@ class SqliteLogger(BaseLogger):
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(client_id, session_id)) UNIQUE(client_id, session_id))
""" """
self.cur.execute(query) self._run_query(query=query)
self.con.commit()
query = """ query = """
CREATE TABLE IF NOT EXISTS version ( CREATE TABLE IF NOT EXISTS version (
@ -107,31 +101,30 @@ class SqliteLogger(BaseLogger):
version_number INTEGER NOT NULL -- version of the logging database version_number INTEGER NOT NULL -- version of the logging database
); );
""" """
self.cur.execute(query) self._run_query(query=query)
self.con.commit()
current_verion = self._get_current_db_version() current_verion = self._get_current_db_version()
if current_verion is None: if current_verion is None:
self.cur.execute( self._run_query(
"INSERT INTO version (id, version_number) VALUES (1, ?);", (SqliteLogger.schema_version,) query="INSERT INTO version (id, version_number) VALUES (1, ?);", args=(SqliteLogger.schema_version,)
) )
self.con.commit() self._apply_migration()
self._apply_migration(dbname)
except sqlite3.Error as e: except sqlite3.Error as e:
logger.error(f"[SqliteLogger] start logging error: {e}") logger.error(f"[SqliteLogger] start logging error: {e}")
finally: finally:
return this._session_id return self.session_id
def _get_current_db_version(self): def _get_current_db_version(self) -> Union[None, int]:
self.cur.execute("SELECT version_number FROM version ORDER BY id DESC LIMIT 1") self.cur.execute("SELECT version_number FROM version ORDER BY id DESC LIMIT 1")
result = self.cur.fetchone() result = self.cur.fetchone()
return result[0] if result else None return result[0] if result is not None else None
# Example migration script name format: 002_update_agents_table.sql # Example migration script name format: 002_update_agents_table.sql
def _apply_migration(self, dbname, migrations_dir="./migrations"): def _apply_migration(self, migrations_dir: str = "./migrations") -> None:
current_version = self._get_current_db_version() current_version = self._get_current_db_version()
current_version = SqliteLogger.schema_version if current_version is None else current_version
if os.path.isdir(migrations_dir): if os.path.isdir(migrations_dir):
migrations = sorted(os.listdir(migrations_dir)) migrations = sorted(os.listdir(migrations_dir))
else: else:
@ -143,19 +136,48 @@ class SqliteLogger(BaseLogger):
for script in migrations_to_apply: for script in migrations_to_apply:
with open(script, "r") as f: with open(script, "r") as f:
migration_sql = f.read() migration_sql = f.read()
self.con.executescript(migration_sql) self._run_query_script(script=migration_sql)
self.con.commit()
latest_version = int(script.split("_")[0]) latest_version = int(script.split("_")[0])
self.cur.execute("UPDATE version SET version_number = ? WHERE id = 1", (latest_version)) query = "UPDATE version SET version_number = ? WHERE id = 1"
args = (latest_version,)
self._run_query(query=query, args=args)
def _run_query(self, query: str, args: Tuple[Any, ...] = ()) -> None:
"""
Executes a given SQL query.
Args:
query (str): The SQL query to execute.
args (Tuple): The arguments to pass to the SQL query.
"""
try:
with lock:
self.cur.execute(query, args)
self.con.commit() self.con.commit()
except Exception as e:
logger.error("[sqlite logger]Error running query with query %s and args %s: %s", query, args, e)
def _run_query_script(self, script: str) -> None:
"""
Executes SQL script.
Args:
script (str): SQL script to execute.
"""
try:
with lock:
self.cur.executescript(script)
self.con.commit()
except Exception as e:
logger.error("[sqlite logger]Error running query script %s: %s", script, e)
def log_chat_completion( def log_chat_completion(
self, self,
invocation_id: uuid.UUID, invocation_id: uuid.UUID,
client_id: int, client_id: int,
wrapper_id: int, wrapper_id: int,
request: Dict, request: Dict[str, Union[float, str, List[Dict[str, str]]]],
response: Union[str, ChatCompletion], response: Union[str, ChatCompletion],
is_cached: int, is_cached: int,
cost: float, cost: float,
@ -176,28 +198,22 @@ class SqliteLogger(BaseLogger):
invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""" """
args = (
invocation_id,
client_id,
wrapper_id,
self.session_id,
json.dumps(request),
response_messages,
is_cached,
cost,
start_time,
end_time,
)
try: self._run_query(query=query, args=args)
self.cur.execute(
query,
(
invocation_id,
client_id,
wrapper_id,
this._session_id,
json.dumps(request),
response_messages,
is_cached,
cost,
start_time,
end_time,
),
)
self.con.commit()
except sqlite3.Error as e:
logger.error(f"[SqliteLogger] log_chat_completion error: {e}")
def log_new_agent(self, agent: ConversableAgent, init_args: Dict) -> None: def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
from autogen import Agent from autogen import Agent
if self.con is None: if self.con is None:
@ -206,7 +222,7 @@ class SqliteLogger(BaseLogger):
args = to_dict( args = to_dict(
init_args, init_args,
exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint"), exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint"),
no_recursive=(Agent), no_recursive=(Agent,),
) )
# We do an upsert since both the superclass and subclass may call this method (in that order) # We do an upsert since both the superclass and subclass may call this method (in that order)
@ -219,24 +235,18 @@ class SqliteLogger(BaseLogger):
init_args = excluded.init_args, init_args = excluded.init_args,
timestamp = excluded.timestamp timestamp = excluded.timestamp
""" """
try: args = (
self.cur.execute( id(agent),
query, agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else "",
( self.session_id,
id(agent), agent.name if hasattr(agent, "name") and agent.name is not None else "",
agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else "", type(agent).__name__,
this._session_id, json.dumps(args),
agent.name if hasattr(agent, "name") and agent.name is not None else "", get_current_ts(),
type(agent).__name__, )
json.dumps(args), self._run_query(query=query, args=args)
get_current_ts(),
),
)
self.con.commit()
except sqlite3.Error as e:
logger.error(f"[SqliteLogger] log_new_agent error: {e}")
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict) -> None: def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
if self.con is None: if self.con is None:
return return
@ -248,21 +258,17 @@ class SqliteLogger(BaseLogger):
INSERT INTO oai_wrappers (wrapper_id, session_id, init_args, timestamp) VALUES (?, ?, ?, ?) INSERT INTO oai_wrappers (wrapper_id, session_id, init_args, timestamp) VALUES (?, ?, ?, ?)
ON CONFLICT (wrapper_id, session_id) DO NOTHING; ON CONFLICT (wrapper_id, session_id) DO NOTHING;
""" """
try: args = (
self.cur.execute( id(wrapper),
query, self.session_id,
( json.dumps(args),
id(wrapper), get_current_ts(),
this._session_id, )
json.dumps(args), self._run_query(query=query, args=args)
get_current_ts(),
),
)
self.con.commit()
except sqlite3.Error as e:
logger.error(f"[SqliteLogger] log_new_wrapper error: {e}")
def log_new_client(self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict) -> None: def log_new_client(
self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
) -> None:
if self.con is None: if self.con is None:
return return
@ -274,28 +280,21 @@ class SqliteLogger(BaseLogger):
INSERT INTO oai_clients (client_id, wrapper_id, session_id, class, init_args, timestamp) VALUES (?, ?, ?, ?, ?, ?) INSERT INTO oai_clients (client_id, wrapper_id, session_id, class, init_args, timestamp) VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (client_id, session_id) DO NOTHING; ON CONFLICT (client_id, session_id) DO NOTHING;
""" """
try: args = (
self.cur.execute( id(client),
query, id(wrapper),
( self.session_id,
id(client), type(client).__name__,
id(wrapper), json.dumps(args),
this._session_id, get_current_ts(),
type(client).__name__, )
json.dumps(args), self._run_query(query=query, args=args)
get_current_ts(),
),
)
self.con.commit()
except sqlite3.Error as e:
logger.error(f"[SqliteLogger] log_new_client error: {e}")
def stop(self) -> None: def stop(self) -> None:
if self.con: if self.con:
self.con.close() self.con.close()
self.con = None
self.cur = None
def get_connection(self) -> sqlite3.Connection: def get_connection(self) -> Union[None, sqlite3.Connection]:
if self.con: if self.con:
return self.con return self.con
return None

View File

@ -1,8 +1,11 @@
from __future__ import annotations from __future__ import annotations
from autogen.logger.logger_factory import LoggerFactory from autogen.logger.logger_factory import LoggerFactory
from autogen.logger.base_logger import LLMConfig
import logging
import sqlite3 import sqlite3
from typing import Any, Dict, Optional, TYPE_CHECKING, Union from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
import uuid import uuid
from openai import OpenAI, AzureOpenAI from openai import OpenAI, AzureOpenAI
@ -11,6 +14,8 @@ from openai.types.chat import ChatCompletion
if TYPE_CHECKING: if TYPE_CHECKING:
from autogen import ConversableAgent, OpenAIWrapper from autogen import ConversableAgent, OpenAIWrapper
logger = logging.getLogger(__name__)
autogen_logger = None autogen_logger = None
is_logging = False is_logging = False
@ -19,39 +24,57 @@ def start(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None)
global autogen_logger global autogen_logger
global is_logging global is_logging
if autogen_logger is None: autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
session_id = autogen_logger.start() try:
is_logging = True session_id = autogen_logger.start()
is_logging = True
return session_id except Exception as e:
logger.error(f"[runtime logging] Failed to start logging: {e}")
finally:
return session_id
def log_chat_completion( def log_chat_completion(
invocation_id: uuid.UUID, invocation_id: uuid.UUID,
client_id: int, client_id: int,
wrapper_id: int, wrapper_id: int,
request: Dict, request: Dict[str, Union[float, str, List[Dict[str, str]]]],
response: Union[str, ChatCompletion], response: Union[str, ChatCompletion],
is_cached: int, is_cached: int,
cost: float, cost: float,
start_time: str, start_time: str,
) -> None: ) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_chat_completion: autogen logger is None")
return
autogen_logger.log_chat_completion( autogen_logger.log_chat_completion(
invocation_id, client_id, wrapper_id, request, response, is_cached, cost, start_time invocation_id, client_id, wrapper_id, request, response, is_cached, cost, start_time
) )
def log_new_agent(agent: ConversableAgent, init_args: Dict) -> None: def log_new_agent(agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_new_agent: autogen logger is None")
return
autogen_logger.log_new_agent(agent, init_args) autogen_logger.log_new_agent(agent, init_args)
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict) -> None: def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_new_wrapper: autogen logger is None")
return
autogen_logger.log_new_wrapper(wrapper, init_args) autogen_logger.log_new_wrapper(wrapper, init_args)
def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict) -> None: def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_new_client: autogen logger is None")
return
autogen_logger.log_new_client(client, wrapper, init_args) autogen_logger.log_new_client(client, wrapper, init_args)
@ -62,7 +85,11 @@ def stop() -> None:
is_logging = False is_logging = False
def get_connection() -> Union[sqlite3.Connection]: def get_connection() -> Union[None, sqlite3.Connection]:
if autogen_logger is None:
logger.error("[runtime logging] get_connection: autogen logger is None")
return None
return autogen_logger.get_connection() return autogen_logger.get_connection()

View File

@ -265,4 +265,4 @@ def test_logging_exception_will_not_crash_only_print_error(mock_logger_error, db
args, _ = mock_logger_error.call_args args, _ = mock_logger_error.call_args
error_message = args[0] error_message = args[0]
assert error_message.startswith("[SqliteLogger] log_chat_completion error:") assert error_message.startswith("[sqlite logger]Error running query with query")