mirror of https://github.com/microsoft/autogen.git
Fix threading issue for logging (#1901)
* fix mypy errors for logging * cleanup * formatting * fix threading issue in logging * remove kwarg --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
2a62ffc566
commit
811fd6926d
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, TYPE_CHECKING, Union
|
from typing import Any, Dict, List, TYPE_CHECKING, Union
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
@ -11,6 +11,9 @@ from openai.types.chat import ChatCompletion
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from autogen import ConversableAgent, OpenAIWrapper
|
from autogen import ConversableAgent, OpenAIWrapper
|
||||||
|
|
||||||
|
ConfigItem = Dict[str, Union[str, List[str]]]
|
||||||
|
LLMConfig = Dict[str, Union[None, float, int, ConfigItem, List[ConfigItem]]]
|
||||||
|
|
||||||
|
|
||||||
class BaseLogger(ABC):
|
class BaseLogger(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -25,10 +28,11 @@ class BaseLogger(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def log_chat_completion(
|
def log_chat_completion(
|
||||||
|
self,
|
||||||
invocation_id: uuid.UUID,
|
invocation_id: uuid.UUID,
|
||||||
client_id: int,
|
client_id: int,
|
||||||
wrapper_id: int,
|
wrapper_id: int,
|
||||||
request: Dict,
|
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
|
||||||
response: Union[str, ChatCompletion],
|
response: Union[str, ChatCompletion],
|
||||||
is_cached: int,
|
is_cached: int,
|
||||||
cost: float,
|
cost: float,
|
||||||
|
@ -54,7 +58,7 @@ class BaseLogger(ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def log_new_agent(agent: ConversableAgent, init_args: Dict) -> None:
|
def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
Log the birth of a new agent.
|
Log the birth of a new agent.
|
||||||
|
|
||||||
|
@ -65,7 +69,7 @@ class BaseLogger(ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict) -> None:
|
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
|
||||||
"""
|
"""
|
||||||
Log the birth of a new OpenAIWrapper.
|
Log the birth of a new OpenAIWrapper.
|
||||||
|
|
||||||
|
@ -76,7 +80,9 @@ class BaseLogger(ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict) -> None:
|
def log_new_client(
|
||||||
|
self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Log the birth of a new OpenAIWrapper.
|
Log the birth of a new OpenAIWrapper.
|
||||||
|
|
||||||
|
@ -87,14 +93,14 @@ class BaseLogger(ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def stop() -> None:
|
def stop(self) -> None:
|
||||||
"""
|
"""
|
||||||
Close the connection to the logging database, and stop logging.
|
Close the connection to the logging database, and stop logging.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_connection() -> Union[sqlite3.Connection]:
|
def get_connection(self) -> Union[None, sqlite3.Connection]:
|
||||||
"""
|
"""
|
||||||
Return a connection to the logging database.
|
Return a connection to the logging database.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -5,14 +5,14 @@ from typing import Any, Dict, List, Tuple, Union
|
||||||
__all__ = ("get_current_ts", "to_dict")
|
__all__ = ("get_current_ts", "to_dict")
|
||||||
|
|
||||||
|
|
||||||
def get_current_ts():
|
def get_current_ts() -> str:
|
||||||
return datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")
|
return datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||||
|
|
||||||
|
|
||||||
def to_dict(
|
def to_dict(
|
||||||
obj: Union[int, float, str, bool, Dict[Any, Any], List[Any], Tuple[Any, ...], Any],
|
obj: Union[int, float, str, bool, Dict[Any, Any], List[Any], Tuple[Any, ...], Any],
|
||||||
exclude: Tuple[str] = (),
|
exclude: Tuple[str, ...] = (),
|
||||||
no_recursive: Tuple[str] = (),
|
no_recursive: Tuple[Any, ...] = (),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if isinstance(obj, (int, float, str, bool)):
|
if isinstance(obj, (int, float, str, bool)):
|
||||||
return obj
|
return obj
|
||||||
|
|
|
@ -4,7 +4,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import sys
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from autogen.logger.base_logger import BaseLogger
|
from autogen.logger.base_logger import BaseLogger
|
||||||
|
@ -12,17 +12,15 @@ from autogen.logger.logger_utils import get_current_ts, to_dict
|
||||||
|
|
||||||
from openai import OpenAI, AzureOpenAI
|
from openai import OpenAI, AzureOpenAI
|
||||||
from openai.types.chat import ChatCompletion
|
from openai.types.chat import ChatCompletion
|
||||||
from typing import Dict, TYPE_CHECKING, Union
|
from typing import Any, Dict, List, TYPE_CHECKING, Tuple, Union
|
||||||
|
from .base_logger import LLMConfig
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from autogen import ConversableAgent, OpenAIWrapper
|
from autogen import ConversableAgent, OpenAIWrapper
|
||||||
|
|
||||||
|
|
||||||
# this is a pointer to the module object instance itself
|
|
||||||
this = sys.modules[__name__]
|
|
||||||
this._session_id = None
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
__all__ = ("SqliteLogger",)
|
__all__ = ("SqliteLogger",)
|
||||||
|
|
||||||
|
@ -30,19 +28,19 @@ __all__ = ("SqliteLogger",)
|
||||||
class SqliteLogger(BaseLogger):
|
class SqliteLogger(BaseLogger):
|
||||||
schema_version = 1
|
schema_version = 1
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config: Dict[str, Any]):
|
||||||
self.con = None
|
|
||||||
self.cur = None
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def start(self) -> str:
|
|
||||||
dbname = self.config["dbname"] if "dbname" in self.config else "logs.db"
|
|
||||||
this._session_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.con = sqlite3.connect(dbname)
|
self.dbname = self.config.get("dbname", "logs.db")
|
||||||
|
self.con = sqlite3.connect(self.dbname, check_same_thread=False)
|
||||||
self.cur = self.con.cursor()
|
self.cur = self.con.cursor()
|
||||||
|
self.session_id = str(uuid.uuid4())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
logger.error(f"[SqliteLogger] Failed to connect to database {self.dbname}: {e}")
|
||||||
|
|
||||||
|
def start(self) -> str:
|
||||||
|
try:
|
||||||
query = """
|
query = """
|
||||||
CREATE TABLE IF NOT EXISTS chat_completions(
|
CREATE TABLE IF NOT EXISTS chat_completions(
|
||||||
id INTEGER PRIMARY KEY,
|
id INTEGER PRIMARY KEY,
|
||||||
|
@ -57,8 +55,7 @@ class SqliteLogger(BaseLogger):
|
||||||
start_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
start_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
end_time DATETIME DEFAULT CURRENT_TIMESTAMP)
|
end_time DATETIME DEFAULT CURRENT_TIMESTAMP)
|
||||||
"""
|
"""
|
||||||
self.cur.execute(query)
|
self._run_query(query=query)
|
||||||
self.con.commit()
|
|
||||||
|
|
||||||
query = """
|
query = """
|
||||||
CREATE TABLE IF NOT EXISTS agents (
|
CREATE TABLE IF NOT EXISTS agents (
|
||||||
|
@ -72,8 +69,7 @@ class SqliteLogger(BaseLogger):
|
||||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
UNIQUE(agent_id, session_id))
|
UNIQUE(agent_id, session_id))
|
||||||
"""
|
"""
|
||||||
self.cur.execute(query)
|
self._run_query(query=query)
|
||||||
self.con.commit()
|
|
||||||
|
|
||||||
query = """
|
query = """
|
||||||
CREATE TABLE IF NOT EXISTS oai_wrappers (
|
CREATE TABLE IF NOT EXISTS oai_wrappers (
|
||||||
|
@ -84,8 +80,7 @@ class SqliteLogger(BaseLogger):
|
||||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
UNIQUE(wrapper_id, session_id))
|
UNIQUE(wrapper_id, session_id))
|
||||||
"""
|
"""
|
||||||
self.cur.execute(query)
|
self._run_query(query=query)
|
||||||
self.con.commit()
|
|
||||||
|
|
||||||
query = """
|
query = """
|
||||||
CREATE TABLE IF NOT EXISTS oai_clients (
|
CREATE TABLE IF NOT EXISTS oai_clients (
|
||||||
|
@ -98,8 +93,7 @@ class SqliteLogger(BaseLogger):
|
||||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
UNIQUE(client_id, session_id))
|
UNIQUE(client_id, session_id))
|
||||||
"""
|
"""
|
||||||
self.cur.execute(query)
|
self._run_query(query=query)
|
||||||
self.con.commit()
|
|
||||||
|
|
||||||
query = """
|
query = """
|
||||||
CREATE TABLE IF NOT EXISTS version (
|
CREATE TABLE IF NOT EXISTS version (
|
||||||
|
@ -107,31 +101,30 @@ class SqliteLogger(BaseLogger):
|
||||||
version_number INTEGER NOT NULL -- version of the logging database
|
version_number INTEGER NOT NULL -- version of the logging database
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
self.cur.execute(query)
|
self._run_query(query=query)
|
||||||
self.con.commit()
|
|
||||||
|
|
||||||
current_verion = self._get_current_db_version()
|
current_verion = self._get_current_db_version()
|
||||||
if current_verion is None:
|
if current_verion is None:
|
||||||
self.cur.execute(
|
self._run_query(
|
||||||
"INSERT INTO version (id, version_number) VALUES (1, ?);", (SqliteLogger.schema_version,)
|
query="INSERT INTO version (id, version_number) VALUES (1, ?);", args=(SqliteLogger.schema_version,)
|
||||||
)
|
)
|
||||||
self.con.commit()
|
self._apply_migration()
|
||||||
|
|
||||||
self._apply_migration(dbname)
|
|
||||||
|
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
logger.error(f"[SqliteLogger] start logging error: {e}")
|
logger.error(f"[SqliteLogger] start logging error: {e}")
|
||||||
finally:
|
finally:
|
||||||
return this._session_id
|
return self.session_id
|
||||||
|
|
||||||
def _get_current_db_version(self):
|
def _get_current_db_version(self) -> Union[None, int]:
|
||||||
self.cur.execute("SELECT version_number FROM version ORDER BY id DESC LIMIT 1")
|
self.cur.execute("SELECT version_number FROM version ORDER BY id DESC LIMIT 1")
|
||||||
result = self.cur.fetchone()
|
result = self.cur.fetchone()
|
||||||
return result[0] if result else None
|
return result[0] if result is not None else None
|
||||||
|
|
||||||
# Example migration script name format: 002_update_agents_table.sql
|
# Example migration script name format: 002_update_agents_table.sql
|
||||||
def _apply_migration(self, dbname, migrations_dir="./migrations"):
|
def _apply_migration(self, migrations_dir: str = "./migrations") -> None:
|
||||||
current_version = self._get_current_db_version()
|
current_version = self._get_current_db_version()
|
||||||
|
current_version = SqliteLogger.schema_version if current_version is None else current_version
|
||||||
|
|
||||||
if os.path.isdir(migrations_dir):
|
if os.path.isdir(migrations_dir):
|
||||||
migrations = sorted(os.listdir(migrations_dir))
|
migrations = sorted(os.listdir(migrations_dir))
|
||||||
else:
|
else:
|
||||||
|
@ -143,19 +136,48 @@ class SqliteLogger(BaseLogger):
|
||||||
for script in migrations_to_apply:
|
for script in migrations_to_apply:
|
||||||
with open(script, "r") as f:
|
with open(script, "r") as f:
|
||||||
migration_sql = f.read()
|
migration_sql = f.read()
|
||||||
self.con.executescript(migration_sql)
|
self._run_query_script(script=migration_sql)
|
||||||
self.con.commit()
|
|
||||||
|
|
||||||
latest_version = int(script.split("_")[0])
|
latest_version = int(script.split("_")[0])
|
||||||
self.cur.execute("UPDATE version SET version_number = ? WHERE id = 1", (latest_version))
|
query = "UPDATE version SET version_number = ? WHERE id = 1"
|
||||||
|
args = (latest_version,)
|
||||||
|
self._run_query(query=query, args=args)
|
||||||
|
|
||||||
|
def _run_query(self, query: str, args: Tuple[Any, ...] = ()) -> None:
|
||||||
|
"""
|
||||||
|
Executes a given SQL query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The SQL query to execute.
|
||||||
|
args (Tuple): The arguments to pass to the SQL query.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with lock:
|
||||||
|
self.cur.execute(query, args)
|
||||||
self.con.commit()
|
self.con.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("[sqlite logger]Error running query with query %s and args %s: %s", query, args, e)
|
||||||
|
|
||||||
|
def _run_query_script(self, script: str) -> None:
|
||||||
|
"""
|
||||||
|
Executes SQL script.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
script (str): SQL script to execute.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with lock:
|
||||||
|
self.cur.executescript(script)
|
||||||
|
self.con.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("[sqlite logger]Error running query script %s: %s", script, e)
|
||||||
|
|
||||||
def log_chat_completion(
|
def log_chat_completion(
|
||||||
self,
|
self,
|
||||||
invocation_id: uuid.UUID,
|
invocation_id: uuid.UUID,
|
||||||
client_id: int,
|
client_id: int,
|
||||||
wrapper_id: int,
|
wrapper_id: int,
|
||||||
request: Dict,
|
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
|
||||||
response: Union[str, ChatCompletion],
|
response: Union[str, ChatCompletion],
|
||||||
is_cached: int,
|
is_cached: int,
|
||||||
cost: float,
|
cost: float,
|
||||||
|
@ -176,28 +198,22 @@ class SqliteLogger(BaseLogger):
|
||||||
invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time
|
invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
"""
|
"""
|
||||||
|
args = (
|
||||||
|
invocation_id,
|
||||||
|
client_id,
|
||||||
|
wrapper_id,
|
||||||
|
self.session_id,
|
||||||
|
json.dumps(request),
|
||||||
|
response_messages,
|
||||||
|
is_cached,
|
||||||
|
cost,
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
self._run_query(query=query, args=args)
|
||||||
self.cur.execute(
|
|
||||||
query,
|
|
||||||
(
|
|
||||||
invocation_id,
|
|
||||||
client_id,
|
|
||||||
wrapper_id,
|
|
||||||
this._session_id,
|
|
||||||
json.dumps(request),
|
|
||||||
response_messages,
|
|
||||||
is_cached,
|
|
||||||
cost,
|
|
||||||
start_time,
|
|
||||||
end_time,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.con.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
logger.error(f"[SqliteLogger] log_chat_completion error: {e}")
|
|
||||||
|
|
||||||
def log_new_agent(self, agent: ConversableAgent, init_args: Dict) -> None:
|
def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
|
||||||
from autogen import Agent
|
from autogen import Agent
|
||||||
|
|
||||||
if self.con is None:
|
if self.con is None:
|
||||||
|
@ -206,7 +222,7 @@ class SqliteLogger(BaseLogger):
|
||||||
args = to_dict(
|
args = to_dict(
|
||||||
init_args,
|
init_args,
|
||||||
exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint"),
|
exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint"),
|
||||||
no_recursive=(Agent),
|
no_recursive=(Agent,),
|
||||||
)
|
)
|
||||||
|
|
||||||
# We do an upsert since both the superclass and subclass may call this method (in that order)
|
# We do an upsert since both the superclass and subclass may call this method (in that order)
|
||||||
|
@ -219,24 +235,18 @@ class SqliteLogger(BaseLogger):
|
||||||
init_args = excluded.init_args,
|
init_args = excluded.init_args,
|
||||||
timestamp = excluded.timestamp
|
timestamp = excluded.timestamp
|
||||||
"""
|
"""
|
||||||
try:
|
args = (
|
||||||
self.cur.execute(
|
id(agent),
|
||||||
query,
|
agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else "",
|
||||||
(
|
self.session_id,
|
||||||
id(agent),
|
agent.name if hasattr(agent, "name") and agent.name is not None else "",
|
||||||
agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else "",
|
type(agent).__name__,
|
||||||
this._session_id,
|
json.dumps(args),
|
||||||
agent.name if hasattr(agent, "name") and agent.name is not None else "",
|
get_current_ts(),
|
||||||
type(agent).__name__,
|
)
|
||||||
json.dumps(args),
|
self._run_query(query=query, args=args)
|
||||||
get_current_ts(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.con.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
logger.error(f"[SqliteLogger] log_new_agent error: {e}")
|
|
||||||
|
|
||||||
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict) -> None:
|
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
|
||||||
if self.con is None:
|
if self.con is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -248,21 +258,17 @@ class SqliteLogger(BaseLogger):
|
||||||
INSERT INTO oai_wrappers (wrapper_id, session_id, init_args, timestamp) VALUES (?, ?, ?, ?)
|
INSERT INTO oai_wrappers (wrapper_id, session_id, init_args, timestamp) VALUES (?, ?, ?, ?)
|
||||||
ON CONFLICT (wrapper_id, session_id) DO NOTHING;
|
ON CONFLICT (wrapper_id, session_id) DO NOTHING;
|
||||||
"""
|
"""
|
||||||
try:
|
args = (
|
||||||
self.cur.execute(
|
id(wrapper),
|
||||||
query,
|
self.session_id,
|
||||||
(
|
json.dumps(args),
|
||||||
id(wrapper),
|
get_current_ts(),
|
||||||
this._session_id,
|
)
|
||||||
json.dumps(args),
|
self._run_query(query=query, args=args)
|
||||||
get_current_ts(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.con.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
logger.error(f"[SqliteLogger] log_new_wrapper error: {e}")
|
|
||||||
|
|
||||||
def log_new_client(self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict) -> None:
|
def log_new_client(
|
||||||
|
self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
if self.con is None:
|
if self.con is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -274,28 +280,21 @@ class SqliteLogger(BaseLogger):
|
||||||
INSERT INTO oai_clients (client_id, wrapper_id, session_id, class, init_args, timestamp) VALUES (?, ?, ?, ?, ?, ?)
|
INSERT INTO oai_clients (client_id, wrapper_id, session_id, class, init_args, timestamp) VALUES (?, ?, ?, ?, ?, ?)
|
||||||
ON CONFLICT (client_id, session_id) DO NOTHING;
|
ON CONFLICT (client_id, session_id) DO NOTHING;
|
||||||
"""
|
"""
|
||||||
try:
|
args = (
|
||||||
self.cur.execute(
|
id(client),
|
||||||
query,
|
id(wrapper),
|
||||||
(
|
self.session_id,
|
||||||
id(client),
|
type(client).__name__,
|
||||||
id(wrapper),
|
json.dumps(args),
|
||||||
this._session_id,
|
get_current_ts(),
|
||||||
type(client).__name__,
|
)
|
||||||
json.dumps(args),
|
self._run_query(query=query, args=args)
|
||||||
get_current_ts(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.con.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
logger.error(f"[SqliteLogger] log_new_client error: {e}")
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
if self.con:
|
if self.con:
|
||||||
self.con.close()
|
self.con.close()
|
||||||
self.con = None
|
|
||||||
self.cur = None
|
|
||||||
|
|
||||||
def get_connection(self) -> sqlite3.Connection:
|
def get_connection(self) -> Union[None, sqlite3.Connection]:
|
||||||
if self.con:
|
if self.con:
|
||||||
return self.con
|
return self.con
|
||||||
|
return None
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from autogen.logger.logger_factory import LoggerFactory
|
from autogen.logger.logger_factory import LoggerFactory
|
||||||
|
from autogen.logger.base_logger import LLMConfig
|
||||||
|
|
||||||
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from openai import OpenAI, AzureOpenAI
|
from openai import OpenAI, AzureOpenAI
|
||||||
|
@ -11,6 +14,8 @@ from openai.types.chat import ChatCompletion
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from autogen import ConversableAgent, OpenAIWrapper
|
from autogen import ConversableAgent, OpenAIWrapper
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
autogen_logger = None
|
autogen_logger = None
|
||||||
is_logging = False
|
is_logging = False
|
||||||
|
|
||||||
|
@ -19,39 +24,57 @@ def start(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None)
|
||||||
global autogen_logger
|
global autogen_logger
|
||||||
global is_logging
|
global is_logging
|
||||||
|
|
||||||
if autogen_logger is None:
|
autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
|
||||||
autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
|
|
||||||
|
|
||||||
session_id = autogen_logger.start()
|
try:
|
||||||
is_logging = True
|
session_id = autogen_logger.start()
|
||||||
|
is_logging = True
|
||||||
return session_id
|
except Exception as e:
|
||||||
|
logger.error(f"[runtime logging] Failed to start logging: {e}")
|
||||||
|
finally:
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
def log_chat_completion(
|
def log_chat_completion(
|
||||||
invocation_id: uuid.UUID,
|
invocation_id: uuid.UUID,
|
||||||
client_id: int,
|
client_id: int,
|
||||||
wrapper_id: int,
|
wrapper_id: int,
|
||||||
request: Dict,
|
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
|
||||||
response: Union[str, ChatCompletion],
|
response: Union[str, ChatCompletion],
|
||||||
is_cached: int,
|
is_cached: int,
|
||||||
cost: float,
|
cost: float,
|
||||||
start_time: str,
|
start_time: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if autogen_logger is None:
|
||||||
|
logger.error("[runtime logging] log_chat_completion: autogen logger is None")
|
||||||
|
return
|
||||||
|
|
||||||
autogen_logger.log_chat_completion(
|
autogen_logger.log_chat_completion(
|
||||||
invocation_id, client_id, wrapper_id, request, response, is_cached, cost, start_time
|
invocation_id, client_id, wrapper_id, request, response, is_cached, cost, start_time
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_new_agent(agent: ConversableAgent, init_args: Dict) -> None:
|
def log_new_agent(agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
|
||||||
|
if autogen_logger is None:
|
||||||
|
logger.error("[runtime logging] log_new_agent: autogen logger is None")
|
||||||
|
return
|
||||||
|
|
||||||
autogen_logger.log_new_agent(agent, init_args)
|
autogen_logger.log_new_agent(agent, init_args)
|
||||||
|
|
||||||
|
|
||||||
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict) -> None:
|
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
|
||||||
|
if autogen_logger is None:
|
||||||
|
logger.error("[runtime logging] log_new_wrapper: autogen logger is None")
|
||||||
|
return
|
||||||
|
|
||||||
autogen_logger.log_new_wrapper(wrapper, init_args)
|
autogen_logger.log_new_wrapper(wrapper, init_args)
|
||||||
|
|
||||||
|
|
||||||
def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict) -> None:
|
def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]) -> None:
|
||||||
|
if autogen_logger is None:
|
||||||
|
logger.error("[runtime logging] log_new_client: autogen logger is None")
|
||||||
|
return
|
||||||
|
|
||||||
autogen_logger.log_new_client(client, wrapper, init_args)
|
autogen_logger.log_new_client(client, wrapper, init_args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,7 +85,11 @@ def stop() -> None:
|
||||||
is_logging = False
|
is_logging = False
|
||||||
|
|
||||||
|
|
||||||
def get_connection() -> Union[sqlite3.Connection]:
|
def get_connection() -> Union[None, sqlite3.Connection]:
|
||||||
|
if autogen_logger is None:
|
||||||
|
logger.error("[runtime logging] get_connection: autogen logger is None")
|
||||||
|
return None
|
||||||
|
|
||||||
return autogen_logger.get_connection()
|
return autogen_logger.get_connection()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -265,4 +265,4 @@ def test_logging_exception_will_not_crash_only_print_error(mock_logger_error, db
|
||||||
|
|
||||||
args, _ = mock_logger_error.call_args
|
args, _ = mock_logger_error.call_args
|
||||||
error_message = args[0]
|
error_message = args[0]
|
||||||
assert error_message.startswith("[SqliteLogger] log_chat_completion error:")
|
assert error_message.startswith("[sqlite logger]Error running query with query")
|
||||||
|
|
Loading…
Reference in New Issue