bad_clone_prediction/MySQLOperator.py

391 lines
13 KiB
Python

import asyncio
from typing import List
import pymysql
import GlobalConstants
from ConfigOperator import ConfigOperator
from models.RepoInfo import RepoInfo
class MySQLOperator(object):
def __init__(
self,
config_path: str,
autocommit: bool = True,
repoInfo: RepoInfo = None, # whether create the operator for a specific repository
) -> None:
self.config = ConfigOperator(config_path=config_path).config
self.connection = self.connect(autocommit)
self.cursor = self.connection.cursor(pymysql.cursors.DictCursor)
"""
define the tablenames
"""
if repoInfo is not None:
self.tablename_dict = {}
for rel_table in GlobalConstants.REL_TABLES:
self.tablename_dict[
rel_table
] = "{repo_id}{separator}{rel_table}".format(
repo_id=repoInfo.id,
separator=GlobalConstants.SEPARATOR,
rel_table=rel_table,
)
def connect(self, autocommit=True):
"""
Function: connect a mysql database and obtain the configuration by a YML file
input:
- The filepath of a YML file
return:
- The connection of a database
"""
mysql_config = self.config.get("mysql")
connection = pymysql.connect(
host=mysql_config.get("host"),
port=int(mysql_config.get("port")),
user=mysql_config.get("user"),
passwd=mysql_config.get("passwd"),
db=mysql_config.get("database"),
autocommit=autocommit,
)
return connection
def create_tables_4_project(self):
"""
Function: create the mysql tables used for running the program
"""
def read_template(template_filepath: str):
with open(template_filepath, "r") as f:
return f.read()
for rel_tablename, tablename in self.tablename_dict.items():
template_filepath = "sql_templates/{rel_tablename}.sql".format(
rel_tablename=rel_tablename
)
sql = read_template(template_filepath=template_filepath).format(
tablename=tablename
)
self.cursor.execute(sql)
self.connection.commit()
def create_repositories_table(self):
"""
Function: create repositories table
"""
with open("sql_templates/repositories.sql", "r") as f:
sql = f.read()
self.cursor.execute(sql)
self.connection.commit()
def init_repositories_table(self, repoInfos: List[RepoInfo]):
"""
Function: insert all the repositories into repositories table
params:
- repoInfos: a list of RepoInfo objects
"""
for repoInfo in repoInfos:
id = repoInfo.id
ownername = repoInfo.ownername
reponame = repoInfo.reponame
self.cursor.execute(
"insert ignore into repositories (id, ownername, reponame, handled) values (%s, %s, %s, %s)",
(id, ownername, reponame, 0),
)
self.connection.commit()
def update_handled_repository(self, repoInfo: RepoInfo):
"""
Function: insert all the repositories into repositories table
params:
- repoInfos: a list of RepoInfo objects
"""
ownername = repoInfo.ownername
reponame = repoInfo.reponame
self.cursor.execute(
"update repositories set handled = 1 where ownername=%s and reponame=%s",
(ownername, reponame),
)
self.connection.commit()
def init_steps_table(self, repoInfo: RepoInfo):
"""
Function: initialize the handled_repositories table
params:
- RepoInfo object
"""
steps_tablename = "{repo_id}{separator}steps".format(
repo_id=repoInfo.id,
separator=GlobalConstants.SEPARATOR,
)
steps = GlobalConstants.STEPS
for step_id, step_name in steps.items():
self.cursor.execute(
"insert ignore into `{steps_tablename}` (step_id, step_name, handled) values (%s, %s, %s)".format(
steps_tablename=steps_tablename
),
(step_id, step_name, 0),
)
self.connection.commit()
def get_supported_blobs(self, langs: List[str]) -> List:
"""
Function: get the paths of supported languages' blobs for a target repo
params:
- langs: the list of supported languages
return:
- a list of tuples: (blob_sha, lang)
"""
result = []
for lang in langs:
self.cursor.execute(
"select distinct b.sha from `{blob_commit_relation_tablename}` bcr, `{filepath_tablename}` f, `{blob_tablename}` b where f.id=bcr.filepath_id and b.id=bcr.blob_id and f.filepath like '%{filepath}'".format(
blob_commit_relation_tablename=self.tablename_dict[
"blob_commit_relations"
],
filepath_tablename=self.tablename_dict["filepaths"],
blob_tablename=self.tablename_dict["blobs"],
filepath="." + lang,
)
)
tmp = self.cursor.fetchall()
tmp = [(item["sha"], lang.encode()) for item in tmp]
result.extend(tmp)
return result
def truncate_table(self, tablename: str):
"""
Function: truncate a table
params:
- tablename: the table waiting for truncate
"""
sql = "truncate table `{tablename}`".format(tablename=tablename)
self.cursor.execute(sql)
def delete_all_tables_4_project(self, repoInfo: RepoInfo):
"""
Function: delete all the tables for a target repo
params:
- RepoInfo object
"""
for rel_tablename in GlobalConstants.REL_TABLES:
tablename = "{id}{separator}{rel_tablename}".format(
id=repoInfo.id,
rel_tablename=rel_tablename,
separator=GlobalConstants.SEPARATOR,
)
sql = "drop table if exists `{tablename}`".format(tablename=tablename)
self.cursor.execute(sql)
print("Finish deleting all mysql tables of {id}".format(id=repoInfo.id))
def get_blob_sha_id_dict(self):
"""
Function: get the blob sha and id relation dict
"""
sha_id_dict = {}
self.cursor.execute(
"select id, sha from `{blob_tablename}`".format(
blob_tablename=self.tablename_dict["blobs"]
)
)
items = self.cursor.fetchall()
for item in items:
id = item["id"]
sha = item["sha"]
sha_id_dict[sha] = id
return sha_id_dict
def get_commit_sha_id_dict(self):
"""
Function: get the commit sha and id relation dict
"""
sha_id_dict = {}
self.cursor.execute(
"select id, sha from `{commit_tablename}`".format(
commit_tablename=self.tablename_dict["commits"]
)
)
items = self.cursor.fetchall()
for item in items:
id = item["id"]
sha = item["sha"]
sha_id_dict[sha] = id
return sha_id_dict
def get_commit_ids(self) -> List[int]:
"""
Function get the list of commit ids
return:
- commit id list
"""
self.cursor.execute(
"select id from `{commit_tablename}` order by id asc".format(
commit_tablename=self.tablename_dict["commits"]
)
)
result = self.cursor.fetchall()
result = [id["id"] for id in result]
return result
def get_filepath_sha_id_dict(self):
"""
Function: get the filepath sha and id relation dict
"""
sha_id_dict = {}
self.cursor.execute(
"select id, sha from `{filepath_tablename}`".format(
filepath_tablename=self.tablename_dict["filepaths"]
)
)
items = self.cursor.fetchall()
for item in items:
id = item["id"]
sha = item["sha"]
sha_id_dict[sha] = id
return sha_id_dict
def get_filepath_id_dict(self):
"""
Function: get the filepath and id relation dict
"""
path_id_dict = {}
self.cursor.execute(
"select id, filepath from `{filepath_tablename}`".format(
filepath_tablename=self.tablename_dict["filepaths"]
)
)
items = self.cursor.fetchall()
for item in items:
id = item["id"]
filepath = item["filepath"]
path_id_dict[filepath] = id
return path_id_dict
def get_method_id_func_id_and_blob_id_dict(self):
"""
Function: get method id and function id relationship, as well as get method id and blob id relationship
return:
- {method_id: (function_id, blob_id)}
"""
result = {}
self.cursor.execute(
"select function_id, blob_id, id as method_id from `{blob_method_tablename}`".format(
blob_method_tablename=self.tablename_dict["blob_methods"]
)
)
items = self.cursor.fetchall()
for item in items:
function_id = item["function_id"]
blob_id = item["blob_id"]
method_id = item["method_id"]
result[method_id] = (function_id, blob_id)
return result
def get_fp_lineno_method_id_dict(self, commit_id: int) -> dict:
"""
Function: get the dict: filepath -> lineno -> method_id
params:
- commit_id: int
return:
- {filepath_id: {lineno: method_id}}
"""
result = {}
self.cursor.execute(
"select bm.id, bcr.filepath_id, bm.start, bm.end from `{blob_method_tablename}` bm, `{blob_commit_relation_tablename}` bcr where bm.blob_id=bcr.blob_id and bcr.commit_id=%s".format(
blob_method_tablename=self.tablename_dict["blob_methods"],
blob_commit_relation_tablename=self.tablename_dict[
"blob_commit_relations"
],
),
(commit_id,),
)
items = self.cursor.fetchall()
for item in items:
method_id = item["id"]
start = item["start"]
end = item["end"]
filepath_id = item["filepath_id"]
result.setdefault(filepath_id, {})
for lineno in range(start, end + 1):
result[filepath_id][lineno] = method_id
return result
def get_method_id_name_dict(self) -> dict:
"""
Function: get the method id -> name dict
return:
- {method_id: method_name}
"""
result = {}
self.cursor.execute(
"select id, method_name from `{blob_method_tablename}`".format(
blob_method_tablename=self.tablename_dict["blob_methods"]
)
)
items = self.cursor.fetchall()
for item in items:
id = item["id"]
method_name = item["method_name"]
result[id] = method_name
return result
def get_commit_id_by_sha(self, sha: bytes) -> int:
self.cursor.execute(
"select id from `{commit_tablename}` where sha=%s".format(
commit_tablename=self.tablename_dict["commits"]
),
(sha,),
)
commit_id = self.cursor.fetchone()
if commit_id is None:
return commit_id
else:
return commit_id["id"]
def get_blob_ids_by_commit_id(self, commit_id: int) -> List[int]:
"""
Function: get all the commit related blob ids
params:
- commit_id: the id of the commit
return:
- a list of blob ids
"""
self.cursor.execute(
"select blob_id from `{blob_commit_relation_tablename}` where commit_id=%s".format(
blob_commit_relation_tablename=self.tablename_dict[
"blob_commit_relations"
]
),
(commit_id,),
)
blob_ids = self.cursor.fetchall()
blob_ids = [item["blob_id"] for item in blob_ids]
return blob_ids
def get_repo_id_by_names(self, repoInfo: RepoInfo) -> int:
"""
Function: get the id of repository through ownername and reponame
params:
- RepoInfo object
return:
- id of repository
"""
self.cursor.execute(
"select id from repositories where ownername=%s and reponame=%s",
(repoInfo.ownername, repoInfo.reponame),
)
id = self.cursor.fetchone()
if id is None:
return None
else:
return id["id"]
# The test of the function connect is in test.py