commits to merge

This commit is contained in:
xuhongzuo 2023-09-22 12:09:18 +08:00
parent 09eca95d19
commit 94f8f9de85
3 changed files with 17 additions and 7 deletions

View File

@ -6,7 +6,7 @@ some functions are adapted from the pyod library
"""
import sys
import warnings
import pickle
import numpy as np
import torch
import random
@ -527,3 +527,12 @@ class BaseDeepAD(metaclass=ABCMeta):
random.seed(seed)
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
@classmethod
def load_model(cls, path):
with open(f"{path}.pkl", mode="rb") as f:
return pickle.load(f)
def save_model(self, path):
with open(f"{path}.pkl", mode="wb") as f:
pickle.dump(self, f)

View File

@ -22,14 +22,15 @@ class DCdetector(BaseDeepAD):
n_heads=1, d_model=256, e_layers=3, patch_size=None,
verbose=2, random_state=42):
super(DCdetector, self).__init__(
model_name='TranAD', data_type='ts', epochs=epochs, batch_size=batch_size, lr=lr,
model_name='DCdetector', data_type='ts', epochs=epochs, batch_size=batch_size, lr=lr,
seq_len=seq_len, stride=stride,
epoch_steps=epoch_steps, prt_steps=prt_steps, device=device,
verbose=verbose, random_state=random_state
)
if patch_size is None:
self.patch_size = [5] # seq_len must be divisible by patch_size
self.patch_size = patch_size
else:
self.patch_size = patch_size
self.n_heads = n_heads
self.d_model = d_model
self.e_layers = e_layers

View File

@ -23,14 +23,14 @@ parser.add_argument("--runs", type=int, default=5,
"obtain the average performance")
parser.add_argument("--output_dir", type=str, default='@records/',
help="the output file path")
parser.add_argument("--dataset", type=str, default='MSL',
parser.add_argument("--dataset", type=str, default='ASD',
help='dataset name or a list of names split by comma')
parser.add_argument("--entities", type=str,
default='C-2',
default='FULL',
help='FULL represents all the csv file in the folder, '
'or a list of entity names split by comma')
parser.add_argument("--entity_combined", type=int, default=0)
parser.add_argument("--model", type=str, default='COUTA', help="")
parser.add_argument("--entity_combined", type=int, default=1)
parser.add_argument("--model", type=str, default='DCdetector', help="")
parser.add_argument("--auto_hyper", default=False, action='store_true', help="")
parser.add_argument('--silent_header', action='store_true')