diff --git a/deepod/core/base_model.py b/deepod/core/base_model.py index f4ea7dc..42d2d9a 100644 --- a/deepod/core/base_model.py +++ b/deepod/core/base_model.py @@ -6,7 +6,7 @@ some functions are adapted from the pyod library """ import sys import warnings - +import pickle import numpy as np import torch import random @@ -527,3 +527,12 @@ class BaseDeepAD(metaclass=ABCMeta): random.seed(seed) # torch.backends.cudnn.benchmark = False # torch.backends.cudnn.deterministic = True + + @classmethod + def load_model(cls, path): + with open(f"{path}.pkl", mode="rb") as f: + return pickle.load(f) + + def save_model(self, path): + with open(f"{path}.pkl", mode="wb") as f: + pickle.dump(self, f) diff --git a/deepod/models/time_series/dcdetector.py b/deepod/models/time_series/dcdetector.py index 6d80034..cb5784b 100644 --- a/deepod/models/time_series/dcdetector.py +++ b/deepod/models/time_series/dcdetector.py @@ -22,14 +22,15 @@ class DCdetector(BaseDeepAD): n_heads=1, d_model=256, e_layers=3, patch_size=None, verbose=2, random_state=42): super(DCdetector, self).__init__( - model_name='TranAD', data_type='ts', epochs=epochs, batch_size=batch_size, lr=lr, + model_name='DCdetector', data_type='ts', epochs=epochs, batch_size=batch_size, lr=lr, seq_len=seq_len, stride=stride, epoch_steps=epoch_steps, prt_steps=prt_steps, device=device, verbose=verbose, random_state=random_state ) if patch_size is None: self.patch_size = [5] # seq_len must be divisible by patch_size - self.patch_size = patch_size + else: + self.patch_size = patch_size self.n_heads = n_heads self.d_model = d_model self.e_layers = e_layers diff --git a/testbed/testbed_unsupervised_tsad.py b/testbed/testbed_unsupervised_tsad.py index a9eb3a4..ab335ca 100644 --- a/testbed/testbed_unsupervised_tsad.py +++ b/testbed/testbed_unsupervised_tsad.py @@ -23,14 +23,14 @@ parser.add_argument("--runs", type=int, default=5, "obtain the average performance") parser.add_argument("--output_dir", type=str, default='@records/', help="the output file path") -parser.add_argument("--dataset", type=str, default='MSL', +parser.add_argument("--dataset", type=str, default='ASD', help='dataset name or a list of names split by comma') parser.add_argument("--entities", type=str, - default='C-2', + default='FULL', help='FULL represents all the csv file in the folder, ' 'or a list of entity names split by comma') -parser.add_argument("--entity_combined", type=int, default=0) -parser.add_argument("--model", type=str, default='COUTA', help="") +parser.add_argument("--entity_combined", type=int, default=1) +parser.add_argument("--model", type=str, default='DCdetector', help="") parser.add_argument("--auto_hyper", default=False, action='store_true', help="") parser.add_argument('--silent_header', action='store_true')