forked from mindspore-Ecosystem/mindspore
fix gcn import error
This commit is contained in:
parent
14868eb2b0
commit
a04d497118
|
@ -15,13 +15,9 @@
|
|||
"""GCN."""
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from model_zoo.gcn.src.metrics import Loss, Accuracy
|
||||
|
||||
|
||||
def glorot(shape):
|
||||
|
@ -105,116 +101,3 @@ class GCN(nn.Cell):
|
|||
output0 = self.layer0(self.adj, self.feature)
|
||||
output1 = self.layer1(self.adj, output0)
|
||||
return output1
|
||||
|
||||
|
||||
class LossAccuracyWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the GCN model with loss and accuracy cell.
|
||||
|
||||
Args:
|
||||
network (Cell): GCN network.
|
||||
label (numpy.ndarray): Dataset labels.
|
||||
mask (numpy.ndarray): Mask for training, evaluation or test.
|
||||
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
||||
"""
|
||||
|
||||
def __init__(self, network, label, mask, weight_decay):
|
||||
super(LossAccuracyWrapper, self).__init__()
|
||||
self.network = network
|
||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||
self.accuracy = Accuracy(label, mask)
|
||||
|
||||
def construct(self):
|
||||
preds = self.network()
|
||||
loss = self.loss(preds)
|
||||
accuracy = self.accuracy(preds)
|
||||
return loss, accuracy
|
||||
|
||||
|
||||
class LossWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the GCN model with loss.
|
||||
|
||||
Args:
|
||||
network (Cell): GCN network.
|
||||
label (numpy.ndarray): Dataset labels.
|
||||
mask (numpy.ndarray): Mask for training.
|
||||
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
||||
"""
|
||||
|
||||
def __init__(self, network, label, mask, weight_decay):
|
||||
super(LossWrapper, self).__init__()
|
||||
self.network = network
|
||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||
|
||||
def construct(self):
|
||||
preds = self.network()
|
||||
loss = self.loss(preds)
|
||||
return loss
|
||||
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
r"""
|
||||
Network training package class.
|
||||
|
||||
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
|
||||
Backward graph will be created in the construct function to do parameter updating. Different
|
||||
parallel modes are available to run the training.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
|
||||
Outputs:
|
||||
Tensor, a scalar Tensor with shape :math:`()`.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
||||
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
|
||||
def construct(self):
|
||||
weights = self.weights
|
||||
loss = self.network()
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(sens)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class TrainNetWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the GCN model with optimizer.
|
||||
|
||||
Args:
|
||||
network (Cell): GCN network.
|
||||
label (numpy.ndarray): Dataset labels.
|
||||
mask (numpy.ndarray): Mask for training, evaluation or test.
|
||||
config (ConfigGCN): Configuration for GCN.
|
||||
"""
|
||||
|
||||
def __init__(self, network, label, mask, config):
|
||||
super(TrainNetWrapper, self).__init__(auto_prefix=True)
|
||||
self.network = network
|
||||
loss_net = LossWrapper(network, label, mask, config.weight_decay)
|
||||
optimizer = nn.Adam(loss_net.trainable_params(),
|
||||
learning_rate=config.learning_rate)
|
||||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
||||
self.accuracy = Accuracy(label, mask)
|
||||
|
||||
def construct(self):
|
||||
loss = self.loss_train_net()
|
||||
accuracy = self.accuracy(self.network())
|
||||
return loss, accuracy
|
||||
|
|
|
@ -17,6 +17,9 @@ from mindspore import nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
class Loss(nn.Cell):
|
||||
|
@ -68,3 +71,116 @@ class Accuracy(nn.Cell):
|
|||
mask = mask / mask_reduce
|
||||
accuracy_all *= mask
|
||||
return self.mean(accuracy_all)
|
||||
|
||||
|
||||
class LossAccuracyWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the GCN model with loss and accuracy cell.
|
||||
|
||||
Args:
|
||||
network (Cell): GCN network.
|
||||
label (numpy.ndarray): Dataset labels.
|
||||
mask (numpy.ndarray): Mask for training, evaluation or test.
|
||||
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
||||
"""
|
||||
|
||||
def __init__(self, network, label, mask, weight_decay):
|
||||
super(LossAccuracyWrapper, self).__init__()
|
||||
self.network = network
|
||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||
self.accuracy = Accuracy(label, mask)
|
||||
|
||||
def construct(self):
|
||||
preds = self.network()
|
||||
loss = self.loss(preds)
|
||||
accuracy = self.accuracy(preds)
|
||||
return loss, accuracy
|
||||
|
||||
|
||||
class LossWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the GCN model with loss.
|
||||
|
||||
Args:
|
||||
network (Cell): GCN network.
|
||||
label (numpy.ndarray): Dataset labels.
|
||||
mask (numpy.ndarray): Mask for training.
|
||||
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
||||
"""
|
||||
|
||||
def __init__(self, network, label, mask, weight_decay):
|
||||
super(LossWrapper, self).__init__()
|
||||
self.network = network
|
||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||
|
||||
def construct(self):
|
||||
preds = self.network()
|
||||
loss = self.loss(preds)
|
||||
return loss
|
||||
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
r"""
|
||||
Network training package class.
|
||||
|
||||
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
|
||||
Backward graph will be created in the construct function to do parameter updating. Different
|
||||
parallel modes are available to run the training.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
|
||||
Outputs:
|
||||
Tensor, a scalar Tensor with shape :math:`()`.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
||||
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
|
||||
def construct(self):
|
||||
weights = self.weights
|
||||
loss = self.network()
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(sens)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class TrainNetWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the GCN model with optimizer.
|
||||
|
||||
Args:
|
||||
network (Cell): GCN network.
|
||||
label (numpy.ndarray): Dataset labels.
|
||||
mask (numpy.ndarray): Mask for training, evaluation or test.
|
||||
config (ConfigGCN): Configuration for GCN.
|
||||
"""
|
||||
|
||||
def __init__(self, network, label, mask, config):
|
||||
super(TrainNetWrapper, self).__init__(auto_prefix=True)
|
||||
self.network = network
|
||||
loss_net = LossWrapper(network, label, mask, config.weight_decay)
|
||||
optimizer = nn.Adam(loss_net.trainable_params(),
|
||||
learning_rate=config.learning_rate)
|
||||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
||||
self.accuracy = Accuracy(label, mask)
|
||||
|
||||
def construct(self):
|
||||
loss = self.loss_train_net()
|
||||
accuracy = self.accuracy(self.network())
|
||||
return loss, accuracy
|
||||
|
|
|
@ -26,9 +26,10 @@ from matplotlib import animation
|
|||
from sklearn import manifold
|
||||
from mindspore import context
|
||||
|
||||
from model_zoo.gcn.src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
||||
from model_zoo.gcn.src.config import ConfigGCN
|
||||
from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask
|
||||
from src.gcn import GCN
|
||||
from src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
||||
from src.config import ConfigGCN
|
||||
from src.dataset import get_adj_features_labels, get_mask
|
||||
|
||||
|
||||
def t_SNE(out_feature, dim):
|
||||
|
|
|
@ -17,7 +17,8 @@ import time
|
|||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from model_zoo.gcn.src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
||||
from model_zoo.gcn.src.gcn import GCN
|
||||
from model_zoo.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
||||
from model_zoo.gcn.src.config import ConfigGCN
|
||||
from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask
|
||||
|
||||
|
|
Loading…
Reference in New Issue