Fix threading issue for logging (#1901)

* fix mypy errors for logging

* cleanup

* formatting

* fix threading issue in logging

* remove kwarg

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
cheng-tan 2024-03-07 19:08:22 -05:00 committed by GitHub
parent 2a62ffc566
commit 811fd6926d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 162 additions and 130 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Dict, TYPE_CHECKING, Union
from typing import Any, Dict, List, TYPE_CHECKING, Union
import sqlite3
import uuid
@ -11,6 +11,9 @@ from openai.types.chat import ChatCompletion
if TYPE_CHECKING:
from autogen import ConversableAgent, OpenAIWrapper
ConfigItem = Dict[str, Union[str, List[str]]]
LLMConfig = Dict[str, Union[None, float, int, ConfigItem, List[ConfigItem]]]
class BaseLogger(ABC):
@abstractmethod
@ -25,10 +28,11 @@ class BaseLogger(ABC):
@abstractmethod
def log_chat_completion(
self,
invocation_id: uuid.UUID,
client_id: int,
wrapper_id: int,
request: Dict,
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
response: Union[str, ChatCompletion],
is_cached: int,
cost: float,
@ -54,7 +58,7 @@ class BaseLogger(ABC):
...
@abstractmethod
def log_new_agent(agent: ConversableAgent, init_args: Dict) -> None:
def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
"""
Log the birth of a new agent.
@ -65,7 +69,7 @@ class BaseLogger(ABC):
...
@abstractmethod
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict) -> None:
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
"""
Log the birth of a new OpenAIWrapper.
@ -76,7 +80,9 @@ class BaseLogger(ABC):
...
@abstractmethod
def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict) -> None:
def log_new_client(
self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
) -> None:
"""
Log the birth of a new OpenAIWrapper.
@ -87,14 +93,14 @@ class BaseLogger(ABC):
...
@abstractmethod
def stop() -> None:
def stop(self) -> None:
"""
Close the connection to the logging database, and stop logging.
"""
...
@abstractmethod
def get_connection() -> Union[sqlite3.Connection]:
def get_connection(self) -> Union[None, sqlite3.Connection]:
"""
Return a connection to the logging database.
"""

View File

@ -5,14 +5,14 @@ from typing import Any, Dict, List, Tuple, Union
__all__ = ("get_current_ts", "to_dict")
def get_current_ts():
def get_current_ts() -> str:
return datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")
def to_dict(
obj: Union[int, float, str, bool, Dict[Any, Any], List[Any], Tuple[Any, ...], Any],
exclude: Tuple[str] = (),
no_recursive: Tuple[str] = (),
exclude: Tuple[str, ...] = (),
no_recursive: Tuple[Any, ...] = (),
) -> Any:
if isinstance(obj, (int, float, str, bool)):
return obj

View File

@ -4,7 +4,7 @@ import json
import logging
import os
import sqlite3
import sys
import threading
import uuid
from autogen.logger.base_logger import BaseLogger
@ -12,17 +12,15 @@ from autogen.logger.logger_utils import get_current_ts, to_dict
from openai import OpenAI, AzureOpenAI
from openai.types.chat import ChatCompletion
from typing import Dict, TYPE_CHECKING, Union
from typing import Any, Dict, List, TYPE_CHECKING, Tuple, Union
from .base_logger import LLMConfig
if TYPE_CHECKING:
from autogen import ConversableAgent, OpenAIWrapper
# this is a pointer to the module object instance itself
this = sys.modules[__name__]
this._session_id = None
logger = logging.getLogger(__name__)
lock = threading.Lock()
__all__ = ("SqliteLogger",)
@ -30,19 +28,19 @@ __all__ = ("SqliteLogger",)
class SqliteLogger(BaseLogger):
schema_version = 1
def __init__(self, config):
self.con = None
self.cur = None
def __init__(self, config: Dict[str, Any]):
self.config = config
def start(self) -> str:
dbname = self.config["dbname"] if "dbname" in self.config else "logs.db"
this._session_id = str(uuid.uuid4())
try:
self.con = sqlite3.connect(dbname)
self.dbname = self.config.get("dbname", "logs.db")
self.con = sqlite3.connect(self.dbname, check_same_thread=False)
self.cur = self.con.cursor()
self.session_id = str(uuid.uuid4())
except sqlite3.Error as e:
logger.error(f"[SqliteLogger] Failed to connect to database {self.dbname}: {e}")
def start(self) -> str:
try:
query = """
CREATE TABLE IF NOT EXISTS chat_completions(
id INTEGER PRIMARY KEY,
@ -57,8 +55,7 @@ class SqliteLogger(BaseLogger):
start_time DATETIME DEFAULT CURRENT_TIMESTAMP,
end_time DATETIME DEFAULT CURRENT_TIMESTAMP)
"""
self.cur.execute(query)
self.con.commit()
self._run_query(query=query)
query = """
CREATE TABLE IF NOT EXISTS agents (
@ -72,8 +69,7 @@ class SqliteLogger(BaseLogger):
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(agent_id, session_id))
"""
self.cur.execute(query)
self.con.commit()
self._run_query(query=query)
query = """
CREATE TABLE IF NOT EXISTS oai_wrappers (
@ -84,8 +80,7 @@ class SqliteLogger(BaseLogger):
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(wrapper_id, session_id))
"""
self.cur.execute(query)
self.con.commit()
self._run_query(query=query)
query = """
CREATE TABLE IF NOT EXISTS oai_clients (
@ -98,8 +93,7 @@ class SqliteLogger(BaseLogger):
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(client_id, session_id))
"""
self.cur.execute(query)
self.con.commit()
self._run_query(query=query)
query = """
CREATE TABLE IF NOT EXISTS version (
@ -107,31 +101,30 @@ class SqliteLogger(BaseLogger):
version_number INTEGER NOT NULL -- version of the logging database
);
"""
self.cur.execute(query)
self.con.commit()
self._run_query(query=query)
current_verion = self._get_current_db_version()
if current_verion is None:
self.cur.execute(
"INSERT INTO version (id, version_number) VALUES (1, ?);", (SqliteLogger.schema_version,)
self._run_query(
query="INSERT INTO version (id, version_number) VALUES (1, ?);", args=(SqliteLogger.schema_version,)
)
self.con.commit()
self._apply_migration(dbname)
self._apply_migration()
except sqlite3.Error as e:
logger.error(f"[SqliteLogger] start logging error: {e}")
finally:
return this._session_id
return self.session_id
def _get_current_db_version(self):
def _get_current_db_version(self) -> Union[None, int]:
self.cur.execute("SELECT version_number FROM version ORDER BY id DESC LIMIT 1")
result = self.cur.fetchone()
return result[0] if result else None
return result[0] if result is not None else None
# Example migration script name format: 002_update_agents_table.sql
def _apply_migration(self, dbname, migrations_dir="./migrations"):
def _apply_migration(self, migrations_dir: str = "./migrations") -> None:
current_version = self._get_current_db_version()
current_version = SqliteLogger.schema_version if current_version is None else current_version
if os.path.isdir(migrations_dir):
migrations = sorted(os.listdir(migrations_dir))
else:
@ -143,19 +136,48 @@ class SqliteLogger(BaseLogger):
for script in migrations_to_apply:
with open(script, "r") as f:
migration_sql = f.read()
self.con.executescript(migration_sql)
self.con.commit()
self._run_query_script(script=migration_sql)
latest_version = int(script.split("_")[0])
self.cur.execute("UPDATE version SET version_number = ? WHERE id = 1", (latest_version))
query = "UPDATE version SET version_number = ? WHERE id = 1"
args = (latest_version,)
self._run_query(query=query, args=args)
def _run_query(self, query: str, args: Tuple[Any, ...] = ()) -> None:
"""
Executes a given SQL query.
Args:
query (str): The SQL query to execute.
args (Tuple): The arguments to pass to the SQL query.
"""
try:
with lock:
self.cur.execute(query, args)
self.con.commit()
except Exception as e:
logger.error("[sqlite logger]Error running query with query %s and args %s: %s", query, args, e)
def _run_query_script(self, script: str) -> None:
"""
Executes SQL script.
Args:
script (str): SQL script to execute.
"""
try:
with lock:
self.cur.executescript(script)
self.con.commit()
except Exception as e:
logger.error("[sqlite logger]Error running query script %s: %s", script, e)
def log_chat_completion(
self,
invocation_id: uuid.UUID,
client_id: int,
wrapper_id: int,
request: Dict,
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
response: Union[str, ChatCompletion],
is_cached: int,
cost: float,
@ -176,28 +198,22 @@ class SqliteLogger(BaseLogger):
invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
args = (
invocation_id,
client_id,
wrapper_id,
self.session_id,
json.dumps(request),
response_messages,
is_cached,
cost,
start_time,
end_time,
)
try:
self.cur.execute(
query,
(
invocation_id,
client_id,
wrapper_id,
this._session_id,
json.dumps(request),
response_messages,
is_cached,
cost,
start_time,
end_time,
),
)
self.con.commit()
except sqlite3.Error as e:
logger.error(f"[SqliteLogger] log_chat_completion error: {e}")
self._run_query(query=query, args=args)
def log_new_agent(self, agent: ConversableAgent, init_args: Dict) -> None:
def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
from autogen import Agent
if self.con is None:
@ -206,7 +222,7 @@ class SqliteLogger(BaseLogger):
args = to_dict(
init_args,
exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint"),
no_recursive=(Agent),
no_recursive=(Agent,),
)
# We do an upsert since both the superclass and subclass may call this method (in that order)
@ -219,24 +235,18 @@ class SqliteLogger(BaseLogger):
init_args = excluded.init_args,
timestamp = excluded.timestamp
"""
try:
self.cur.execute(
query,
(
id(agent),
agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else "",
this._session_id,
agent.name if hasattr(agent, "name") and agent.name is not None else "",
type(agent).__name__,
json.dumps(args),
get_current_ts(),
),
)
self.con.commit()
except sqlite3.Error as e:
logger.error(f"[SqliteLogger] log_new_agent error: {e}")
args = (
id(agent),
agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else "",
self.session_id,
agent.name if hasattr(agent, "name") and agent.name is not None else "",
type(agent).__name__,
json.dumps(args),
get_current_ts(),
)
self._run_query(query=query, args=args)
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict) -> None:
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
if self.con is None:
return
@ -248,21 +258,17 @@ class SqliteLogger(BaseLogger):
INSERT INTO oai_wrappers (wrapper_id, session_id, init_args, timestamp) VALUES (?, ?, ?, ?)
ON CONFLICT (wrapper_id, session_id) DO NOTHING;
"""
try:
self.cur.execute(
query,
(
id(wrapper),
this._session_id,
json.dumps(args),
get_current_ts(),
),
)
self.con.commit()
except sqlite3.Error as e:
logger.error(f"[SqliteLogger] log_new_wrapper error: {e}")
args = (
id(wrapper),
self.session_id,
json.dumps(args),
get_current_ts(),
)
self._run_query(query=query, args=args)
def log_new_client(self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict) -> None:
def log_new_client(
self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
) -> None:
if self.con is None:
return
@ -274,28 +280,21 @@ class SqliteLogger(BaseLogger):
INSERT INTO oai_clients (client_id, wrapper_id, session_id, class, init_args, timestamp) VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (client_id, session_id) DO NOTHING;
"""
try:
self.cur.execute(
query,
(
id(client),
id(wrapper),
this._session_id,
type(client).__name__,
json.dumps(args),
get_current_ts(),
),
)
self.con.commit()
except sqlite3.Error as e:
logger.error(f"[SqliteLogger] log_new_client error: {e}")
args = (
id(client),
id(wrapper),
self.session_id,
type(client).__name__,
json.dumps(args),
get_current_ts(),
)
self._run_query(query=query, args=args)
def stop(self) -> None:
if self.con:
self.con.close()
self.con = None
self.cur = None
def get_connection(self) -> sqlite3.Connection:
def get_connection(self) -> Union[None, sqlite3.Connection]:
if self.con:
return self.con
return None

View File

@ -1,8 +1,11 @@
from __future__ import annotations
from autogen.logger.logger_factory import LoggerFactory
from autogen.logger.base_logger import LLMConfig
import logging
import sqlite3
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
import uuid
from openai import OpenAI, AzureOpenAI
@ -11,6 +14,8 @@ from openai.types.chat import ChatCompletion
if TYPE_CHECKING:
from autogen import ConversableAgent, OpenAIWrapper
logger = logging.getLogger(__name__)
autogen_logger = None
is_logging = False
@ -19,39 +24,57 @@ def start(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None)
global autogen_logger
global is_logging
if autogen_logger is None:
autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
session_id = autogen_logger.start()
is_logging = True
return session_id
try:
session_id = autogen_logger.start()
is_logging = True
except Exception as e:
logger.error(f"[runtime logging] Failed to start logging: {e}")
finally:
return session_id
def log_chat_completion(
invocation_id: uuid.UUID,
client_id: int,
wrapper_id: int,
request: Dict,
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
response: Union[str, ChatCompletion],
is_cached: int,
cost: float,
start_time: str,
) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_chat_completion: autogen logger is None")
return
autogen_logger.log_chat_completion(
invocation_id, client_id, wrapper_id, request, response, is_cached, cost, start_time
)
def log_new_agent(agent: ConversableAgent, init_args: Dict) -> None:
def log_new_agent(agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_new_agent: autogen logger is None")
return
autogen_logger.log_new_agent(agent, init_args)
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict) -> None:
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_new_wrapper: autogen logger is None")
return
autogen_logger.log_new_wrapper(wrapper, init_args)
def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict) -> None:
def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_new_client: autogen logger is None")
return
autogen_logger.log_new_client(client, wrapper, init_args)
@ -62,7 +85,11 @@ def stop() -> None:
is_logging = False
def get_connection() -> Union[sqlite3.Connection]:
def get_connection() -> Union[None, sqlite3.Connection]:
if autogen_logger is None:
logger.error("[runtime logging] get_connection: autogen logger is None")
return None
return autogen_logger.get_connection()

View File

@ -265,4 +265,4 @@ def test_logging_exception_will_not_crash_only_print_error(mock_logger_error, db
args, _ = mock_logger_error.call_args
error_message = args[0]
assert error_message.startswith("[SqliteLogger] log_chat_completion error:")
assert error_message.startswith("[sqlite logger]Error running query with query")