forked from mindspore-Ecosystem/mindspore
fix gcn import error
This commit is contained in:
parent
14868eb2b0
commit
a04d497118
|
@ -15,13 +15,9 @@
|
||||||
"""GCN."""
|
"""GCN."""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore.common.parameter import ParameterTuple
|
|
||||||
from mindspore.ops import composite as C
|
|
||||||
from mindspore.ops import functional as F
|
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.nn.layer.activation import get_activation
|
from mindspore.nn.layer.activation import get_activation
|
||||||
from model_zoo.gcn.src.metrics import Loss, Accuracy
|
|
||||||
|
|
||||||
|
|
||||||
def glorot(shape):
|
def glorot(shape):
|
||||||
|
@ -105,116 +101,3 @@ class GCN(nn.Cell):
|
||||||
output0 = self.layer0(self.adj, self.feature)
|
output0 = self.layer0(self.adj, self.feature)
|
||||||
output1 = self.layer1(self.adj, output0)
|
output1 = self.layer1(self.adj, output0)
|
||||||
return output1
|
return output1
|
||||||
|
|
||||||
|
|
||||||
class LossAccuracyWrapper(nn.Cell):
|
|
||||||
"""
|
|
||||||
Wraps the GCN model with loss and accuracy cell.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): GCN network.
|
|
||||||
label (numpy.ndarray): Dataset labels.
|
|
||||||
mask (numpy.ndarray): Mask for training, evaluation or test.
|
|
||||||
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, network, label, mask, weight_decay):
|
|
||||||
super(LossAccuracyWrapper, self).__init__()
|
|
||||||
self.network = network
|
|
||||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
|
||||||
self.accuracy = Accuracy(label, mask)
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
preds = self.network()
|
|
||||||
loss = self.loss(preds)
|
|
||||||
accuracy = self.accuracy(preds)
|
|
||||||
return loss, accuracy
|
|
||||||
|
|
||||||
|
|
||||||
class LossWrapper(nn.Cell):
|
|
||||||
"""
|
|
||||||
Wraps the GCN model with loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): GCN network.
|
|
||||||
label (numpy.ndarray): Dataset labels.
|
|
||||||
mask (numpy.ndarray): Mask for training.
|
|
||||||
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, network, label, mask, weight_decay):
|
|
||||||
super(LossWrapper, self).__init__()
|
|
||||||
self.network = network
|
|
||||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
preds = self.network()
|
|
||||||
loss = self.loss(preds)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class TrainOneStepCell(nn.Cell):
|
|
||||||
r"""
|
|
||||||
Network training package class.
|
|
||||||
|
|
||||||
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
|
|
||||||
Backward graph will be created in the construct function to do parameter updating. Different
|
|
||||||
parallel modes are available to run the training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): The training network.
|
|
||||||
optimizer (Cell): Optimizer for updating the weights.
|
|
||||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, a scalar Tensor with shape :math:`()`.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> net = Net()
|
|
||||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
||||||
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
||||||
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
|
||||||
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, network, optimizer, sens=1.0):
|
|
||||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
|
||||||
self.network = network
|
|
||||||
self.network.add_flags(defer_inline=True)
|
|
||||||
self.weights = ParameterTuple(network.trainable_params())
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
|
||||||
self.sens = sens
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
weights = self.weights
|
|
||||||
loss = self.network()
|
|
||||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
|
||||||
grads = self.grad(self.network, weights)(sens)
|
|
||||||
return F.depend(loss, self.optimizer(grads))
|
|
||||||
|
|
||||||
|
|
||||||
class TrainNetWrapper(nn.Cell):
|
|
||||||
"""
|
|
||||||
Wraps the GCN model with optimizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): GCN network.
|
|
||||||
label (numpy.ndarray): Dataset labels.
|
|
||||||
mask (numpy.ndarray): Mask for training, evaluation or test.
|
|
||||||
config (ConfigGCN): Configuration for GCN.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, network, label, mask, config):
|
|
||||||
super(TrainNetWrapper, self).__init__(auto_prefix=True)
|
|
||||||
self.network = network
|
|
||||||
loss_net = LossWrapper(network, label, mask, config.weight_decay)
|
|
||||||
optimizer = nn.Adam(loss_net.trainable_params(),
|
|
||||||
learning_rate=config.learning_rate)
|
|
||||||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
|
||||||
self.accuracy = Accuracy(label, mask)
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
loss = self.loss_train_net()
|
|
||||||
accuracy = self.accuracy(self.network())
|
|
||||||
return loss, accuracy
|
|
||||||
|
|
|
@ -17,6 +17,9 @@ from mindspore import nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common.parameter import ParameterTuple
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
|
||||||
|
|
||||||
class Loss(nn.Cell):
|
class Loss(nn.Cell):
|
||||||
|
@ -68,3 +71,116 @@ class Accuracy(nn.Cell):
|
||||||
mask = mask / mask_reduce
|
mask = mask / mask_reduce
|
||||||
accuracy_all *= mask
|
accuracy_all *= mask
|
||||||
return self.mean(accuracy_all)
|
return self.mean(accuracy_all)
|
||||||
|
|
||||||
|
|
||||||
|
class LossAccuracyWrapper(nn.Cell):
|
||||||
|
"""
|
||||||
|
Wraps the GCN model with loss and accuracy cell.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): GCN network.
|
||||||
|
label (numpy.ndarray): Dataset labels.
|
||||||
|
mask (numpy.ndarray): Mask for training, evaluation or test.
|
||||||
|
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network, label, mask, weight_decay):
|
||||||
|
super(LossAccuracyWrapper, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||||
|
self.accuracy = Accuracy(label, mask)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
preds = self.network()
|
||||||
|
loss = self.loss(preds)
|
||||||
|
accuracy = self.accuracy(preds)
|
||||||
|
return loss, accuracy
|
||||||
|
|
||||||
|
|
||||||
|
class LossWrapper(nn.Cell):
|
||||||
|
"""
|
||||||
|
Wraps the GCN model with loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): GCN network.
|
||||||
|
label (numpy.ndarray): Dataset labels.
|
||||||
|
mask (numpy.ndarray): Mask for training.
|
||||||
|
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network, label, mask, weight_decay):
|
||||||
|
super(LossWrapper, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
preds = self.network()
|
||||||
|
loss = self.loss(preds)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class TrainOneStepCell(nn.Cell):
|
||||||
|
r"""
|
||||||
|
Network training package class.
|
||||||
|
|
||||||
|
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
|
||||||
|
Backward graph will be created in the construct function to do parameter updating. Different
|
||||||
|
parallel modes are available to run the training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The training network.
|
||||||
|
optimizer (Cell): Optimizer for updating the weights.
|
||||||
|
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, a scalar Tensor with shape :math:`()`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = Net()
|
||||||
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
|
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
||||||
|
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network, optimizer, sens=1.0):
|
||||||
|
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.network.add_flags(defer_inline=True)
|
||||||
|
self.weights = ParameterTuple(network.trainable_params())
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||||
|
self.sens = sens
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network()
|
||||||
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||||
|
grads = self.grad(self.network, weights)(sens)
|
||||||
|
return F.depend(loss, self.optimizer(grads))
|
||||||
|
|
||||||
|
|
||||||
|
class TrainNetWrapper(nn.Cell):
|
||||||
|
"""
|
||||||
|
Wraps the GCN model with optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): GCN network.
|
||||||
|
label (numpy.ndarray): Dataset labels.
|
||||||
|
mask (numpy.ndarray): Mask for training, evaluation or test.
|
||||||
|
config (ConfigGCN): Configuration for GCN.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network, label, mask, config):
|
||||||
|
super(TrainNetWrapper, self).__init__(auto_prefix=True)
|
||||||
|
self.network = network
|
||||||
|
loss_net = LossWrapper(network, label, mask, config.weight_decay)
|
||||||
|
optimizer = nn.Adam(loss_net.trainable_params(),
|
||||||
|
learning_rate=config.learning_rate)
|
||||||
|
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
||||||
|
self.accuracy = Accuracy(label, mask)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
loss = self.loss_train_net()
|
||||||
|
accuracy = self.accuracy(self.network())
|
||||||
|
return loss, accuracy
|
||||||
|
|
|
@ -26,9 +26,10 @@ from matplotlib import animation
|
||||||
from sklearn import manifold
|
from sklearn import manifold
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
|
||||||
from model_zoo.gcn.src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
from src.gcn import GCN
|
||||||
from model_zoo.gcn.src.config import ConfigGCN
|
from src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
||||||
from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask
|
from src.config import ConfigGCN
|
||||||
|
from src.dataset import get_adj_features_labels, get_mask
|
||||||
|
|
||||||
|
|
||||||
def t_SNE(out_feature, dim):
|
def t_SNE(out_feature, dim):
|
||||||
|
|
|
@ -17,7 +17,8 @@ import time
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from model_zoo.gcn.src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
from model_zoo.gcn.src.gcn import GCN
|
||||||
|
from model_zoo.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
||||||
from model_zoo.gcn.src.config import ConfigGCN
|
from model_zoo.gcn.src.config import ConfigGCN
|
||||||
from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask
|
from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue