fix gcn import error

This commit is contained in:
chentingting 2020-07-02 10:17:27 +08:00
parent 14868eb2b0
commit a04d497118
4 changed files with 122 additions and 121 deletions

View File

@ -15,13 +15,9 @@
"""GCN."""
import numpy as np
from mindspore import nn
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.nn.layer.activation import get_activation
from model_zoo.gcn.src.metrics import Loss, Accuracy
def glorot(shape):
@ -105,116 +101,3 @@ class GCN(nn.Cell):
output0 = self.layer0(self.adj, self.feature)
output1 = self.layer1(self.adj, output0)
return output1
class LossAccuracyWrapper(nn.Cell):
"""
Wraps the GCN model with loss and accuracy cell.
Args:
network (Cell): GCN network.
label (numpy.ndarray): Dataset labels.
mask (numpy.ndarray): Mask for training, evaluation or test.
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
"""
def __init__(self, network, label, mask, weight_decay):
super(LossAccuracyWrapper, self).__init__()
self.network = network
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
self.accuracy = Accuracy(label, mask)
def construct(self):
preds = self.network()
loss = self.loss(preds)
accuracy = self.accuracy(preds)
return loss, accuracy
class LossWrapper(nn.Cell):
"""
Wraps the GCN model with loss.
Args:
network (Cell): GCN network.
label (numpy.ndarray): Dataset labels.
mask (numpy.ndarray): Mask for training.
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
"""
def __init__(self, network, label, mask, weight_decay):
super(LossWrapper, self).__init__()
self.network = network
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
def construct(self):
preds = self.network()
loss = self.loss(preds)
return loss
class TrainOneStepCell(nn.Cell):
r"""
Network training package class.
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
Backward graph will be created in the construct function to do parameter updating. Different
parallel modes are available to run the training.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens
def construct(self):
weights = self.weights
loss = self.network()
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(sens)
return F.depend(loss, self.optimizer(grads))
class TrainNetWrapper(nn.Cell):
"""
Wraps the GCN model with optimizer.
Args:
network (Cell): GCN network.
label (numpy.ndarray): Dataset labels.
mask (numpy.ndarray): Mask for training, evaluation or test.
config (ConfigGCN): Configuration for GCN.
"""
def __init__(self, network, label, mask, config):
super(TrainNetWrapper, self).__init__(auto_prefix=True)
self.network = network
loss_net = LossWrapper(network, label, mask, config.weight_decay)
optimizer = nn.Adam(loss_net.trainable_params(),
learning_rate=config.learning_rate)
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
self.accuracy = Accuracy(label, mask)
def construct(self):
loss = self.loss_train_net()
accuracy = self.accuracy(self.network())
return loss, accuracy

View File

@ -17,6 +17,9 @@ from mindspore import nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import composite as C
from mindspore.ops import functional as F
class Loss(nn.Cell):
@ -68,3 +71,116 @@ class Accuracy(nn.Cell):
mask = mask / mask_reduce
accuracy_all *= mask
return self.mean(accuracy_all)
class LossAccuracyWrapper(nn.Cell):
"""
Wraps the GCN model with loss and accuracy cell.
Args:
network (Cell): GCN network.
label (numpy.ndarray): Dataset labels.
mask (numpy.ndarray): Mask for training, evaluation or test.
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
"""
def __init__(self, network, label, mask, weight_decay):
super(LossAccuracyWrapper, self).__init__()
self.network = network
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
self.accuracy = Accuracy(label, mask)
def construct(self):
preds = self.network()
loss = self.loss(preds)
accuracy = self.accuracy(preds)
return loss, accuracy
class LossWrapper(nn.Cell):
"""
Wraps the GCN model with loss.
Args:
network (Cell): GCN network.
label (numpy.ndarray): Dataset labels.
mask (numpy.ndarray): Mask for training.
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
"""
def __init__(self, network, label, mask, weight_decay):
super(LossWrapper, self).__init__()
self.network = network
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
def construct(self):
preds = self.network()
loss = self.loss(preds)
return loss
class TrainOneStepCell(nn.Cell):
r"""
Network training package class.
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
Backward graph will be created in the construct function to do parameter updating. Different
parallel modes are available to run the training.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens
def construct(self):
weights = self.weights
loss = self.network()
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(sens)
return F.depend(loss, self.optimizer(grads))
class TrainNetWrapper(nn.Cell):
"""
Wraps the GCN model with optimizer.
Args:
network (Cell): GCN network.
label (numpy.ndarray): Dataset labels.
mask (numpy.ndarray): Mask for training, evaluation or test.
config (ConfigGCN): Configuration for GCN.
"""
def __init__(self, network, label, mask, config):
super(TrainNetWrapper, self).__init__(auto_prefix=True)
self.network = network
loss_net = LossWrapper(network, label, mask, config.weight_decay)
optimizer = nn.Adam(loss_net.trainable_params(),
learning_rate=config.learning_rate)
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
self.accuracy = Accuracy(label, mask)
def construct(self):
loss = self.loss_train_net()
accuracy = self.accuracy(self.network())
return loss, accuracy

View File

@ -26,9 +26,10 @@ from matplotlib import animation
from sklearn import manifold
from mindspore import context
from model_zoo.gcn.src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
from model_zoo.gcn.src.config import ConfigGCN
from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask
from src.gcn import GCN
from src.metrics import LossAccuracyWrapper, TrainNetWrapper
from src.config import ConfigGCN
from src.dataset import get_adj_features_labels, get_mask
def t_SNE(out_feature, dim):

View File

@ -17,7 +17,8 @@ import time
import pytest
import numpy as np
from mindspore import context
from model_zoo.gcn.src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
from model_zoo.gcn.src.gcn import GCN
from model_zoo.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
from model_zoo.gcn.src.config import ConfigGCN
from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask