chore: improve gfibot.model and gfibot.data

This commit is contained in:
Hao He 2022-06-13 15:43:54 +00:00
parent 6c57fa93cb
commit d36562514a
7 changed files with 196 additions and 117 deletions

View File

@ -62,8 +62,22 @@ First, determine some GitHub projects of interest and specify them in [`pyprojec
python -m gfibot.check_tokens python -m gfibot.check_tokens
``` ```
### Dataset Preparation
Next, run the following script to collect historical data for the interested projects. This can take some time (up to days) to finish for the first run, but can perform quick incremental update on an existing database. This script should be done periodically (e.g., as a scheduled background task) to ensure that the MongoDB database reflect the latest state in the specified repositories. Next, run the following script to collect historical data for the interested projects. This can take some time (up to days) to finish for the first run, but can perform quick incremental update on an existing database. This script should be done periodically (e.g., as a scheduled background task) to ensure that the MongoDB database reflect the latest state in the specified repositories.
```shell script ```shell script
python -m gfibot.data.update python -m gfibot.data.update --nprocess=4 # you can increase parallelism with more GitHub tokens
``` ```
Then, build a dataset for training and prediction as follows. This script may also take a long time but can be accelerated with more processes.
```shell script
python -m gfibot.data.dataset --since=2008.01.01 --nprocess=4
```
### Model Training
### Backend Deployment

View File

@ -41,25 +41,24 @@ class TrainingSummary(Document):
Attributes: Attributes:
owner, name, threshold: uniquely identifies a GitHub repository and a training setting. owner, name, threshold: uniquely identifies a GitHub repository and a training setting.
If owner="", name="", then this is a global summary result. If owner="", name="", then this is a global summary result.
model_file: relative path to the model file, with repository as root.
n_resolved_issues: total number of resolved issues in this repository. n_resolved_issues: total number of resolved issues in this repository.
n_newcomer_resolved: the number of issues resolved by newcomers in this repository. n_newcomer_resolved: the number of issues resolved by newcomers in this repository.
accuracy: the accuracy of the model on the training data. accuracy, auc, precision, recall, f1: performance metrics in this repo.
auc: the area under the ROC curve.
last_updated: the last time this training summary was updated. last_updated: the last time this training summary was updated.
""" """
owner: str = StringField(required=True) owner: str = StringField(required=True)
name: str = StringField(required=True) name: str = StringField(required=True)
threshold: int = IntField(required=True, min_value=1, max_value=5)
issues_train: List[list] = ListField(ListField(), default=[]) issues_train: List[list] = ListField(ListField(), default=[])
issues_test: List[list] = ListField(ListField(), default=[]) issues_test: List[list] = ListField(ListField(), default=[])
threshold: int = IntField(required=True, min_value=1, max_value=5)
model_90_file: str = StringField(required=True)
model_full_file: str = StringField(required=True)
n_resolved_issues: int = IntField(required=True) n_resolved_issues: int = IntField(required=True)
n_newcomer_resolved: int = IntField(required=True) n_newcomer_resolved: int = IntField(required=True)
accuracy: float = FloatField(required=True) accuracy: float = FloatField(null=True)
auc: float = FloatField(required=True) auc: float = FloatField(null=True)
precision: float = FloatField(null=True)
recall: float = FloatField(null=True)
f1: float = FloatField(null=True)
last_updated: datetime = DateTimeField(required=True) last_updated: datetime = DateTimeField(required=True)
meta = { meta = {
"indexes": [ "indexes": [

View File

@ -153,7 +153,7 @@ def _get_categorized_labels(labels: List[str]) -> Dataset.LabelCategory:
def _get_user_data( def _get_user_data(
owner: str, name: str, user: str, t: datetime owner: str, name: str, user: str, t: datetime, all_github: bool = True
) -> Dataset.UserFeature: ) -> Dataset.UserFeature:
"""Get user data before a certain time t""" """Get user data before a certain time t"""
feat = Dataset.UserFeature(name=user) feat = Dataset.UserFeature(name=user)
@ -194,27 +194,28 @@ def _get_user_data(
feat.resolver_commits = usr_revolver_cmts feat.resolver_commits = usr_revolver_cmts
# GitHub global features # GitHub global features
query = User.objects(login=user) if all_github:
query = User.objects(login=user)
if query.count() == 0: if query.count() == 0:
return feat return feat
user: User = query.first() user: User = query.first()
commits = [c for c in user.commit_contributions if c.created_at <= t] commits = [c for c in user.commit_contributions if c.created_at <= t]
issues = [i for i in user.issues if i.created_at <= t] issues = [i for i in user.issues if i.created_at <= t]
pulls = [p for p in user.pulls if p.created_at <= t] pulls = [p for p in user.pulls if p.created_at <= t]
reviews = [r for r in user.pull_reviews if r.created_at <= t] reviews = [r for r in user.pull_reviews if r.created_at <= t]
feat.n_commits_all = sum(c.commit_count for c in commits) feat.n_commits_all = sum(c.commit_count for c in commits)
feat.n_issues_all = len(issues) feat.n_issues_all = len(issues)
feat.n_pulls_all = len(pulls) feat.n_pulls_all = len(pulls)
feat.n_reviews_all = len(reviews) feat.n_reviews_all = len(reviews)
feat.max_stars_commit = max([c.repo_stars for c in commits] + [0]) feat.max_stars_commit = max([c.repo_stars for c in commits] + [0])
feat.max_stars_issue = max([i.repo_stars for i in issues] + [0]) feat.max_stars_issue = max([i.repo_stars for i in issues] + [0])
feat.max_stars_pull = max([p.repo_stars for p in pulls] + [0]) feat.max_stars_pull = max([p.repo_stars for p in pulls] + [0])
feat.max_stars_review = max([r.repo_stars for r in reviews] + [0]) feat.max_stars_review = max([r.repo_stars for r in reviews] + [0])
feat.n_repos = len( feat.n_repos = len(
set((x.owner, x.name) for x in commits + issues + pulls + reviews) set((x.owner, x.name) for x in commits + issues + pulls + reviews)
) )
return feat return feat
@ -262,8 +263,10 @@ def _get_dynamics_data(owner: str, name: str, events: List[IssueEvent], t: datet
comments.append(event.comment) comments.append(event.comment)
if event.actor is not None and event.actor != "ghost": if event.actor is not None and event.actor != "ghost":
comment_users.add(event.actor) comment_users.add(event.actor)
comment_users = [_get_user_data(owner, name, user, t) for user in comment_users] comment_users = [
event_users = [_get_user_data(owner, name, user, t) for user in event_users] _get_user_data(owner, name, user, t, False) for user in comment_users
]
event_users = [_get_user_data(owner, name, user, t, False) for user in event_users]
return labels, comments, comment_users, event_users return labels, comments, comment_users, event_users
@ -433,8 +436,9 @@ def get_dataset_all(since: datetime, n_process: int = None):
n_process (int, optional): Number of processes to use. Defaults to None n_process (int, optional): Number of processes to use. Defaults to None
""" """
if n_process is None: if n_process is None:
for repo in Repo.objects(): repos = [(r.owner, r.name) for r in Repo.objects()]
get_dataset_for_repo(repo.owner, repo.name, since) for owner, name in repos:
get_dataset_for_repo(owner, name, since)
else: else:
params = [(r.owner, r.name, since, None, True) for r in Repo.objects()] params = [(r.owner, r.name, since, None, True) for r in Repo.objects()]
with mp.Pool(n_process) as p: with mp.Pool(n_process) as p:
@ -443,8 +447,12 @@ def get_dataset_all(since: datetime, n_process: int = None):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--since") parser.add_argument("--since", type=str, default="2008.01.01")
since = parse_date(parser.parse_args().since) parser.add_argument("--nprocess", type=int, default=mp.cpu_count())
args = parser.parse_args()
since, nprocess = parse_date(args.since), args.nprocess
logger.info("Start!")
mongoengine.connect( mongoengine.connect(
CONFIG["mongodb"]["db"], CONFIG["mongodb"]["db"],
@ -453,6 +461,6 @@ if __name__ == "__main__":
uuidRepresentation="standard", uuidRepresentation="standard",
) )
get_dataset_all(since, mp.cpu_count()) get_dataset_all(since, nprocess)
logger.info("Finish!") logger.info("Finish!")

View File

@ -272,7 +272,7 @@ def _update_open_issues(fetcher: RepoFetcher, nums: List[int], since: datetime):
logger.info("%d open issues updated since %s", repo_open_issues.count(), since) logger.info("%d open issues updated since %s", repo_open_issues.count(), since)
open_issues = [] open_issues = []
for issue in repo_open_issues: for issue in list(repo_open_issues):
existing = OpenIssue.objects(query & Q(number=issue.number)) existing = OpenIssue.objects(query & Q(number=issue.number))
if existing.count() > 0: if existing.count() > 0:
open_issue = existing.first() open_issue = existing.first()

View File

@ -1,22 +1,35 @@
import os import os
import math import math
import logging import logging
import argparse
import mongoengine import mongoengine
import pandas as pd import pandas as pd
import xgboost as xgb import xgboost as xgb
from typing import Tuple
from datetime import datetime from datetime import datetime
from gfibot import CONFIG from gfibot import CONFIG
from gfibot.model import utils from gfibot.model import utils
from gfibot.collections import * from gfibot.collections import *
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MODEL_ROOT_DIRECTORY = "models" MODEL_ROOT_DIRECTORY = "models"
os.makedirs(MODEL_ROOT_DIRECTORY, exist_ok=True)
def get_update_set(threshold: int, dataset_batch: List[Dataset]) -> list: def model_90_path(threshold: int) -> str:
return f"{MODEL_ROOT_DIRECTORY}/model_90_{threshold}.model"
def model_full_path(threshold: int) -> str:
return f"{MODEL_ROOT_DIRECTORY}/model_full_{threshold}.model"
def get_update_set(
threshold: int, dataset_batch: List[Dataset]
) -> Tuple[str, str, list]: # name, owner, number, before
update_set = [] update_set = []
for i in dataset_batch: for i in dataset_batch:
if i.resolver_commit_num == -1: if i.resolver_commit_num == -1:
@ -29,8 +42,6 @@ def get_update_set(threshold: int, dataset_batch: List[Dataset]) -> list:
owner=i.owner, owner=i.owner,
name=i.name, name=i.name,
threshold=threshold, threshold=threshold,
model_90_file="",
model_full_file="",
issues_train=[], issues_train=[],
issues_test=[], issues_test=[],
n_resolved_issues=0, n_resolved_issues=0,
@ -48,17 +59,19 @@ def get_update_set(threshold: int, dataset_batch: List[Dataset]) -> list:
issue_set = train_set + test_set issue_set = train_set + test_set
issue_number_set = [iss[0] for iss in issue_set] issue_number_set = [iss[0] for iss in issue_set]
if i.number in issue_number_set: if i.number in issue_number_set:
logger.info(f"{i.owner}/{i.name}#{i.number}): no need to update") logger.info(f"{i.owner}/{i.name}#{i.number}: no need to update")
continue continue
else: else:
update_set.append([i.name, i.owner, [i.number, i.before]]) update_set.append((i.name, i.owner, [i.number, i.before]))
logger.info(f"{i.owner}/{i.name}#{i.number}): added") logger.info(f"{i.owner}/{i.name}#{i.number}: added")
return update_set return update_set
def update_basic_training_summary( def update_basic_training_summary(
update_set: list, min_test_size: int, threshold: int update_set: Tuple[str, str, list],
) -> list: min_test_size: int,
threshold: int,
) -> List[Tuple[str, str, list]]:
train_90_add = [] train_90_add = []
for i in TrainingSummary.objects(threshold=threshold): for i in TrainingSummary.objects(threshold=threshold):
update_issues = [ update_issues = [
@ -77,7 +90,7 @@ def update_basic_training_summary(
elif len(train_set) < 9 * min_test_size: elif len(train_set) < 9 * min_test_size:
train_set.append(iss) train_set.append(iss)
train_set.append(iss) train_set.append(iss)
train_90_add.append([i.name, i.owner, iss]) train_90_add.append((i.name, i.owner, iss))
else: else:
test_set.append(iss) test_set.append(iss)
if count == 9: if count == 9:
@ -85,7 +98,7 @@ def update_basic_training_summary(
continue continue
else: else:
train_set.append(test_set[0]) train_set.append(test_set[0])
train_90_add.append([i.name, i.owner, test_set[0]]) train_90_add.append((i.name, i.owner, test_set[0]))
test_set.pop(0) test_set.pop(0)
count += 1 count += 1
@ -104,53 +117,93 @@ def update_basic_training_summary(
def update_models( def update_models(
update_set: list, train_90_add: list, batch_size: int, threshold: int update_set: list, train_90_add: list, batch_size: int, threshold: int
) -> xgb.core.Booster: ) -> xgb.core.Booster:
model_90_path = TrainingSummary.objects(threshold=threshold)[0].model_90_file model_90 = utils.update_model(
if model_90_path == "": model_90_path(threshold), threshold, train_90_add, batch_size
model_90_path = None )
model_90 = utils.update_model(model_90_path, threshold, train_90_add, batch_size) model_full = utils.update_model(
model_full_path = TrainingSummary.objects(threshold=threshold)[0].model_full_file model_full_path(threshold), threshold, update_set, batch_size
if model_full_path == "": )
model_full_path = None model_90.save_model(model_90_path(threshold))
model_full = utils.update_model(model_full_path, threshold, update_set, batch_size) model_full.save_model(model_full_path(threshold))
if model_90 is not None:
if 1 - os.path.exists(MODEL_ROOT_DIRECTORY):
os.makedirs(MODEL_ROOT_DIRECTORY)
model_90.save_model(
"{}/model_90_".format(MODEL_ROOT_DIRECTORY) + str(threshold) + ".model"
)
if model_full is not None:
if 1 - os.path.exists(MODEL_ROOT_DIRECTORY):
os.makedirs(MODEL_ROOT_DIRECTORY)
model_full.save_model(
"{}/model_full_".format(MODEL_ROOT_DIRECTORY) + str(threshold) + ".model"
)
return model_90 return model_90
def update_peformance_training_summary( def update_peformance_training_summary(
model_90: xgb.core.Booster, batch_size: int, prob_thres: float, threshold: int model_90: xgb.core.Booster, batch_size: int, prob_thres: float, threshold: int
): ):
test_set_all = []
training_set_all = []
n_resolved_issues_all, n_newcomer_resolved_all = 0, 0
for i in TrainingSummary.objects(threshold=threshold): for i in TrainingSummary.objects(threshold=threshold):
test_set = i.issues_test test_set_all.extend(i.issues_test)
test_set = [[i.name, i.owner, iss] for iss in test_set] training_set_all.extend(i.issues_train)
n_resolved_issues_all += i.n_resolved_issues
n_newcomer_resolved_all += i.n_newcomer_resolved
test_set = [[i.name, i.owner, iss] for iss in i.issues_test]
y_test, y_prob = utils.predict_issues(test_set, threshold, batch_size, model_90) y_test, y_prob = utils.predict_issues(test_set, threshold, batch_size, model_90)
y_pred = [int(i > prob_thres) for i in y_prob] y_pred = [int(i > prob_thres) for i in y_prob]
auc, precision, recall, f1 = utils.get_all_metrics(y_test, y_pred, y_prob) if all(y == 1 for y in y_pred) or all(y == 0 for y in y_pred):
acc = accuracy_score(y_test, y_pred) logger.info(
if math.isnan(auc): "Performance for %s/%s is undefined because of all-0 or all-1 labels",
auc = 0.0 i.name,
i.owner,
i.model_90_file = ( )
"{}/model_90_".format(MODEL_ROOT_DIRECTORY) + str(threshold) + ".model" i.accuracy = i.auc = i.precision = i.recall = i.f1 = float("nan")
) else:
i.model_full_file = ( i.auc, i.precision, i.recall, i.f1 = utils.get_all_metrics(
"{}/model_full_".format(MODEL_ROOT_DIRECTORY) + str(threshold) + ".model" y_test, y_pred, y_prob
) )
i.accuracy = accuracy_score(y_test, y_pred) i.accuracy = accuracy_score(y_test, y_pred)
i.auc = auc i.last_updated = datetime.utcnow()
i.last_updated = datetime.now()
i.save() i.save()
logger.info(
"Performance for %s/%s: acc = %.4f, auc = %.4f, prec = %.4f, recall = %.4f, f1 = %.4f",
i.owner,
i.name,
i.accuracy,
i.auc,
i.precision,
i.recall,
i.f1,
)
summ = TrainingSummary.objects(name="", owner="", threshold=threshold)
if summ.count() == 0:
summ = TrainingSummary(owner="", name="", threshold=threshold)
else:
summ = summ.first()
summ.issues_train = training_set_all
summ.issues_test = test_set_all
summ.n_resolved_issues = n_resolved_issues_all
summ.n_newcomer_resolved = n_newcomer_resolved_all
test_set = [[i.name, i.owner, iss] for iss in test_set_all]
y_test, y_prob = utils.predict_issues(test_set, threshold, batch_size, model_90)
y_pred = [int(i > prob_thres) for i in y_prob]
if all(y == 1 for y in y_pred) or all(y == 0 for y in y_pred):
logger.info(
"Performance is undefined because of all-0 or all-1 labels",
)
summ.accuracy = summ.auc = summ.precision = summ.recall = summ.f1 = float("nan")
else:
summ.auc, summ.precision, summ.recall, summ.f1 = utils.get_all_metrics(
y_test, y_pred, y_prob
)
summ.acc = accuracy_score(y_test, y_pred)
summ.last_updated = datetime.utcnow()
summ.save()
logger.info(
"Overall Performance: acc = %.4f, auc = %.4f, prec = %.4f, recall = %.4f, f1 = %.4f",
summ.accuracy,
summ.auc,
summ.precision,
summ.recall,
summ.f1,
)
def update_training_summary( def update_training_summary(
@ -169,7 +222,7 @@ def update_training_summary(
""" """
dataset = Dataset.objects() dataset = Dataset.objects()
ith_batch = 0 ith_batch = 0
need_update = 0 need_update = False
while ith_batch < math.ceil(dataset.count() / dataset_size): while ith_batch < math.ceil(dataset.count() / dataset_size):
dataset_batch = dataset[ dataset_batch = dataset[
ith_batch * dataset_size : (ith_batch + 1) * dataset_size ith_batch * dataset_size : (ith_batch + 1) * dataset_size
@ -184,25 +237,20 @@ def update_training_summary(
update_set, min_test_size, threshold update_set, min_test_size, threshold
) )
model_90 = update_models(update_set, train_90_add, batch_size, threshold) model_90 = update_models(update_set, train_90_add, batch_size, threshold)
need_update = 1 need_update = True
ith_batch += 1 ith_batch += 1
logger.info("Model updated.") logger.info("Model updated.")
if need_update == 1 and model_90 is not None: if need_update and model_90 is not None:
update_peformance_training_summary(model_90, batch_size, prob_thres, threshold) update_peformance_training_summary(model_90, batch_size, prob_thres, threshold)
logger.info("Performance on each project is updated.") logger.info("Performance on each project is updated.")
else: else:
logger.info("No need to update performance.") logger.info("No need to update performance.")
def update_prediction_for_issue(i: OpenIssue, threshold: int): def update_prediction_for_issue(issue: Dataset, threshold: int):
model_full_path = TrainingSummary.objects(threshold=threshold)[0].model_full_file
model_full = xgb.Booster() model_full = xgb.Booster()
model_full.load_model(model_full_path) model_full.load_model(model_full_path(threshold))
query = Q(name=i.name, owner=i.owner, number=i.number) issue_df = pd.DataFrame(utils.get_issue_data(issue, threshold), index=[0])
iss = Dataset.objects(query)
iss = iss.first()
issue = utils.get_issue_data(iss, threshold)
issue_df = pd.DataFrame(issue, index=[0])
y_test = issue_df["is_gfi"] y_test = issue_df["is_gfi"]
X_test = issue_df.drop(["is_gfi", "owner", "name", "number"], axis=1) X_test = issue_df.drop(["is_gfi", "owner", "name", "number"], axis=1)
xg_test = xgb.DMatrix(X_test, label=y_test) xg_test = xgb.DMatrix(X_test, label=y_test)
@ -210,11 +258,11 @@ def update_prediction_for_issue(i: OpenIssue, threshold: int):
# print(xg_test) # print(xg_test)
# print(y_prob) # print(y_prob)
Prediction.objects( Prediction.objects(
Q(owner=iss.owner) & Q(name=iss.name) & Q(number=iss.number) owner=issue.owner, name=issue.name, number=issue.number
).upsert_one( ).upsert_one(
owner=iss.owner, owner=issue.owner,
name=iss.name, name=issue.name,
number=iss.number, number=issue.number,
threshold=threshold, threshold=threshold,
probability=y_prob, probability=y_prob,
last_updated=datetime.now(), last_updated=datetime.now(),
@ -223,19 +271,26 @@ def update_prediction_for_issue(i: OpenIssue, threshold: int):
def update_repo_prediction(owner: str, name: str): def update_repo_prediction(owner: str, name: str):
for threshold in [1, 2, 3, 4, 5]: for threshold in [1, 2, 3, 4, 5]:
for i in OpenIssue.objects(Q(owner=owner) & Q(name=name)): for i in Dataset.objects(owner=owner, name=name, resolver_commit_num=-1):
update_prediction_for_issue(i, threshold) update_prediction_for_issue(i, threshold)
def update_prediction(threshold: int): def update_prediction(threshold: int):
for i in OpenIssue.objects(): for i in Dataset.objects(resolver_commit_num=-1):
update_prediction_for_issue(i, threshold) update_prediction_for_issue(i, threshold)
logger.info( logger.info(
"Get predictions for open issues under threshold " + str(threshold) + "." "Get predictions for open issues under threshol Model updated.d "
+ str(threshold)
+ "."
) )
def update(): def update(cleanup=False):
if cleanup:
TrainingSummary.drop_collection()
Prediction.drop_collection()
os.rmdir(MODEL_ROOT_DIRECTORY)
os.makedirs(MODEL_ROOT_DIRECTORY, exist_ok=True)
for threshold in [1, 2, 3, 4, 5]: for threshold in [1, 2, 3, 4, 5]:
logger.info( logger.info(
"Update TrainingSummary and Prediction for threshold " "Update TrainingSummary and Prediction for threshold "
@ -247,11 +302,17 @@ def update():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--cleanup", action="store_true")
args = parser.parse_args()
mongoengine.connect( mongoengine.connect(
CONFIG["mongodb"]["db"], CONFIG["mongodb"]["db"],
host=CONFIG["mongodb"]["url"], host=CONFIG["mongodb"]["url"],
tz_aware=True, tz_aware=True,
uuidRepresentation="standard", uuidRepresentation="standard",
) )
update()
update(args.cleanup)
logger.info("Done!") logger.info("Done!")

View File

@ -2,7 +2,6 @@ import math
import pandas as pd import pandas as pd
import xgboost as xgb import xgboost as xgb
from collections import defaultdict
from mongoengine import Q from mongoengine import Q
from gfibot.collections import * from gfibot.collections import *
from sklearn.feature_extraction.text import HashingVectorizer from sklearn.feature_extraction.text import HashingVectorizer
@ -298,5 +297,5 @@ def dump_dataset(file: str, threshold: int):
if __name__ == "__main__": if __name__ == "__main__":
dump_dataset() dump_dataset()

View File

@ -12,15 +12,15 @@ def test_get_update_set(mock_mongodb):
] ]
update_set = get_update_set(threshold, dataset_batch) update_set = get_update_set(threshold, dataset_batch)
assert update_set == [ assert update_set == [
["name", "owner", [5, datetime(1970, 1, 3, 0, 0, tzinfo=timezone.utc)]], ("name", "owner", [5, datetime(1970, 1, 3, 0, 0, tzinfo=timezone.utc)]),
["name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]], ("name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]),
] ]
def test_update_basic_training_summary(mock_mongodb): def test_update_basic_training_summary(mock_mongodb):
update_set = [ update_set = [
["name", "owner", [5, datetime(1970, 1, 3, 0, 0, tzinfo=timezone.utc)]], ("name", "owner", [5, datetime(1970, 1, 3, 0, 0, tzinfo=timezone.utc)]),
["name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]], ("name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]),
] ]
dataset_batch = [ dataset_batch = [
Dataset.objects(name="name", owner="owner", number=5).first(), Dataset.objects(name="name", owner="owner", number=5).first(),
@ -31,18 +31,18 @@ def test_update_basic_training_summary(mock_mongodb):
get_update_set(threshold, dataset_batch) get_update_set(threshold, dataset_batch)
train_90_add = update_basic_training_summary(update_set, min_test_size, threshold) train_90_add = update_basic_training_summary(update_set, min_test_size, threshold)
assert train_90_add == [ assert train_90_add == [
["name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]] ("name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)])
] ]
def test_update_models(mock_mongodb): def test_update_models(mock_mongodb):
update_set = [ update_set = [
["name", "owner", [5, datetime(1970, 1, 3, 0, 0, tzinfo=timezone.utc)]], ("name", "owner", [5, datetime(1970, 1, 3, 0, 0, tzinfo=timezone.utc)]),
["name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]], ("name", "owner", [6, datetime(1971, 1, 3, 0, 0, tzinfo=timezone.utc)]),
] ]
train_90_add = [ train_90_add = [
["name", "owner", [5, datetime(1970, 1, 3, tzinfo=timezone.utc)]], ("name", "owner", [5, datetime(1970, 1, 3, tzinfo=timezone.utc)]),
["name", "owner", [6, datetime(1971, 1, 3, tzinfo=timezone.utc)]], ("name", "owner", [6, datetime(1971, 1, 3, tzinfo=timezone.utc)]),
] ]
batch_size = 100 batch_size = 100
threshold = 2 threshold = 2
@ -50,8 +50,6 @@ def test_update_models(mock_mongodb):
owner="owner", owner="owner",
name="name", name="name",
threshold=threshold, threshold=threshold,
model_90_file="",
model_full_file="",
issues_train=[], issues_train=[],
issues_test=[], issues_test=[],
n_resolved_issues=0, n_resolved_issues=0,