391 lines
13 KiB
Python
391 lines
13 KiB
Python
import asyncio
|
|
from typing import List
|
|
|
|
import pymysql
|
|
|
|
import GlobalConstants
|
|
from ConfigOperator import ConfigOperator
|
|
from models.RepoInfo import RepoInfo
|
|
|
|
|
|
class MySQLOperator(object):
|
|
def __init__(
|
|
self,
|
|
config_path: str,
|
|
autocommit: bool = True,
|
|
repoInfo: RepoInfo = None, # whether create the operator for a specific repository
|
|
) -> None:
|
|
self.config = ConfigOperator(config_path=config_path).config
|
|
self.connection = self.connect(autocommit)
|
|
self.cursor = self.connection.cursor(pymysql.cursors.DictCursor)
|
|
"""
|
|
define the tablenames
|
|
"""
|
|
if repoInfo is not None:
|
|
self.tablename_dict = {}
|
|
for rel_table in GlobalConstants.REL_TABLES:
|
|
self.tablename_dict[
|
|
rel_table
|
|
] = "{repo_id}{separator}{rel_table}".format(
|
|
repo_id=repoInfo.id,
|
|
separator=GlobalConstants.SEPARATOR,
|
|
rel_table=rel_table,
|
|
)
|
|
|
|
def connect(self, autocommit=True):
|
|
"""
|
|
Function: connect a mysql database and obtain the configuration by a YML file
|
|
input:
|
|
- The filepath of a YML file
|
|
return:
|
|
- The connection of a database
|
|
"""
|
|
mysql_config = self.config.get("mysql")
|
|
connection = pymysql.connect(
|
|
host=mysql_config.get("host"),
|
|
port=int(mysql_config.get("port")),
|
|
user=mysql_config.get("user"),
|
|
passwd=mysql_config.get("passwd"),
|
|
db=mysql_config.get("database"),
|
|
autocommit=autocommit,
|
|
)
|
|
return connection
|
|
|
|
def create_tables_4_project(self):
|
|
"""
|
|
Function: create the mysql tables used for running the program
|
|
"""
|
|
|
|
def read_template(template_filepath: str):
|
|
with open(template_filepath, "r") as f:
|
|
return f.read()
|
|
|
|
for rel_tablename, tablename in self.tablename_dict.items():
|
|
template_filepath = "sql_templates/{rel_tablename}.sql".format(
|
|
rel_tablename=rel_tablename
|
|
)
|
|
sql = read_template(template_filepath=template_filepath).format(
|
|
tablename=tablename
|
|
)
|
|
self.cursor.execute(sql)
|
|
self.connection.commit()
|
|
|
|
def create_repositories_table(self):
|
|
"""
|
|
Function: create repositories table
|
|
"""
|
|
with open("sql_templates/repositories.sql", "r") as f:
|
|
sql = f.read()
|
|
self.cursor.execute(sql)
|
|
self.connection.commit()
|
|
|
|
def init_repositories_table(self, repoInfos: List[RepoInfo]):
|
|
"""
|
|
Function: insert all the repositories into repositories table
|
|
params:
|
|
- repoInfos: a list of RepoInfo objects
|
|
"""
|
|
for repoInfo in repoInfos:
|
|
id = repoInfo.id
|
|
ownername = repoInfo.ownername
|
|
reponame = repoInfo.reponame
|
|
self.cursor.execute(
|
|
"insert ignore into repositories (id, ownername, reponame, handled) values (%s, %s, %s, %s)",
|
|
(id, ownername, reponame, 0),
|
|
)
|
|
self.connection.commit()
|
|
|
|
def update_handled_repository(self, repoInfo: RepoInfo):
|
|
"""
|
|
Function: insert all the repositories into repositories table
|
|
params:
|
|
- repoInfos: a list of RepoInfo objects
|
|
"""
|
|
ownername = repoInfo.ownername
|
|
reponame = repoInfo.reponame
|
|
self.cursor.execute(
|
|
"update repositories set handled = 1 where ownername=%s and reponame=%s",
|
|
(ownername, reponame),
|
|
)
|
|
self.connection.commit()
|
|
|
|
def init_steps_table(self, repoInfo: RepoInfo):
|
|
"""
|
|
Function: initialize the handled_repositories table
|
|
params:
|
|
- RepoInfo object
|
|
"""
|
|
steps_tablename = "{repo_id}{separator}steps".format(
|
|
repo_id=repoInfo.id,
|
|
separator=GlobalConstants.SEPARATOR,
|
|
)
|
|
steps = GlobalConstants.STEPS
|
|
for step_id, step_name in steps.items():
|
|
self.cursor.execute(
|
|
"insert ignore into `{steps_tablename}` (step_id, step_name, handled) values (%s, %s, %s)".format(
|
|
steps_tablename=steps_tablename
|
|
),
|
|
(step_id, step_name, 0),
|
|
)
|
|
self.connection.commit()
|
|
|
|
def get_supported_blobs(self, langs: List[str]) -> List:
|
|
"""
|
|
Function: get the paths of supported languages' blobs for a target repo
|
|
params:
|
|
- langs: the list of supported languages
|
|
return:
|
|
- a list of tuples: (blob_sha, lang)
|
|
"""
|
|
result = []
|
|
for lang in langs:
|
|
self.cursor.execute(
|
|
"select distinct b.sha from `{blob_commit_relation_tablename}` bcr, `{filepath_tablename}` f, `{blob_tablename}` b where f.id=bcr.filepath_id and b.id=bcr.blob_id and f.filepath like '%{filepath}'".format(
|
|
blob_commit_relation_tablename=self.tablename_dict[
|
|
"blob_commit_relations"
|
|
],
|
|
filepath_tablename=self.tablename_dict["filepaths"],
|
|
blob_tablename=self.tablename_dict["blobs"],
|
|
filepath="." + lang,
|
|
)
|
|
)
|
|
tmp = self.cursor.fetchall()
|
|
tmp = [(item["sha"], lang.encode()) for item in tmp]
|
|
result.extend(tmp)
|
|
return result
|
|
|
|
def truncate_table(self, tablename: str):
|
|
"""
|
|
Function: truncate a table
|
|
params:
|
|
- tablename: the table waiting for truncate
|
|
"""
|
|
sql = "truncate table `{tablename}`".format(tablename=tablename)
|
|
self.cursor.execute(sql)
|
|
|
|
def delete_all_tables_4_project(self, repoInfo: RepoInfo):
|
|
"""
|
|
Function: delete all the tables for a target repo
|
|
params:
|
|
- RepoInfo object
|
|
"""
|
|
|
|
for rel_tablename in GlobalConstants.REL_TABLES:
|
|
tablename = "{id}{separator}{rel_tablename}".format(
|
|
id=repoInfo.id,
|
|
rel_tablename=rel_tablename,
|
|
separator=GlobalConstants.SEPARATOR,
|
|
)
|
|
sql = "drop table if exists `{tablename}`".format(tablename=tablename)
|
|
self.cursor.execute(sql)
|
|
print("Finish deleting all mysql tables of {id}".format(id=repoInfo.id))
|
|
|
|
def get_blob_sha_id_dict(self):
|
|
"""
|
|
Function: get the blob sha and id relation dict
|
|
"""
|
|
sha_id_dict = {}
|
|
self.cursor.execute(
|
|
"select id, sha from `{blob_tablename}`".format(
|
|
blob_tablename=self.tablename_dict["blobs"]
|
|
)
|
|
)
|
|
items = self.cursor.fetchall()
|
|
for item in items:
|
|
id = item["id"]
|
|
sha = item["sha"]
|
|
sha_id_dict[sha] = id
|
|
return sha_id_dict
|
|
|
|
def get_commit_sha_id_dict(self):
|
|
"""
|
|
Function: get the commit sha and id relation dict
|
|
"""
|
|
sha_id_dict = {}
|
|
self.cursor.execute(
|
|
"select id, sha from `{commit_tablename}`".format(
|
|
commit_tablename=self.tablename_dict["commits"]
|
|
)
|
|
)
|
|
items = self.cursor.fetchall()
|
|
for item in items:
|
|
id = item["id"]
|
|
sha = item["sha"]
|
|
sha_id_dict[sha] = id
|
|
return sha_id_dict
|
|
|
|
def get_commit_ids(self) -> List[int]:
|
|
"""
|
|
Function get the list of commit ids
|
|
return:
|
|
- commit id list
|
|
"""
|
|
self.cursor.execute(
|
|
"select id from `{commit_tablename}` order by id asc".format(
|
|
commit_tablename=self.tablename_dict["commits"]
|
|
)
|
|
)
|
|
result = self.cursor.fetchall()
|
|
result = [id["id"] for id in result]
|
|
return result
|
|
|
|
def get_filepath_sha_id_dict(self):
|
|
"""
|
|
Function: get the filepath sha and id relation dict
|
|
"""
|
|
sha_id_dict = {}
|
|
self.cursor.execute(
|
|
"select id, sha from `{filepath_tablename}`".format(
|
|
filepath_tablename=self.tablename_dict["filepaths"]
|
|
)
|
|
)
|
|
items = self.cursor.fetchall()
|
|
for item in items:
|
|
id = item["id"]
|
|
sha = item["sha"]
|
|
sha_id_dict[sha] = id
|
|
return sha_id_dict
|
|
|
|
def get_filepath_id_dict(self):
|
|
"""
|
|
Function: get the filepath and id relation dict
|
|
"""
|
|
path_id_dict = {}
|
|
self.cursor.execute(
|
|
"select id, filepath from `{filepath_tablename}`".format(
|
|
filepath_tablename=self.tablename_dict["filepaths"]
|
|
)
|
|
)
|
|
items = self.cursor.fetchall()
|
|
for item in items:
|
|
id = item["id"]
|
|
filepath = item["filepath"]
|
|
path_id_dict[filepath] = id
|
|
return path_id_dict
|
|
|
|
def get_method_id_func_id_and_blob_id_dict(self):
|
|
"""
|
|
Function: get method id and function id relationship, as well as get method id and blob id relationship
|
|
return:
|
|
- {method_id: (function_id, blob_id)}
|
|
"""
|
|
result = {}
|
|
self.cursor.execute(
|
|
"select function_id, blob_id, id as method_id from `{blob_method_tablename}`".format(
|
|
blob_method_tablename=self.tablename_dict["blob_methods"]
|
|
)
|
|
)
|
|
items = self.cursor.fetchall()
|
|
for item in items:
|
|
function_id = item["function_id"]
|
|
blob_id = item["blob_id"]
|
|
method_id = item["method_id"]
|
|
result[method_id] = (function_id, blob_id)
|
|
return result
|
|
|
|
def get_fp_lineno_method_id_dict(self, commit_id: int) -> dict:
|
|
"""
|
|
Function: get the dict: filepath -> lineno -> method_id
|
|
params:
|
|
- commit_id: int
|
|
return:
|
|
- {filepath_id: {lineno: method_id}}
|
|
"""
|
|
|
|
result = {}
|
|
|
|
self.cursor.execute(
|
|
"select bm.id, bcr.filepath_id, bm.start, bm.end from `{blob_method_tablename}` bm, `{blob_commit_relation_tablename}` bcr where bm.blob_id=bcr.blob_id and bcr.commit_id=%s".format(
|
|
blob_method_tablename=self.tablename_dict["blob_methods"],
|
|
blob_commit_relation_tablename=self.tablename_dict[
|
|
"blob_commit_relations"
|
|
],
|
|
),
|
|
(commit_id,),
|
|
)
|
|
|
|
items = self.cursor.fetchall()
|
|
for item in items:
|
|
method_id = item["id"]
|
|
start = item["start"]
|
|
end = item["end"]
|
|
filepath_id = item["filepath_id"]
|
|
result.setdefault(filepath_id, {})
|
|
for lineno in range(start, end + 1):
|
|
result[filepath_id][lineno] = method_id
|
|
|
|
return result
|
|
|
|
def get_method_id_name_dict(self) -> dict:
|
|
"""
|
|
Function: get the method id -> name dict
|
|
return:
|
|
- {method_id: method_name}
|
|
"""
|
|
result = {}
|
|
self.cursor.execute(
|
|
"select id, method_name from `{blob_method_tablename}`".format(
|
|
blob_method_tablename=self.tablename_dict["blob_methods"]
|
|
)
|
|
)
|
|
items = self.cursor.fetchall()
|
|
for item in items:
|
|
id = item["id"]
|
|
method_name = item["method_name"]
|
|
result[id] = method_name
|
|
return result
|
|
|
|
def get_commit_id_by_sha(self, sha: bytes) -> int:
|
|
self.cursor.execute(
|
|
"select id from `{commit_tablename}` where sha=%s".format(
|
|
commit_tablename=self.tablename_dict["commits"]
|
|
),
|
|
(sha,),
|
|
)
|
|
commit_id = self.cursor.fetchone()
|
|
if commit_id is None:
|
|
return commit_id
|
|
else:
|
|
return commit_id["id"]
|
|
|
|
def get_blob_ids_by_commit_id(self, commit_id: int) -> List[int]:
|
|
"""
|
|
Function: get all the commit related blob ids
|
|
params:
|
|
- commit_id: the id of the commit
|
|
return:
|
|
- a list of blob ids
|
|
"""
|
|
self.cursor.execute(
|
|
"select blob_id from `{blob_commit_relation_tablename}` where commit_id=%s".format(
|
|
blob_commit_relation_tablename=self.tablename_dict[
|
|
"blob_commit_relations"
|
|
]
|
|
),
|
|
(commit_id,),
|
|
)
|
|
blob_ids = self.cursor.fetchall()
|
|
blob_ids = [item["blob_id"] for item in blob_ids]
|
|
return blob_ids
|
|
|
|
def get_repo_id_by_names(self, repoInfo: RepoInfo) -> int:
|
|
"""
|
|
Function: get the id of repository through ownername and reponame
|
|
params:
|
|
- RepoInfo object
|
|
return:
|
|
- id of repository
|
|
"""
|
|
self.cursor.execute(
|
|
"select id from repositories where ownername=%s and reponame=%s",
|
|
(repoInfo.ownername, repoInfo.reponame),
|
|
)
|
|
id = self.cursor.fetchone()
|
|
if id is None:
|
|
return None
|
|
else:
|
|
return id["id"]
|
|
|
|
|
|
# The test of the function connect is in test.py
|