beimingwu/backend/context.py

155 lines
4.2 KiB
Python

"""context.py: All globals should be defined here.
"""
import os
from config import Config
from database import Database, SQLAlchemy
from typing import List
from learnware.market import instantiate_learnware_market
from learnware.config import C as leanrware_conf
import logging
import redis
import re
import concurrent_log_handler
database: Database = None
engine = None
engine_config = None
redis_pool = None
stats = 0
sensitive_pattern = None
config = Config()
logger = logging.getLogger("")
class MyConcurrentTimedRotatingFileHandler(concurrent_log_handler.ConcurrentTimedRotatingFileHandler):
def getFilesToDelete(self) -> List[str]:
dirName, baseName = os.path.split(self.baseFilename)
fileNames = os.listdir(dirName)
results = []
for fileName in fileNames:
if fileName.startswith(baseName):
results.append(os.path.join(dirName, fileName))
pass
pass
results.sort()
if self.backupCount <= 0:
return results
else:
return results[: -self.backupCount]
pass
def init_database(admin_password: str = None):
global config, database
if config.database["type"] == "sqlalchemy":
database = SQLAlchemy(config=config.database, admin_password=admin_password)
pass
if database is None:
raise ValueError(f"Database type {config.database['type']} is not supproted.")
pass
def init_engine():
global config, engine, engine_config
if config.engine["type"] == "easy":
engine = instantiate_learnware_market(market_id="default", name="eazy", rebuild=False)
engine_config = leanrware_conf
elif config.engine["type"] == "hetero":
engine = instantiate_learnware_market(
market_id="default", name="hetero", rebuild=False, organizer_kwargs={"auto_update": False}
)
engine_config = leanrware_conf
else:
raise ValueError(f"Learnware engine type {config.engine['type']} is not supproted.")
pass
def init_backend():
global config
os.makedirs(config.upload_path, exist_ok=True)
os.makedirs(config.temp_path, exist_ok=True)
os.makedirs(config.backup_path, exist_ok=True)
os.makedirs(config.datasets_path, exist_ok=True)
os.makedirs(config.log_path, exist_ok=True)
os.makedirs(config.env_path, exist_ok=True)
pass
def init_redis():
global config, redis_pool
redis_pool = redis.ConnectionPool(
host=config.redis["host"], port=config.redis["port"], decode_responses=True, max_connections=256
)
pass
def get_redis_client():
global redis_pool
return redis.Redis(connection_pool=redis_pool, decode_responses=True)
def init_logger():
global config, logger
target = config["log_target"]
if len(logger.handlers) > 0:
pass
elif target == "console":
handler = logging.StreamHandler()
elif target == "file":
log_path = config["log_path"]
hostname = os.environ.get("HOSTNAME")
if hostname is None:
hostname = "learnware-backend"
pass
filename = os.path.join(log_path, f"{hostname}.log")
handler = MyConcurrentTimedRotatingFileHandler(
filename=filename, when="d", interval=1, backupCount=180, encoding="utf-8"
)
pass
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
pass
def init_sensitive_words():
global sensitive_pattern
word_file = config["sensitive_word_file"]
sensitive_words = []
if os.path.exists(word_file):
with open(word_file) as fin:
for line in fin:
w = line.strip()
if len(w) > 0:
if w.isascii():
w = " " + w + " "
pass
sensitive_words.append(re.escape(w))
pass
pass
pass
sensitive_pattern = re.compile("(" + "|".join(sensitive_words) + ")")
pass
pass
def get_learnware_verify_file_path(learnware_id):
return os.path.join(config.upload_path, f"{learnware_id}.zip")