chore: improve gfibot.model and gfibot.data
This commit is contained in:
parent
6c57fa93cb
commit
d36562514a
16
README.md
16
README.md
|
@ -62,8 +62,22 @@ First, determine some GitHub projects of interest and specify them in [`pyprojec
|
|||
python -m gfibot.check_tokens
|
||||
```
|
||||
|
||||
### Dataset Preparation
|
||||
|
||||
Next, run the following script to collect historical data for the interested projects. This can take some time (up to days) to finish for the first run, but can perform quick incremental update on an existing database. This script should be done periodically (e.g., as a scheduled background task) to ensure that the MongoDB database reflect the latest state in the specified repositories.
|
||||
|
||||
```shell script
|
||||
python -m gfibot.data.update
|
||||
python -m gfibot.data.update --nprocess=4 # you can increase parallelism with more GitHub tokens
|
||||
```
|
||||
|
||||
Then, build a dataset for training and prediction as follows. This script may also take a long time but can be accelerated with more processes.
|
||||
|
||||
```shell script
|
||||
python -m gfibot.data.dataset --since=2008.01.01 --nprocess=4
|
||||
```
|
||||
|
||||
### Model Training
|
||||
|
||||
|
||||
|
||||
### Backend Deployment
|
||||
|
|
|
@ -41,25 +41,24 @@ class TrainingSummary(Document):
|
|||
Attributes:
|
||||
owner, name, threshold: uniquely identifies a GitHub repository and a training setting.
|
||||
If owner="", name="", then this is a global summary result.
|
||||
model_file: relative path to the model file, with repository as root.
|
||||
n_resolved_issues: total number of resolved issues in this repository.
|
||||
n_newcomer_resolved: the number of issues resolved by newcomers in this repository.
|
||||
accuracy: the accuracy of the model on the training data.
|
||||
auc: the area under the ROC curve.
|
||||
accuracy, auc, precision, recall, f1: performance metrics in this repo.
|
||||
last_updated: the last time this training summary was updated.
|
||||
"""
|
||||
|
||||
owner: str = StringField(required=True)
|
||||
name: str = StringField(required=True)
|
||||
threshold: int = IntField(required=True, min_value=1, max_value=5)
|
||||
issues_train: List[list] = ListField(ListField(), default=[])
|
||||
issues_test: List[list] = ListField(ListField(), default=[])
|
||||
threshold: int = IntField(required=True, min_value=1, max_value=5)
|
||||
model_90_file: str = StringField(required=True)
|
||||
model_full_file: str = StringField(required=True)
|
||||
n_resolved_issues: int = IntField(required=True)
|
||||
n_newcomer_resolved: int = IntField(required=True)
|
||||
accuracy: float = FloatField(required=True)
|
||||
auc: float = FloatField(required=True)
|
||||
accuracy: float = FloatField(null=True)
|
||||
auc: float = FloatField(null=True)
|
||||
precision: float = FloatField(null=True)
|
||||
recall: float = FloatField(null=True)
|
||||
f1: float = FloatField(null=True)
|
||||
last_updated: datetime = DateTimeField(required=True)
|
||||
meta = {
|
||||
"indexes": [
|
||||
|
|
|
@ -153,7 +153,7 @@ def _get_categorized_labels(labels: List[str]) -> Dataset.LabelCategory:
|
|||
|
||||
|
||||
def _get_user_data(
|
||||
owner: str, name: str, user: str, t: datetime
|
||||
owner: str, name: str, user: str, t: datetime, all_github: bool = True
|
||||
) -> Dataset.UserFeature:
|
||||
"""Get user data before a certain time t"""
|
||||
feat = Dataset.UserFeature(name=user)
|
||||
|
@ -194,27 +194,28 @@ def _get_user_data(
|
|||
feat.resolver_commits = usr_revolver_cmts
|
||||
|
||||
# GitHub global features
|
||||
query = User.objects(login=user)
|
||||
if all_github:
|
||||
query = User.objects(login=user)
|
||||
|
||||
if query.count() == 0:
|
||||
return feat
|
||||
if query.count() == 0:
|
||||
return feat
|
||||
|
||||
user: User = query.first()
|
||||
commits = [c for c in user.commit_contributions if c.created_at <= t]
|
||||
issues = [i for i in user.issues if i.created_at <= t]
|
||||
pulls = [p for p in user.pulls if p.created_at <= t]
|
||||
reviews = [r for r in user.pull_reviews if r.created_at <= t]
|
||||
feat.n_commits_all = sum(c.commit_count for c in commits)
|
||||
feat.n_issues_all = len(issues)
|
||||
feat.n_pulls_all = len(pulls)
|
||||
feat.n_reviews_all = len(reviews)
|
||||
feat.max_stars_commit = max([c.repo_stars for c in commits] + [0])
|
||||
feat.max_stars_issue = max([i.repo_stars for i in issues] + [0])
|
||||
feat.max_stars_pull = max([p.repo_stars for p in pulls] + [0])
|
||||
feat.max_stars_review = max([r.repo_stars for r in reviews] + [0])
|
||||
feat.n_repos = len(
|
||||
set((x.owner, x.name) for x in commits + issues + pulls + reviews)
|
||||
)
|
||||
user: User = query.first()
|
||||
commits = [c for c in user.commit_contributions if c.created_at <= t]
|
||||
issues = [i for i in user.issues if i.created_at <= t]
|
||||
pulls = [p for p in user.pulls if p.created_at <= t]
|
||||
reviews = [r for r in user.pull_reviews if r.created_at <= t]
|
||||
feat.n_commits_all = sum(c.commit_count for c in commits)
|
||||
feat.n_issues_all = len(issues)
|
||||
feat.n_pulls_all = len(pulls)
|
||||
feat.n_reviews_all = len(reviews)
|
||||
feat.max_stars_commit = max([c.repo_stars for c in commits] + [0])
|
||||
feat.max_stars_issue = max([i.repo_stars for i in issues] + [0])
|
||||
feat.max_stars_pull = max([p.repo_stars for p in pulls] + [0])
|
||||
feat.max_stars_review = max([r.repo_stars for r in reviews] + [0])
|
||||
feat.n_repos = len(
|
||||
set((x.owner, x.name) for x in commits + issues + pulls + reviews)
|
||||
)
|
||||
|
||||
return feat
|
||||
|
||||
|
@ -262,8 +263,10 @@ def _get_dynamics_data(owner: str, name: str, events: List[IssueEvent], t: datet
|
|||
comments.append(event.comment)
|
||||
if event.actor is not None and event.actor != "ghost":
|
||||
comment_users.add(event.actor)
|
||||
comment_users = [_get_user_data(owner, name, user, t) for user in comment_users]
|
||||
event_users = [_get_user_data(owner, name, user, t) for user in event_users]
|
||||
comment_users = [
|
||||
_get_user_data(owner, name, user, t, False) for user in comment_users
|
||||
]
|
||||
event_users = [_get_user_data(owner, name, user, t, False) for user in event_users]
|
||||
return labels, comments, comment_users, event_users
|
||||
|
||||
|
||||
|
@ -433,8 +436,9 @@ def get_dataset_all(since: datetime, n_process: int = None):
|
|||
n_process (int, optional): Number of processes to use. Defaults to None
|
||||
"""
|
||||
if n_process is None:
|
||||
for repo in Repo.objects():
|
||||
get_dataset_for_repo(repo.owner, repo.name, since)
|
||||
repos = [(r.owner, r.name) for r in Repo.objects()]
|
||||
for owner, name in repos:
|
||||
get_dataset_for_repo(owner, name, since)
|
||||
else:
|
||||
params = [(r.owner, r.name, since, None, True) for r in Repo.objects()]
|
||||
with mp.Pool(n_process) as p:
|
||||
|
@ -443,8 +447,12 @@ def get_dataset_all(since: datetime, n_process: int = None):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--since")
|
||||
since = parse_date(parser.parse_args().since)
|
||||
parser.add_argument("--since", type=str, default="2008.01.01")
|
||||
parser.add_argument("--nprocess", type=int, default=mp.cpu_count())
|
||||
args = parser.parse_args()
|
||||
since, nprocess = parse_date(args.since), args.nprocess
|
||||
|
||||
logger.info("Start!")
|
||||
|
||||
mongoengine.connect(
|
||||
CONFIG["mongodb"]["db"],
|
||||
|
@ -453,6 +461,6 @@ if __name__ == "__main__":
|
|||
uuidRepresentation="standard",
|
||||
)
|
||||
|
||||
get_dataset_all(since, mp.cpu_count())
|
||||
get_dataset_all(since, nprocess)
|
||||
|
||||
logger.info("Finish!")
|
||||
|
|
|
@ -272,7 +272,7 @@ def _update_open_issues(fetcher: RepoFetcher, nums: List[int], since: datetime):
|
|||
logger.info("%d open issues updated since %s", repo_open_issues.count(), since)
|
||||
|
||||
open_issues = []
|
||||
for issue in repo_open_issues:
|
||||
for issue in list(repo_open_issues):
|
||||
existing = OpenIssue.objects(query & Q(number=issue.number))
|
||||
if existing.count() > 0:
|
||||
open_issue = existing.first()
|
||||
|
|
|
@ -1,22 +1,35 @@
|
|||
import os
|
||||
import math
|
||||
import logging
|
||||
import argparse
|
||||
import mongoengine
|
||||
import pandas as pd
|
||||
import xgboost as xgb
|
||||
|
||||
from typing import Tuple
|
||||
from datetime import datetime
|
||||
from gfibot import CONFIG
|
||||
from gfibot.model import utils
|
||||
from gfibot.collections import *
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_ROOT_DIRECTORY = "models"
|
||||
os.makedirs(MODEL_ROOT_DIRECTORY, exist_ok=True)
|
||||
|
||||
|
||||
def get_update_set(threshold: int, dataset_batch: List[Dataset]) -> list:
|
||||
def model_90_path(threshold: int) -> str:
|
||||
return f"{MODEL_ROOT_DIRECTORY}/model_90_{threshold}.model"
|
||||
|
||||
|
||||
def model_full_path(threshold: int) -> str:
|
||||
return f"{MODEL_ROOT_DIRECTORY}/model_full_{threshold}.model"
|
||||
|
||||
|
||||
def get_update_set(
|
||||
threshold: int, dataset_batch: List[Dataset]
|
||||
) -> Tuple[str, str, list]: # name, owner, number, before
|
||||
update_set = []
|
||||
for i in dataset_batch:
|
||||
if i.resolver_commit_num == -1:
|
||||
|
@ -29,8 +42,6 @@ def get_update_set(threshold: int, dataset_batch: List[Dataset]) -> list:
|
|||
owner=i.owner,
|
||||
name=i.name,
|
||||
threshold=threshold,
|
||||
model_90_file="",
|
||||
model_full_file="",
|
||||
issues_train=[],
|
||||
issues_test=[],
|
||||
n_resolved_issues=0,
|
||||
|
@ -48,17 +59,19 @@ def get_update_set(threshold: int, dataset_batch: List[Dataset]) -> list:
|
|||
issue_set = train_set + test_set
|
||||
issue_number_set = [iss[0] for iss in issue_set]
|
||||
if i.number in issue_number_set:
|
||||
logger.info(f"{i.owner}/{i.name}#{i.number}): no need to update")
|
||||
logger.info(f"{i.owner}/{i.name}#{i.number}: no need to update")
|
||||
continue
|
||||
else:
|
||||
update_set.append([i.name, i.owner, [i.number, i.before]])
|
||||
logger.info(f"{i.owner}/{i.name}#{i.number}): added")
|
||||
update_set.append((i.name, i.owner, [i.number, i.before]))
|
||||
logger.info(f"{i.owner}/{i.name}#{i.number}: added")
|
||||
return update_set
|
||||
|
||||
|
||||
def update_basic_training_summary(
|
||||
update_set: list, min_test_size: int, threshold: int
|
||||
) -> list:
|
||||
update_set: Tuple[str, str, list],
|
||||
min_test_size: int,
|
||||
threshold: int,
|
||||
) -> List[Tuple[str, str, list]]:
|
||||
train_90_add = []
|
||||
for i in TrainingSummary.objects(threshold=threshold):
|
||||
update_issues = [
|
||||
|
@ -77,7 +90,7 @@ def update_basic_training_summary(
|
|||
elif len(train_set) < 9 * min_test_size:
|
||||
train_set.append(iss)
|
||||
train_set.append(iss)
|
||||
train_90_add.append([i.name, i.owner, iss])
|
||||
train_90_add.append((i.name, i.owner, iss))
|
||||
else:
|
||||
test_set.append(iss)
|
||||
if count == 9:
|
||||
|
@ -85,7 +98,7 @@ def update_basic_training_summary(
|
|||
continue
|
||||
else:
|
||||
train_set.append(test_set[0])
|
||||
train_90_add.append([i.name, i.owner, test_set[0]])
|
||||
train_90_add.append((i.name, i.owner, test_set[0]))
|
||||
test_set.pop(0)
|
||||
count += 1
|
||||
|
||||
|
@ -104,53 +117,93 @@ def update_basic_training_summary(
|
|||
def update_models(
|
||||
update_set: list, train_90_add: list, batch_size: int, threshold: int
|
||||
) -> xgb.core.Booster:
|
||||
model_90_path = TrainingSummary.objects(threshold=threshold)[0].model_90_file
|
||||
if model_90_path == "":
|
||||
model_90_path = None
|
||||
model_90 = utils.update_model(model_90_path, threshold, train_90_add, batch_size)
|
||||
model_full_path = TrainingSummary.objects(threshold=threshold)[0].model_full_file
|
||||
if model_full_path == "":
|
||||
model_full_path = None
|
||||
model_full = utils.update_model(model_full_path, threshold, update_set, batch_size)
|
||||
if model_90 is not None:
|
||||
if 1 - os.path.exists(MODEL_ROOT_DIRECTORY):
|
||||
os.makedirs(MODEL_ROOT_DIRECTORY)
|
||||
model_90.save_model(
|
||||
"{}/model_90_".format(MODEL_ROOT_DIRECTORY) + str(threshold) + ".model"
|
||||
)
|
||||
if model_full is not None:
|
||||
if 1 - os.path.exists(MODEL_ROOT_DIRECTORY):
|
||||
os.makedirs(MODEL_ROOT_DIRECTORY)
|
||||
model_full.save_model(
|
||||
"{}/model_full_".format(MODEL_ROOT_DIRECTORY) + str(threshold) + ".model"
|
||||
)
|
||||
model_90 = utils.update_model(
|
||||
model_90_path(threshold), threshold, train_90_add, batch_size
|
||||
)
|
||||
model_full = utils.update_model(
|
||||
model_full_path(threshold), threshold, update_set, batch_size
|
||||
)
|
||||
model_90.save_model(model_90_path(threshold))
|
||||
model_full.save_model(model_full_path(threshold))
|
||||
return model_90
|
||||
|
||||
|
||||
def update_peformance_training_summary(
|
||||
model_90: xgb.core.Booster, batch_size: int, prob_thres: float, threshold: int
|
||||
):
|
||||
test_set_all = []
|
||||
training_set_all = []
|
||||
n_resolved_issues_all, n_newcomer_resolved_all = 0, 0
|
||||
for i in TrainingSummary.objects(threshold=threshold):
|
||||
test_set = i.issues_test
|
||||
test_set = [[i.name, i.owner, iss] for iss in test_set]
|
||||
test_set_all.extend(i.issues_test)
|
||||
training_set_all.extend(i.issues_train)
|
||||
n_resolved_issues_all += i.n_resolved_issues
|
||||
n_newcomer_resolved_all += i.n_newcomer_resolved
|
||||
|
||||
test_set = [[i.name, i.owner, iss] for iss in i.issues_test]
|
||||
y_test, y_prob = utils.predict_issues(test_set, threshold, batch_size, model_90)
|
||||
y_pred = [int(i > prob_thres) for i in y_prob]
|
||||
|
||||
auc, precision, recall, f1 = utils.get_all_metrics(y_test, y_pred, y_prob)
|
||||
acc = accuracy_score(y_test, y_pred)
|
||||
if math.isnan(auc):
|
||||
auc = 0.0
|
||||
|
||||
i.model_90_file = (
|
||||
"{}/model_90_".format(MODEL_ROOT_DIRECTORY) + str(threshold) + ".model"
|
||||
)
|
||||
i.model_full_file = (
|
||||
"{}/model_full_".format(MODEL_ROOT_DIRECTORY) + str(threshold) + ".model"
|
||||
)
|
||||
i.accuracy = accuracy_score(y_test, y_pred)
|
||||
i.auc = auc
|
||||
i.last_updated = datetime.now()
|
||||
if all(y == 1 for y in y_pred) or all(y == 0 for y in y_pred):
|
||||
logger.info(
|
||||
"Performance for %s/%s is undefined because of all-0 or all-1 labels",
|
||||
i.name,
|
||||
i.owner,
|
||||
)
|
||||
i.accuracy = i.auc = i.precision = i.recall = i.f1 = float("nan")
|
||||
else:
|
||||
i.auc, i.precision, i.recall, i.f1 = utils.get_all_metrics(
|
||||
y_test, y_pred, y_prob
|
||||
)
|
||||
i.accuracy = accuracy_score(y_test, y_pred)
|
||||
i.last_updated = datetime.utcnow()
|
||||
i.save()
|
||||
logger.info(
|
||||
"Performance for %s/%s: acc = %.4f, auc = %.4f, prec = %.4f, recall = %.4f, f1 = %.4f",
|
||||
i.owner,
|
||||
i.name,
|
||||
i.accuracy,
|
||||
i.auc,
|
||||
i.precision,
|
||||
i.recall,
|
||||
i.f1,
|
||||
)
|
||||
|
||||
summ = TrainingSummary.objects(name="", owner="", threshold=threshold)
|
||||
if summ.count() == 0:
|
||||
summ = TrainingSummary(owner="", name="", threshold=threshold)
|
||||
else:
|
||||
summ = summ.first()
|
||||
|
||||
summ.issues_train = training_set_all
|
||||
summ.issues_test = test_set_all
|
||||
summ.n_resolved_issues = n_resolved_issues_all
|
||||
summ.n_newcomer_resolved = n_newcomer_resolved_all
|
||||
|
||||
test_set = [[i.name, i.owner, iss] for iss in test_set_all]
|
||||
y_test, y_prob = utils.predict_issues(test_set, threshold, batch_size, model_90)
|
||||
y_pred = [int(i > prob_thres) for i in y_prob]
|
||||
|
||||
if all(y == 1 for y in y_pred) or all(y == 0 for y in y_pred):
|
||||
logger.info(
|
||||
"Performance is undefined because of all-0 or all-1 labels",
|
||||
)
|
||||
summ.accuracy = summ.auc = summ.precision = summ.recall = summ.f1 = float("nan")
|
||||
else:
|
||||
summ.auc, summ.precision, summ.recall, summ.f1 = utils.get_all_metrics(
|
||||
y_test, y_pred, y_prob
|
||||
)
|
||||
summ.acc = accuracy_score(y_test, y_pred)
|
||||
summ.last_updated = datetime.utcnow()
|
||||
summ.save()
|
||||
logger.info(
|
||||
"Overall Performance: acc = %.4f, auc = %.4f, prec = %.4f, recall = %.4f, f1 = %.4f",
|
||||
summ.accuracy,
|
||||
summ.auc,
|
||||
summ.precision,
|
||||
summ.recall,
|
||||
summ.f1,
|
||||
)
|
||||
|
||||
|
||||
def update_training_summary(
|
||||
|
@ -169,7 +222,7 @@ def update_training_summary(
|
|||
"""
|
||||
dataset = Dataset.objects()
|
||||
ith_batch = 0
|
||||
need_update = 0
|
||||
need_update = False
|
||||
while ith_batch < math.ceil(dataset.count() / dataset_size):
|
||||
dataset_batch = dataset[
|
||||
ith_batch * dataset_size : (ith_batch + 1) * dataset_size
|
||||
|
@ -184,25 +237,20 @@ def update_training_summary(
|
|||
update_set, min_test_size, threshold
|
||||
)
|
||||
model_90 = update_models(update_set, train_90_add, batch_size, threshold)
|
||||
need_update = 1
|
||||
need_update = True
|
||||
ith_batch += 1
|
||||
logger.info("Model updated.")
|
||||
if need_update == 1 and model_90 is not None:
|
||||
if need_update and model_90 is not None:
|
||||
update_peformance_training_summary(model_90, batch_size, prob_thres, threshold)
|
||||
logger.info("Performance on each project is updated.")
|
||||
else:
|
||||
logger.info("No need to update performance.")
|
||||
|
||||
|
||||
def update_prediction_for_issue(i: OpenIssue, threshold: int):
|
||||
model_full_path = TrainingSummary.objects(threshold=threshold)[0].model_full_file
|
||||
def update_prediction_for_issue(issue: Dataset, threshold: int):
|
||||
model_full = xgb.Booster()
|
||||
model_full.load_model(model_full_path)
|
||||
query = Q(name=i.name, owner=i.owner, number=i.number)
|
||||
iss = Dataset.objects(query)
|
||||
iss = iss.first()
|
||||
issue = utils.get_issue_data(iss, threshold)
|
||||
issue_df = pd.DataFrame(issue, index=[0])
|
||||
model_full.load_model(model_full_path(threshold))
|
||||
issue_df = pd.DataFrame(utils.get_issue_data(issue, threshold), index=[0])
|
||||
y_test = issue_df["is_gfi"]
|
||||
X_test = issue_df.drop(["is_gfi", "owner", "name", "number"], axis=1)
|
||||
xg_test = xgb.DMatrix(X_test, label=y_test)
|
||||
|
@ -210,11 +258,11 @@ def update_prediction_for_issue(i: OpenIssue, threshold: int):
|
|||
# print(xg_test)
|
||||
# print(y_prob)
|
||||
Prediction.objects(
|
||||
Q(owner=iss.owner) & Q(name=iss.name) & Q(number=iss.number)
|
||||
owner=issue.owner, name=issue.name, number=issue.number
|
||||
).upsert_one(
|
||||
owner=iss.owner,
|
||||
name=iss.name,
|
||||
number=iss.number,
|
||||
owner=issue.owner,
|
||||
name=issue.name,
|
||||
number=issue.number,
|
||||
threshold=threshold,
|
||||
probability=y_prob,
|
||||
last_updated=datetime.now(),
|
||||
|
@ -223,19 +271,26 @@ def update_prediction_for_issue(i: OpenIssue, threshold: int):
|
|||
|
||||
def update_repo_prediction(owner: str, name: str):
|
||||
for threshold in [1, 2, 3, 4, 5]:
|
||||
for i in OpenIssue.objects(Q(owner=owner) & Q(name=name)):
|
||||
for i in Dataset.objects(owner=owner, name=name, resolver_commit_num=-1):
|
||||
update_prediction_for_issue(i, threshold)
|
||||
|
||||
|
||||
def update_prediction(threshold: int):
|
||||
for i in OpenIssue.objects():
|
||||
for i in Dataset.objects(resolver_commit_num=-1):
|
||||
update_prediction_for_issue(i, threshold)
|
||||
logger.info(
|
||||
"Get predictions for open issues under threshold " + str(threshold) + "."
|
||||
"Get predictions for open issues under threshol Model updated.d "
|
||||
+ str(threshold)
|
||||
+ "."
|
||||
)
|
||||
|
||||
|
||||
def update():
|
||||
def update(cleanup=False):
|
||||
if cleanup:
|
||||
TrainingSummary.drop_collection()
|
||||
Prediction.drop_collection()
|
||||
os.rmdir(MODEL_ROOT_DIRECTORY)
|
||||
os.makedirs(MODEL_ROOT_DIRECTORY, exist_ok=True)
|
||||
for threshold in [1, 2, 3, 4, 5]:
|
||||
logger.info(
|
||||
"Update TrainingSummary and Prediction for threshold "
|
||||
|
@ -247,11 +302,17 @@ def update():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--cleanup", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
mongoengine.connect(
|
||||
CONFIG["mongodb"]["db"],
|
||||
host=CONFIG["mongodb"]["url"],
|
||||
tz_aware=True,
|
||||
uuidRepresentation="standard",
|
||||
)
|
||||
update()
|
||||
|
||||
update(args.cleanup)
|
||||
|
||||
logger.info("Done!")
|
||||
|
|
|
@ -2,7 +2,6 @@ import math
|
|||
import pandas as pd
|
||||
import xgboost as xgb
|
||||
|
||||
from collections import defaultdict
|
||||
from mongoengine import Q
|
||||
from gfibot.collections import *
|
||||
from sklearn.feature_extraction.text import HashingVectorizer
|
||||
|
@ -298,5 +297,5 @@ def dump_dataset(file: str, threshold: int):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
dump_dataset()
|
||||
|
|
|
@ -12,15 +12,15 @@ def test_get_update_set(mock_mongodb):
|
|||
]
|
||||
update_set = get_update_set(threshold, dataset_batch)
|
||||
assert update_set == [
|
||||
["name", "owner", [5, datetime(1970, 1, 3, 0, 0, tzinfo=timezone.utc)]],
|
||||
["name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]],
|
||||
("name", "owner", [5, datetime(1970, 1, 3, 0, 0, tzinfo=timezone.utc)]),
|
||||
("name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]),
|
||||
]
|
||||
|
||||
|
||||
def test_update_basic_training_summary(mock_mongodb):
|
||||
update_set = [
|
||||
["name", "owner", [5, datetime(1970, 1, 3, 0, 0, tzinfo=timezone.utc)]],
|
||||
["name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]],
|
||||
("name", "owner", [5, datetime(1970, 1, 3, 0, 0, tzinfo=timezone.utc)]),
|
||||
("name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]),
|
||||
]
|
||||
dataset_batch = [
|
||||
Dataset.objects(name="name", owner="owner", number=5).first(),
|
||||
|
@ -31,18 +31,18 @@ def test_update_basic_training_summary(mock_mongodb):
|
|||
get_update_set(threshold, dataset_batch)
|
||||
train_90_add = update_basic_training_summary(update_set, min_test_size, threshold)
|
||||
assert train_90_add == [
|
||||
["name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]]
|
||||
("name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)])
|
||||
]
|
||||
|
||||
|
||||
def test_update_models(mock_mongodb):
|
||||
update_set = [
|
||||
["name", "owner", [5, datetime(1970, 1, 3, 0, 0, tzinfo=timezone.utc)]],
|
||||
["name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]],
|
||||
("name", "owner", [5, datetime(1970, 1, 3, 0, 0, tzinfo=timezone.utc)]),
|
||||
("name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]),
|
||||
]
|
||||
train_90_add = [
|
||||
["name", "owner", [5, datetime(1970, 1, 3, tzinfo=timezone.utc)]],
|
||||
["name", "owner", [6, datetime(1971, 1, 3, tzinfo=timezone.utc)]],
|
||||
("name", "owner", [5, datetime(1970, 1, 3, tzinfo=timezone.utc)]),
|
||||
("name", "owner", [6, datetime(1971, 1, 3, tzinfo=timezone.utc)]),
|
||||
]
|
||||
batch_size = 100
|
||||
threshold = 2
|
||||
|
@ -50,8 +50,6 @@ def test_update_models(mock_mongodb):
|
|||
owner="owner",
|
||||
name="name",
|
||||
threshold=threshold,
|
||||
model_90_file="",
|
||||
model_full_file="",
|
||||
issues_train=[],
|
||||
issues_test=[],
|
||||
n_resolved_issues=0,
|
||||
|
|
Loading…
Reference in New Issue