forked from mindspore-Ecosystem/mindspore
!2465 GCN rectification
Merge pull request !2465 from chentingting/gcn_rectification
This commit is contained in:
commit
14868eb2b0
|
@ -36,9 +36,9 @@ sh run_process_data.sh [SRC_PATH] [DATASET_NAME]
|
|||
>> Launch
|
||||
```
|
||||
#Generate dataset in mindrecord format for cora
|
||||
sh run_process_data.sh cora
|
||||
sh run_process_data.sh ./data cora
|
||||
#Generate dataset in mindrecord format for citeseer
|
||||
sh run_process_data.sh citeseer
|
||||
sh run_process_data.sh ./data citeseer
|
||||
```
|
||||
|
||||
## Structure
|
||||
|
@ -110,4 +110,6 @@ Epoch: 0200 train_loss= 0.57948 train_acc= 0.96429 val_loss= 1.04753 val_acc= 0.
|
|||
Optimization Finished!
|
||||
Test set results: cost= 1.00983 accuracy= 0.81300 time= 0.39083
|
||||
...
|
||||
```
|
||||
```
|
||||
|
||||
|
||||
|
|
|
@ -40,7 +40,8 @@ else
|
|||
fi
|
||||
MINDRECORD_PATH=`pwd`/data_mr
|
||||
|
||||
rm -f $MINDRECORD_PATH/*
|
||||
rm -f $MINDRECORD_PATH/$DATASET_NAME
|
||||
rm -f $MINDRECORD_PATH/$DATASET_NAME.db
|
||||
|
||||
cd ../../../example/graph_to_mindrecord || exit
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ def get_adj_features_labels(data_dir):
|
|||
adj = adj + adj.T.multiply(adj.T > adj) + sp.eye(nodes_num)
|
||||
nor_adj = normalize_adj(adj)
|
||||
nor_adj = np.array(nor_adj.todense())
|
||||
return nor_adj, features, labels_onehot
|
||||
return nor_adj, features, labels_onehot, labels
|
||||
|
||||
|
||||
def get_mask(total, begin, end):
|
||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from src.metrics import Loss, Accuracy
|
||||
from model_zoo.gcn.src.metrics import Loss, Accuracy
|
||||
|
||||
|
||||
def glorot(shape):
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 6.3 MiB |
|
@ -21,11 +21,25 @@ import time
|
|||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib import animation
|
||||
from sklearn import manifold
|
||||
from mindspore import context
|
||||
|
||||
from src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
||||
from src.config import ConfigGCN
|
||||
from src.dataset import get_adj_features_labels, get_mask
|
||||
from model_zoo.gcn.src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
||||
from model_zoo.gcn.src.config import ConfigGCN
|
||||
from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask
|
||||
|
||||
|
||||
def t_SNE(out_feature, dim):
|
||||
t_sne = manifold.TSNE(n_components=dim, init='pca', random_state=0)
|
||||
return t_sne.fit_transform(out_feature)
|
||||
|
||||
|
||||
def update_graph(i, data, scat, plot):
|
||||
scat.set_offsets(data[i])
|
||||
plt.title('t-SNE visualization of Epoch:{0}'.format(i))
|
||||
return scat, plot
|
||||
|
||||
|
||||
def train():
|
||||
|
@ -36,28 +50,39 @@ def train():
|
|||
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
|
||||
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
|
||||
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
|
||||
parser.add_argument('--save_TSNE', type=bool, default=False, help='Whether to save t-SNE graph')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
np.random.seed(args_opt.seed)
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False)
|
||||
config = ConfigGCN()
|
||||
adj, feature, label = get_adj_features_labels(args_opt.data_dir)
|
||||
adj, feature, label_onehot, label = get_adj_features_labels(args_opt.data_dir)
|
||||
|
||||
nodes_num = label.shape[0]
|
||||
nodes_num = label_onehot.shape[0]
|
||||
train_mask = get_mask(nodes_num, 0, args_opt.train_nodes_num)
|
||||
eval_mask = get_mask(nodes_num, args_opt.train_nodes_num, args_opt.train_nodes_num + args_opt.eval_nodes_num)
|
||||
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
|
||||
|
||||
class_num = label.shape[1]
|
||||
class_num = label_onehot.shape[1]
|
||||
gcn_net = GCN(config, adj, feature, class_num)
|
||||
gcn_net.add_flags_recursive(fp16=True)
|
||||
|
||||
eval_net = LossAccuracyWrapper(gcn_net, label, eval_mask, config.weight_decay)
|
||||
test_net = LossAccuracyWrapper(gcn_net, label, test_mask, config.weight_decay)
|
||||
train_net = TrainNetWrapper(gcn_net, label, train_mask, config)
|
||||
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
||||
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
|
||||
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
||||
|
||||
loss_list = []
|
||||
|
||||
if args_opt.save_TSNE:
|
||||
out_feature = gcn_net()
|
||||
tsne_result = t_SNE(out_feature.asnumpy(), 2)
|
||||
graph_data = []
|
||||
graph_data.append(tsne_result)
|
||||
fig = plt.figure()
|
||||
scat = plt.scatter(tsne_result[:, 0], tsne_result[:, 1], s=2, c=label, cmap='rainbow')
|
||||
plt.title('t-SNE visualization of Epoch:0', fontsize='large', fontweight='bold', verticalalignment='center')
|
||||
|
||||
for epoch in range(config.epochs):
|
||||
t = time.time()
|
||||
|
||||
|
@ -76,6 +101,11 @@ def train():
|
|||
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
|
||||
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
|
||||
|
||||
if args_opt.save_TSNE:
|
||||
out_feature = gcn_net()
|
||||
tsne_result = t_SNE(out_feature.asnumpy(), 2)
|
||||
graph_data.append(tsne_result)
|
||||
|
||||
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
|
||||
print("Early stopping...")
|
||||
break
|
||||
|
@ -88,6 +118,10 @@ def train():
|
|||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
||||
"accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test))
|
||||
|
||||
if args_opt.save_TSNE:
|
||||
ani = animation.FuncAnimation(fig, update_graph, frames=range(config.epochs + 1), fargs=(graph_data, scat, plt))
|
||||
ani.save('t-SNE_visualization.gif', writer='imagemagick')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ConfigGCN():
|
||||
learning_rate = 0.01
|
||||
epochs = 200
|
||||
hidden1 = 16
|
||||
dropout = 0.0
|
||||
weight_decay = 5e-4
|
||||
early_stopping = 10
|
|
@ -1,60 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
def normalize_adj(adj):
|
||||
rowsum = np.array(adj.sum(1))
|
||||
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
|
||||
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
|
||||
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
|
||||
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
|
||||
|
||||
|
||||
def get_adj_features_labels(data_dir):
|
||||
g = ds.GraphData(data_dir)
|
||||
nodes = g.get_all_nodes(0)
|
||||
nodes_list = nodes.tolist()
|
||||
row_tensor = g.get_node_feature(nodes_list, [1, 2])
|
||||
features = row_tensor[0]
|
||||
labels = row_tensor[1]
|
||||
|
||||
nodes_num = labels.shape[0]
|
||||
class_num = labels.max() + 1
|
||||
labels_onehot = np.eye(nodes_num, class_num)[labels].astype(np.float32)
|
||||
|
||||
neighbor = g.get_all_neighbors(nodes_list, 0)
|
||||
node_map = {node_id: index for index, node_id in enumerate(nodes_list)}
|
||||
adj = np.zeros([nodes_num, nodes_num], dtype=np.float32)
|
||||
for index, value in np.ndenumerate(neighbor):
|
||||
# The first column of neighbor is node_id, second column to last column are neighbors of the first column.
|
||||
# So we only care index[1] > 1.
|
||||
# If the node does not have that many neighbors, -1 is padded. So if value < 0, we will not deal with it.
|
||||
if value >= 0 and index[1] > 0:
|
||||
adj[node_map[neighbor[index[0], 0]], node_map[value]] = 1
|
||||
adj = sp.coo_matrix(adj)
|
||||
adj = adj + adj.T.multiply(adj.T > adj) + sp.eye(nodes_num)
|
||||
nor_adj = normalize_adj(adj)
|
||||
nor_adj = np.array(nor_adj.todense())
|
||||
return nor_adj, features, labels_onehot
|
||||
|
||||
|
||||
def get_mask(total, begin, end):
|
||||
mask = np.zeros([total]).astype(np.float32)
|
||||
mask[begin:end] = 1
|
||||
return mask
|
|
@ -1,163 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from src.metrics import Loss, Accuracy
|
||||
|
||||
|
||||
def glorot(shape):
|
||||
init_range = np.sqrt(6.0/(shape[0]+shape[1]))
|
||||
initial = np.random.uniform(-init_range, init_range, shape).astype(np.float32)
|
||||
return Tensor(initial)
|
||||
|
||||
|
||||
class GraphConvolution(nn.Cell):
|
||||
def __init__(self,
|
||||
feature_in_dim,
|
||||
feature_out_dim,
|
||||
dropout_ratio=None,
|
||||
activation=None):
|
||||
super(GraphConvolution, self).__init__()
|
||||
self.in_dim = feature_in_dim
|
||||
self.out_dim = feature_out_dim
|
||||
self.weight_init = glorot([self.out_dim, self.in_dim])
|
||||
self.fc = nn.Dense(self.in_dim,
|
||||
self.out_dim,
|
||||
weight_init=self.weight_init,
|
||||
has_bias=False)
|
||||
self.dropout_ratio = dropout_ratio
|
||||
if self.dropout_ratio is not None:
|
||||
self.dropout = nn.Dropout(keep_prob=1-self.dropout_ratio)
|
||||
self.dropout_flag = self.dropout_ratio is not None
|
||||
self.activation = get_activation(activation)
|
||||
self.activation_flag = self.activation is not None
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, adj, input_feature):
|
||||
dropout = input_feature
|
||||
if self.dropout_flag:
|
||||
dropout = self.dropout(dropout)
|
||||
|
||||
fc = self.fc(dropout)
|
||||
output_feature = self.matmul(adj, fc)
|
||||
|
||||
if self.activation_flag:
|
||||
output_feature = self.activation(output_feature)
|
||||
return output_feature
|
||||
|
||||
|
||||
class GCN(nn.Cell):
|
||||
def __init__(self, config, adj, feature, output_dim):
|
||||
super(GCN, self).__init__()
|
||||
self.adj = Tensor(adj)
|
||||
self.feature = Tensor(feature)
|
||||
input_dim = feature.shape[1]
|
||||
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
|
||||
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)
|
||||
|
||||
def construct(self):
|
||||
output0 = self.layer0(self.adj, self.feature)
|
||||
output1 = self.layer1(self.adj, output0)
|
||||
return output1
|
||||
|
||||
|
||||
class LossAccuracyWrapper(nn.Cell):
|
||||
def __init__(self, network, label, mask, weight_decay):
|
||||
super(LossAccuracyWrapper, self).__init__()
|
||||
self.network = network
|
||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||
self.accuracy = Accuracy(label, mask)
|
||||
|
||||
def construct(self):
|
||||
preds = self.network()
|
||||
loss = self.loss(preds)
|
||||
accuracy = self.accuracy(preds)
|
||||
return loss, accuracy
|
||||
|
||||
|
||||
class LossWrapper(nn.Cell):
|
||||
def __init__(self, network, label, mask, weight_decay):
|
||||
super(LossWrapper, self).__init__()
|
||||
self.network = network
|
||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||
|
||||
def construct(self):
|
||||
preds = self.network()
|
||||
loss = self.loss(preds)
|
||||
return loss
|
||||
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
r"""
|
||||
Network training package class.
|
||||
|
||||
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
|
||||
Backward graph will be created in the construct function to do parameter updating. Different
|
||||
parallel modes are available to run the training.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
|
||||
Outputs:
|
||||
Tensor, a scalar Tensor with shape :math:`()`.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
||||
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
|
||||
def construct(self):
|
||||
weights = self.weights
|
||||
loss = self.network()
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(sens)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class TrainNetWrapper(nn.Cell):
|
||||
def __init__(self, network, label, mask, config):
|
||||
super(TrainNetWrapper, self).__init__(auto_prefix=True)
|
||||
self.network = network
|
||||
loss_net = LossWrapper(network, label, mask, config.weight_decay)
|
||||
optimizer = nn.Adam(loss_net.trainable_params(),
|
||||
learning_rate=config.learning_rate)
|
||||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
||||
self.accuracy = Accuracy(label, mask)
|
||||
|
||||
def construct(self):
|
||||
loss = self.loss_train_net()
|
||||
accuracy = self.accuracy(self.network())
|
||||
return loss, accuracy
|
|
@ -1,68 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Loss(nn.Cell):
|
||||
def __init__(self, label, mask, weight_decay, param):
|
||||
super(Loss, self).__init__()
|
||||
self.label = Tensor(label)
|
||||
self.mask = Tensor(mask)
|
||||
self.loss = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.one = Tensor(1.0, mstype.float32)
|
||||
self.zero = Tensor(0.0, mstype.float32)
|
||||
self.mean = P.ReduceMean()
|
||||
self.cast = P.Cast()
|
||||
self.l2_loss = P.L2Loss()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.weight_decay = weight_decay
|
||||
self.param = param
|
||||
|
||||
def construct(self, preds):
|
||||
param = self.l2_loss(self.param)
|
||||
loss = self.weight_decay * param
|
||||
preds = self.cast(preds, mstype.float32)
|
||||
loss = loss + self.loss(preds, self.label)[0]
|
||||
mask = self.cast(self.mask, mstype.float32)
|
||||
mask_reduce = self.mean(mask)
|
||||
mask = mask / mask_reduce
|
||||
loss = loss * mask
|
||||
loss = self.mean(loss)
|
||||
return loss
|
||||
|
||||
|
||||
class Accuracy(nn.Cell):
|
||||
def __init__(self, label, mask):
|
||||
super(Accuracy, self).__init__()
|
||||
self.label = Tensor(label)
|
||||
self.mask = Tensor(mask)
|
||||
self.equal = P.Equal()
|
||||
self.argmax = P.Argmax()
|
||||
self.cast = P.Cast()
|
||||
self.mean = P.ReduceMean()
|
||||
|
||||
def construct(self, preds):
|
||||
preds = self.cast(preds, mstype.float32)
|
||||
correct_prediction = self.equal(self.argmax(preds), self.argmax(self.label))
|
||||
accuracy_all = self.cast(correct_prediction, mstype.float32)
|
||||
mask = self.cast(self.mask, mstype.float32)
|
||||
mask_reduce = self.mean(mask)
|
||||
mask = mask / mask_reduce
|
||||
accuracy_all *= mask
|
||||
return self.mean(accuracy_all)
|
|
@ -17,9 +17,9 @@ import time
|
|||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
||||
from src.config import ConfigGCN
|
||||
from src.dataset import get_adj_features_labels, get_mask
|
||||
from model_zoo.gcn.src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
||||
from model_zoo.gcn.src.config import ConfigGCN
|
||||
from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask
|
||||
|
||||
|
||||
DATA_DIR = '/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr'
|
||||
|
@ -37,22 +37,23 @@ def test_gcn():
|
|||
print("test_gcn begin")
|
||||
np.random.seed(SEED)
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=True)
|
||||
device_target="Ascend", save_graphs=False)
|
||||
config = ConfigGCN()
|
||||
adj, feature, label = get_adj_features_labels(DATA_DIR)
|
||||
config.dropout = 0.0
|
||||
adj, feature, label_onehot, _ = get_adj_features_labels(DATA_DIR)
|
||||
|
||||
nodes_num = label.shape[0]
|
||||
nodes_num = label_onehot.shape[0]
|
||||
train_mask = get_mask(nodes_num, 0, TRAIN_NODE_NUM)
|
||||
eval_mask = get_mask(nodes_num, TRAIN_NODE_NUM, TRAIN_NODE_NUM + EVAL_NODE_NUM)
|
||||
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
|
||||
|
||||
class_num = label.shape[1]
|
||||
class_num = label_onehot.shape[1]
|
||||
gcn_net = GCN(config, adj, feature, class_num)
|
||||
gcn_net.add_flags_recursive(fp16=True)
|
||||
|
||||
eval_net = LossAccuracyWrapper(gcn_net, label, eval_mask, config.weight_decay)
|
||||
test_net = LossAccuracyWrapper(gcn_net, label, test_mask, config.weight_decay)
|
||||
train_net = TrainNetWrapper(gcn_net, label, train_mask, config)
|
||||
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
||||
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
|
||||
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
||||
|
||||
loss_list = []
|
||||
for epoch in range(config.epochs):
|
||||
|
|
Loading…
Reference in New Issue