forked from mindspore-Ecosystem/mindspore
!8756 [ModelZoo]Change gnn seed
From: @zhan_ke Reviewed-by: @oacjiewen,@c_34 Signed-off-by: @c_34
This commit is contained in:
commit
adfd70e9d8
|
@ -19,7 +19,6 @@ import os
|
|||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import Tensor
|
||||
|
||||
from src.config import GatConfig
|
||||
|
@ -27,7 +26,6 @@ from src.dataset import load_and_process
|
|||
from src.gat import GAT
|
||||
from src.utils import LossAccuracyWrapper, TrainGAT
|
||||
|
||||
set_seed(0)
|
||||
|
||||
def train():
|
||||
"""Train GAT model."""
|
||||
|
|
|
@ -27,7 +27,6 @@ from matplotlib import animation
|
|||
from sklearn import manifold
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||
|
||||
from src.gcn import GCN
|
||||
|
@ -51,7 +50,6 @@ def train():
|
|||
"""Train model."""
|
||||
parser = argparse.ArgumentParser(description='GCN')
|
||||
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory')
|
||||
parser.add_argument('--seed', type=int, default=0, help='Random seed')
|
||||
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
|
||||
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
|
||||
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
|
||||
|
@ -60,7 +58,6 @@ def train():
|
|||
if not os.path.exists("ckpts"):
|
||||
os.mkdir("ckpts")
|
||||
|
||||
set_seed(args_opt.seed)
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False)
|
||||
config = ConfigGCN()
|
||||
|
|
Loading…
Reference in New Issue