forked from JointCloud/JCC-DeepOD
commits to merge
This commit is contained in:
parent
09eca95d19
commit
94f8f9de85
|
@ -6,7 +6,7 @@ some functions are adapted from the pyod library
|
|||
"""
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import pickle
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
|
@ -527,3 +527,12 @@ class BaseDeepAD(metaclass=ABCMeta):
|
|||
random.seed(seed)
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, path):
|
||||
with open(f"{path}.pkl", mode="rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def save_model(self, path):
|
||||
with open(f"{path}.pkl", mode="wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
|
|
@ -22,14 +22,15 @@ class DCdetector(BaseDeepAD):
|
|||
n_heads=1, d_model=256, e_layers=3, patch_size=None,
|
||||
verbose=2, random_state=42):
|
||||
super(DCdetector, self).__init__(
|
||||
model_name='TranAD', data_type='ts', epochs=epochs, batch_size=batch_size, lr=lr,
|
||||
model_name='DCdetector', data_type='ts', epochs=epochs, batch_size=batch_size, lr=lr,
|
||||
seq_len=seq_len, stride=stride,
|
||||
epoch_steps=epoch_steps, prt_steps=prt_steps, device=device,
|
||||
verbose=verbose, random_state=random_state
|
||||
)
|
||||
if patch_size is None:
|
||||
self.patch_size = [5] # seq_len must be divisible by patch_size
|
||||
self.patch_size = patch_size
|
||||
else:
|
||||
self.patch_size = patch_size
|
||||
self.n_heads = n_heads
|
||||
self.d_model = d_model
|
||||
self.e_layers = e_layers
|
||||
|
|
|
@ -23,14 +23,14 @@ parser.add_argument("--runs", type=int, default=5,
|
|||
"obtain the average performance")
|
||||
parser.add_argument("--output_dir", type=str, default='@records/',
|
||||
help="the output file path")
|
||||
parser.add_argument("--dataset", type=str, default='MSL',
|
||||
parser.add_argument("--dataset", type=str, default='ASD',
|
||||
help='dataset name or a list of names split by comma')
|
||||
parser.add_argument("--entities", type=str,
|
||||
default='C-2',
|
||||
default='FULL',
|
||||
help='FULL represents all the csv file in the folder, '
|
||||
'or a list of entity names split by comma')
|
||||
parser.add_argument("--entity_combined", type=int, default=0)
|
||||
parser.add_argument("--model", type=str, default='COUTA', help="")
|
||||
parser.add_argument("--entity_combined", type=int, default=1)
|
||||
parser.add_argument("--model", type=str, default='DCdetector', help="")
|
||||
parser.add_argument("--auto_hyper", default=False, action='store_true', help="")
|
||||
|
||||
parser.add_argument('--silent_header', action='store_true')
|
||||
|
|
Loading…
Reference in New Issue