forked from mindspore-Ecosystem/mindspore
gcn_rectification
This commit is contained in:
parent
932b7649e7
commit
a733102db9
|
@ -36,9 +36,9 @@ sh run_process_data.sh [SRC_PATH] [DATASET_NAME]
|
||||||
>> Launch
|
>> Launch
|
||||||
```
|
```
|
||||||
#Generate dataset in mindrecord format for cora
|
#Generate dataset in mindrecord format for cora
|
||||||
sh run_process_data.sh cora
|
sh run_process_data.sh ./data cora
|
||||||
#Generate dataset in mindrecord format for citeseer
|
#Generate dataset in mindrecord format for citeseer
|
||||||
sh run_process_data.sh citeseer
|
sh run_process_data.sh ./data citeseer
|
||||||
```
|
```
|
||||||
|
|
||||||
## Structure
|
## Structure
|
||||||
|
@ -110,4 +110,6 @@ Epoch: 0200 train_loss= 0.57948 train_acc= 0.96429 val_loss= 1.04753 val_acc= 0.
|
||||||
Optimization Finished!
|
Optimization Finished!
|
||||||
Test set results: cost= 1.00983 accuracy= 0.81300 time= 0.39083
|
Test set results: cost= 1.00983 accuracy= 0.81300 time= 0.39083
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,8 @@ else
|
||||||
fi
|
fi
|
||||||
MINDRECORD_PATH=`pwd`/data_mr
|
MINDRECORD_PATH=`pwd`/data_mr
|
||||||
|
|
||||||
rm -f $MINDRECORD_PATH/*
|
rm -f $MINDRECORD_PATH/$DATASET_NAME
|
||||||
|
rm -f $MINDRECORD_PATH/$DATASET_NAME.db
|
||||||
|
|
||||||
cd ../../../example/graph_to_mindrecord || exit
|
cd ../../../example/graph_to_mindrecord || exit
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ def get_adj_features_labels(data_dir):
|
||||||
adj = adj + adj.T.multiply(adj.T > adj) + sp.eye(nodes_num)
|
adj = adj + adj.T.multiply(adj.T > adj) + sp.eye(nodes_num)
|
||||||
nor_adj = normalize_adj(adj)
|
nor_adj = normalize_adj(adj)
|
||||||
nor_adj = np.array(nor_adj.todense())
|
nor_adj = np.array(nor_adj.todense())
|
||||||
return nor_adj, features, labels_onehot
|
return nor_adj, features, labels_onehot, labels
|
||||||
|
|
||||||
|
|
||||||
def get_mask(total, begin, end):
|
def get_mask(total, begin, end):
|
||||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.ops import functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.nn.layer.activation import get_activation
|
from mindspore.nn.layer.activation import get_activation
|
||||||
from src.metrics import Loss, Accuracy
|
from model_zoo.gcn.src.metrics import Loss, Accuracy
|
||||||
|
|
||||||
|
|
||||||
def glorot(shape):
|
def glorot(shape):
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 6.3 MiB |
|
@ -21,11 +21,25 @@ import time
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from matplotlib import animation
|
||||||
|
from sklearn import manifold
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
|
||||||
from src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
from model_zoo.gcn.src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
||||||
from src.config import ConfigGCN
|
from model_zoo.gcn.src.config import ConfigGCN
|
||||||
from src.dataset import get_adj_features_labels, get_mask
|
from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask
|
||||||
|
|
||||||
|
|
||||||
|
def t_SNE(out_feature, dim):
|
||||||
|
t_sne = manifold.TSNE(n_components=dim, init='pca', random_state=0)
|
||||||
|
return t_sne.fit_transform(out_feature)
|
||||||
|
|
||||||
|
|
||||||
|
def update_graph(i, data, scat, plot):
|
||||||
|
scat.set_offsets(data[i])
|
||||||
|
plt.title('t-SNE visualization of Epoch:{0}'.format(i))
|
||||||
|
return scat, plot
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
|
@ -36,28 +50,39 @@ def train():
|
||||||
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
|
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
|
||||||
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
|
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
|
||||||
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
|
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
|
||||||
|
parser.add_argument('--save_TSNE', type=bool, default=False, help='Whether to save t-SNE graph')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
np.random.seed(args_opt.seed)
|
np.random.seed(args_opt.seed)
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target="Ascend", save_graphs=False)
|
device_target="Ascend", save_graphs=False)
|
||||||
config = ConfigGCN()
|
config = ConfigGCN()
|
||||||
adj, feature, label = get_adj_features_labels(args_opt.data_dir)
|
adj, feature, label_onehot, label = get_adj_features_labels(args_opt.data_dir)
|
||||||
|
|
||||||
nodes_num = label.shape[0]
|
nodes_num = label_onehot.shape[0]
|
||||||
train_mask = get_mask(nodes_num, 0, args_opt.train_nodes_num)
|
train_mask = get_mask(nodes_num, 0, args_opt.train_nodes_num)
|
||||||
eval_mask = get_mask(nodes_num, args_opt.train_nodes_num, args_opt.train_nodes_num + args_opt.eval_nodes_num)
|
eval_mask = get_mask(nodes_num, args_opt.train_nodes_num, args_opt.train_nodes_num + args_opt.eval_nodes_num)
|
||||||
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
|
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
|
||||||
|
|
||||||
class_num = label.shape[1]
|
class_num = label_onehot.shape[1]
|
||||||
gcn_net = GCN(config, adj, feature, class_num)
|
gcn_net = GCN(config, adj, feature, class_num)
|
||||||
gcn_net.add_flags_recursive(fp16=True)
|
gcn_net.add_flags_recursive(fp16=True)
|
||||||
|
|
||||||
eval_net = LossAccuracyWrapper(gcn_net, label, eval_mask, config.weight_decay)
|
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
||||||
test_net = LossAccuracyWrapper(gcn_net, label, test_mask, config.weight_decay)
|
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
|
||||||
train_net = TrainNetWrapper(gcn_net, label, train_mask, config)
|
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
||||||
|
|
||||||
loss_list = []
|
loss_list = []
|
||||||
|
|
||||||
|
if args_opt.save_TSNE:
|
||||||
|
out_feature = gcn_net()
|
||||||
|
tsne_result = t_SNE(out_feature.asnumpy(), 2)
|
||||||
|
graph_data = []
|
||||||
|
graph_data.append(tsne_result)
|
||||||
|
fig = plt.figure()
|
||||||
|
scat = plt.scatter(tsne_result[:, 0], tsne_result[:, 1], s=2, c=label, cmap='rainbow')
|
||||||
|
plt.title('t-SNE visualization of Epoch:0', fontsize='large', fontweight='bold', verticalalignment='center')
|
||||||
|
|
||||||
for epoch in range(config.epochs):
|
for epoch in range(config.epochs):
|
||||||
t = time.time()
|
t = time.time()
|
||||||
|
|
||||||
|
@ -76,6 +101,11 @@ def train():
|
||||||
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
|
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
|
||||||
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
|
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
|
||||||
|
|
||||||
|
if args_opt.save_TSNE:
|
||||||
|
out_feature = gcn_net()
|
||||||
|
tsne_result = t_SNE(out_feature.asnumpy(), 2)
|
||||||
|
graph_data.append(tsne_result)
|
||||||
|
|
||||||
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
|
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
|
||||||
print("Early stopping...")
|
print("Early stopping...")
|
||||||
break
|
break
|
||||||
|
@ -88,6 +118,10 @@ def train():
|
||||||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
||||||
"accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test))
|
"accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test))
|
||||||
|
|
||||||
|
if args_opt.save_TSNE:
|
||||||
|
ani = animation.FuncAnimation(fig, update_graph, frames=range(config.epochs + 1), fargs=(graph_data, scat, plt))
|
||||||
|
ani.save('t-SNE_visualization.gif', writer='imagemagick')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train()
|
train()
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigGCN():
|
|
||||||
learning_rate = 0.01
|
|
||||||
epochs = 200
|
|
||||||
hidden1 = 16
|
|
||||||
dropout = 0.0
|
|
||||||
weight_decay = 5e-4
|
|
||||||
early_stopping = 10
|
|
|
@ -1,60 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import scipy.sparse as sp
|
|
||||||
import mindspore.dataset as ds
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_adj(adj):
|
|
||||||
rowsum = np.array(adj.sum(1))
|
|
||||||
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
|
|
||||||
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
|
|
||||||
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
|
|
||||||
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
|
|
||||||
|
|
||||||
|
|
||||||
def get_adj_features_labels(data_dir):
|
|
||||||
g = ds.GraphData(data_dir)
|
|
||||||
nodes = g.get_all_nodes(0)
|
|
||||||
nodes_list = nodes.tolist()
|
|
||||||
row_tensor = g.get_node_feature(nodes_list, [1, 2])
|
|
||||||
features = row_tensor[0]
|
|
||||||
labels = row_tensor[1]
|
|
||||||
|
|
||||||
nodes_num = labels.shape[0]
|
|
||||||
class_num = labels.max() + 1
|
|
||||||
labels_onehot = np.eye(nodes_num, class_num)[labels].astype(np.float32)
|
|
||||||
|
|
||||||
neighbor = g.get_all_neighbors(nodes_list, 0)
|
|
||||||
node_map = {node_id: index for index, node_id in enumerate(nodes_list)}
|
|
||||||
adj = np.zeros([nodes_num, nodes_num], dtype=np.float32)
|
|
||||||
for index, value in np.ndenumerate(neighbor):
|
|
||||||
# The first column of neighbor is node_id, second column to last column are neighbors of the first column.
|
|
||||||
# So we only care index[1] > 1.
|
|
||||||
# If the node does not have that many neighbors, -1 is padded. So if value < 0, we will not deal with it.
|
|
||||||
if value >= 0 and index[1] > 0:
|
|
||||||
adj[node_map[neighbor[index[0], 0]], node_map[value]] = 1
|
|
||||||
adj = sp.coo_matrix(adj)
|
|
||||||
adj = adj + adj.T.multiply(adj.T > adj) + sp.eye(nodes_num)
|
|
||||||
nor_adj = normalize_adj(adj)
|
|
||||||
nor_adj = np.array(nor_adj.todense())
|
|
||||||
return nor_adj, features, labels_onehot
|
|
||||||
|
|
||||||
|
|
||||||
def get_mask(total, begin, end):
|
|
||||||
mask = np.zeros([total]).astype(np.float32)
|
|
||||||
mask[begin:end] = 1
|
|
||||||
return mask
|
|
|
@ -1,163 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from mindspore import nn
|
|
||||||
from mindspore.common.parameter import ParameterTuple
|
|
||||||
from mindspore.ops import composite as C
|
|
||||||
from mindspore.ops import functional as F
|
|
||||||
from mindspore.ops import operations as P
|
|
||||||
from mindspore import Tensor
|
|
||||||
from mindspore.nn.layer.activation import get_activation
|
|
||||||
from src.metrics import Loss, Accuracy
|
|
||||||
|
|
||||||
|
|
||||||
def glorot(shape):
|
|
||||||
init_range = np.sqrt(6.0/(shape[0]+shape[1]))
|
|
||||||
initial = np.random.uniform(-init_range, init_range, shape).astype(np.float32)
|
|
||||||
return Tensor(initial)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphConvolution(nn.Cell):
|
|
||||||
def __init__(self,
|
|
||||||
feature_in_dim,
|
|
||||||
feature_out_dim,
|
|
||||||
dropout_ratio=None,
|
|
||||||
activation=None):
|
|
||||||
super(GraphConvolution, self).__init__()
|
|
||||||
self.in_dim = feature_in_dim
|
|
||||||
self.out_dim = feature_out_dim
|
|
||||||
self.weight_init = glorot([self.out_dim, self.in_dim])
|
|
||||||
self.fc = nn.Dense(self.in_dim,
|
|
||||||
self.out_dim,
|
|
||||||
weight_init=self.weight_init,
|
|
||||||
has_bias=False)
|
|
||||||
self.dropout_ratio = dropout_ratio
|
|
||||||
if self.dropout_ratio is not None:
|
|
||||||
self.dropout = nn.Dropout(keep_prob=1-self.dropout_ratio)
|
|
||||||
self.dropout_flag = self.dropout_ratio is not None
|
|
||||||
self.activation = get_activation(activation)
|
|
||||||
self.activation_flag = self.activation is not None
|
|
||||||
self.matmul = P.MatMul()
|
|
||||||
|
|
||||||
def construct(self, adj, input_feature):
|
|
||||||
dropout = input_feature
|
|
||||||
if self.dropout_flag:
|
|
||||||
dropout = self.dropout(dropout)
|
|
||||||
|
|
||||||
fc = self.fc(dropout)
|
|
||||||
output_feature = self.matmul(adj, fc)
|
|
||||||
|
|
||||||
if self.activation_flag:
|
|
||||||
output_feature = self.activation(output_feature)
|
|
||||||
return output_feature
|
|
||||||
|
|
||||||
|
|
||||||
class GCN(nn.Cell):
|
|
||||||
def __init__(self, config, adj, feature, output_dim):
|
|
||||||
super(GCN, self).__init__()
|
|
||||||
self.adj = Tensor(adj)
|
|
||||||
self.feature = Tensor(feature)
|
|
||||||
input_dim = feature.shape[1]
|
|
||||||
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
|
|
||||||
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
output0 = self.layer0(self.adj, self.feature)
|
|
||||||
output1 = self.layer1(self.adj, output0)
|
|
||||||
return output1
|
|
||||||
|
|
||||||
|
|
||||||
class LossAccuracyWrapper(nn.Cell):
|
|
||||||
def __init__(self, network, label, mask, weight_decay):
|
|
||||||
super(LossAccuracyWrapper, self).__init__()
|
|
||||||
self.network = network
|
|
||||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
|
||||||
self.accuracy = Accuracy(label, mask)
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
preds = self.network()
|
|
||||||
loss = self.loss(preds)
|
|
||||||
accuracy = self.accuracy(preds)
|
|
||||||
return loss, accuracy
|
|
||||||
|
|
||||||
|
|
||||||
class LossWrapper(nn.Cell):
|
|
||||||
def __init__(self, network, label, mask, weight_decay):
|
|
||||||
super(LossWrapper, self).__init__()
|
|
||||||
self.network = network
|
|
||||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
preds = self.network()
|
|
||||||
loss = self.loss(preds)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class TrainOneStepCell(nn.Cell):
|
|
||||||
r"""
|
|
||||||
Network training package class.
|
|
||||||
|
|
||||||
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
|
|
||||||
Backward graph will be created in the construct function to do parameter updating. Different
|
|
||||||
parallel modes are available to run the training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): The training network.
|
|
||||||
optimizer (Cell): Optimizer for updating the weights.
|
|
||||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, a scalar Tensor with shape :math:`()`.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> net = Net()
|
|
||||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
||||||
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
||||||
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
|
||||||
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, network, optimizer, sens=1.0):
|
|
||||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
|
||||||
self.network = network
|
|
||||||
self.network.add_flags(defer_inline=True)
|
|
||||||
self.weights = ParameterTuple(network.trainable_params())
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
|
||||||
self.sens = sens
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
weights = self.weights
|
|
||||||
loss = self.network()
|
|
||||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
|
||||||
grads = self.grad(self.network, weights)(sens)
|
|
||||||
return F.depend(loss, self.optimizer(grads))
|
|
||||||
|
|
||||||
|
|
||||||
class TrainNetWrapper(nn.Cell):
|
|
||||||
def __init__(self, network, label, mask, config):
|
|
||||||
super(TrainNetWrapper, self).__init__(auto_prefix=True)
|
|
||||||
self.network = network
|
|
||||||
loss_net = LossWrapper(network, label, mask, config.weight_decay)
|
|
||||||
optimizer = nn.Adam(loss_net.trainable_params(),
|
|
||||||
learning_rate=config.learning_rate)
|
|
||||||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
|
||||||
self.accuracy = Accuracy(label, mask)
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
loss = self.loss_train_net()
|
|
||||||
accuracy = self.accuracy(self.network())
|
|
||||||
return loss, accuracy
|
|
|
@ -1,68 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
from mindspore import nn
|
|
||||||
from mindspore import Tensor
|
|
||||||
from mindspore.common import dtype as mstype
|
|
||||||
from mindspore.ops import operations as P
|
|
||||||
|
|
||||||
|
|
||||||
class Loss(nn.Cell):
|
|
||||||
def __init__(self, label, mask, weight_decay, param):
|
|
||||||
super(Loss, self).__init__()
|
|
||||||
self.label = Tensor(label)
|
|
||||||
self.mask = Tensor(mask)
|
|
||||||
self.loss = P.SoftmaxCrossEntropyWithLogits()
|
|
||||||
self.one = Tensor(1.0, mstype.float32)
|
|
||||||
self.zero = Tensor(0.0, mstype.float32)
|
|
||||||
self.mean = P.ReduceMean()
|
|
||||||
self.cast = P.Cast()
|
|
||||||
self.l2_loss = P.L2Loss()
|
|
||||||
self.reduce_sum = P.ReduceSum()
|
|
||||||
self.weight_decay = weight_decay
|
|
||||||
self.param = param
|
|
||||||
|
|
||||||
def construct(self, preds):
|
|
||||||
param = self.l2_loss(self.param)
|
|
||||||
loss = self.weight_decay * param
|
|
||||||
preds = self.cast(preds, mstype.float32)
|
|
||||||
loss = loss + self.loss(preds, self.label)[0]
|
|
||||||
mask = self.cast(self.mask, mstype.float32)
|
|
||||||
mask_reduce = self.mean(mask)
|
|
||||||
mask = mask / mask_reduce
|
|
||||||
loss = loss * mask
|
|
||||||
loss = self.mean(loss)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class Accuracy(nn.Cell):
|
|
||||||
def __init__(self, label, mask):
|
|
||||||
super(Accuracy, self).__init__()
|
|
||||||
self.label = Tensor(label)
|
|
||||||
self.mask = Tensor(mask)
|
|
||||||
self.equal = P.Equal()
|
|
||||||
self.argmax = P.Argmax()
|
|
||||||
self.cast = P.Cast()
|
|
||||||
self.mean = P.ReduceMean()
|
|
||||||
|
|
||||||
def construct(self, preds):
|
|
||||||
preds = self.cast(preds, mstype.float32)
|
|
||||||
correct_prediction = self.equal(self.argmax(preds), self.argmax(self.label))
|
|
||||||
accuracy_all = self.cast(correct_prediction, mstype.float32)
|
|
||||||
mask = self.cast(self.mask, mstype.float32)
|
|
||||||
mask_reduce = self.mean(mask)
|
|
||||||
mask = mask / mask_reduce
|
|
||||||
accuracy_all *= mask
|
|
||||||
return self.mean(accuracy_all)
|
|
|
@ -17,9 +17,9 @@ import time
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
from model_zoo.gcn.src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
||||||
from src.config import ConfigGCN
|
from model_zoo.gcn.src.config import ConfigGCN
|
||||||
from src.dataset import get_adj_features_labels, get_mask
|
from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask
|
||||||
|
|
||||||
|
|
||||||
DATA_DIR = '/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr'
|
DATA_DIR = '/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr'
|
||||||
|
@ -37,22 +37,23 @@ def test_gcn():
|
||||||
print("test_gcn begin")
|
print("test_gcn begin")
|
||||||
np.random.seed(SEED)
|
np.random.seed(SEED)
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target="Ascend", save_graphs=True)
|
device_target="Ascend", save_graphs=False)
|
||||||
config = ConfigGCN()
|
config = ConfigGCN()
|
||||||
adj, feature, label = get_adj_features_labels(DATA_DIR)
|
config.dropout = 0.0
|
||||||
|
adj, feature, label_onehot, _ = get_adj_features_labels(DATA_DIR)
|
||||||
|
|
||||||
nodes_num = label.shape[0]
|
nodes_num = label_onehot.shape[0]
|
||||||
train_mask = get_mask(nodes_num, 0, TRAIN_NODE_NUM)
|
train_mask = get_mask(nodes_num, 0, TRAIN_NODE_NUM)
|
||||||
eval_mask = get_mask(nodes_num, TRAIN_NODE_NUM, TRAIN_NODE_NUM + EVAL_NODE_NUM)
|
eval_mask = get_mask(nodes_num, TRAIN_NODE_NUM, TRAIN_NODE_NUM + EVAL_NODE_NUM)
|
||||||
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
|
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
|
||||||
|
|
||||||
class_num = label.shape[1]
|
class_num = label_onehot.shape[1]
|
||||||
gcn_net = GCN(config, adj, feature, class_num)
|
gcn_net = GCN(config, adj, feature, class_num)
|
||||||
gcn_net.add_flags_recursive(fp16=True)
|
gcn_net.add_flags_recursive(fp16=True)
|
||||||
|
|
||||||
eval_net = LossAccuracyWrapper(gcn_net, label, eval_mask, config.weight_decay)
|
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
||||||
test_net = LossAccuracyWrapper(gcn_net, label, test_mask, config.weight_decay)
|
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
|
||||||
train_net = TrainNetWrapper(gcn_net, label, train_mask, config)
|
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
||||||
|
|
||||||
loss_list = []
|
loss_list = []
|
||||||
for epoch in range(config.epochs):
|
for epoch in range(config.epochs):
|
||||||
|
|
Loading…
Reference in New Issue