diff --git a/model_zoo/official/gnn/gat/src/gat.py b/model_zoo/official/gnn/gat/src/gat.py index b600590ca2e..97edeada79f 100644 --- a/model_zoo/official/gnn/gat/src/gat.py +++ b/model_zoo/official/gnn/gat/src/gat.py @@ -19,7 +19,6 @@ from mindspore.ops import functional as F from mindspore._extends import cell_attr_register from mindspore import Tensor, Parameter from mindspore.common.initializer import initializer -from mindspore._checkparam import Validator from mindspore.nn.layer.activation import get_activation @@ -72,9 +71,9 @@ class GNNFeatureTransform(nn.Cell): bias_init='zeros', has_bias=True): super(GNNFeatureTransform, self).__init__() - self.in_channels = Validator.check_positive_int(in_channels) - self.out_channels = Validator.check_positive_int(out_channels) - self.has_bias = Validator.check_bool(has_bias) + self.in_channels = in_channels + self.out_channels = out_channels + self.has_bias = has_bias if isinstance(weight_init, Tensor): if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ @@ -259,8 +258,8 @@ class AttentionHead(nn.Cell): coef_activation=nn.LeakyReLU(), activation=nn.ELU()): super(AttentionHead, self).__init__() - self.in_channel = Validator.check_positive_int(in_channel) - self.out_channel = Validator.check_positive_int(out_channel) + self.in_channel = in_channel + self.out_channel = out_channel self.in_drop_ratio = in_drop_ratio self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio) self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio) @@ -284,7 +283,7 @@ class AttentionHead(nn.Cell): self.matmul = P.MatMul() self.bias_add = P.BiasAdd() self.bias = Parameter(initializer('zeros', self.out_channel), name='bias') - self.residual = Validator.check_bool(residual) + self.residual = residual if self.residual: if in_channel != out_channel: self.residual_transform_flag = True @@ -436,8 +435,6 @@ class GAT(nn.Cell): """ def __init__(self, - features, - biases, ftr_dims, num_class, num_nodes, @@ -448,17 +445,15 @@ class GAT(nn.Cell): activation=nn.ELU(), residual=False): super(GAT, self).__init__() - self.features = Tensor(features) - self.biases = Tensor(biases) - self.ftr_dims = Validator.check_positive_int(ftr_dims) - self.num_class = Validator.check_positive_int(num_class) - self.num_nodes = Validator.check_positive_int(num_nodes) + self.ftr_dims = ftr_dims + self.num_class = num_class + self.num_nodes = num_nodes self.hidden_units = hidden_units self.num_heads = num_heads self.attn_drop = attn_drop self.ftr_drop = ftr_drop self.activation = activation - self.residual = Validator.check_bool(residual) + self.residual = residual self.layers = [] # first layer self.layers.append(AttentionAggregator( @@ -491,9 +486,9 @@ class GAT(nn.Cell): output_transform='sum')) self.layers = nn.layer.CellList(self.layers) - def construct(self, training=True): - input_data = self.features - bias_mat = self.biases + def construct(self, feature, biases, training=True): + input_data = feature + bias_mat = biases for cell in self.layers: input_data = cell(input_data, bias_mat, training) return input_data/self.num_heads[-1] diff --git a/model_zoo/official/gnn/gat/src/utils.py b/model_zoo/official/gnn/gat/src/utils.py index 23e0c6c3065..c7bae8c8b86 100644 --- a/model_zoo/official/gnn/gat/src/utils.py +++ b/model_zoo/official/gnn/gat/src/utils.py @@ -103,8 +103,8 @@ class LossAccuracyWrapper(nn.Cell): self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, self.network.trainable_params()) self.acc_func = MaskedAccuracy(num_class, label, mask) - def construct(self): - logits = self.network(training=False) + def construct(self, feature, biases): + logits = self.network(feature, biases, training=False) loss = self.loss_func(logits) accuracy = self.acc_func(logits) return loss, accuracy @@ -120,8 +120,8 @@ class LossNetWrapper(nn.Cell): params = list(param for param in self.network.trainable_params() if param.name[-4:] != 'bias') self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, params) - def construct(self): - logits = self.network() + def construct(self, feature, biases): + logits = self.network(feature, biases) loss = self.loss_func(logits) return loss @@ -145,11 +145,11 @@ class TrainOneStepCell(nn.Cell): self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.sens = sens - def construct(self): + def construct(self, feature, biases): weights = self.weights - loss = self.network() + loss = self.network(feature, biases) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) - grads = self.grad(self.network, weights)(sens) + grads = self.grad(self.network, weights)(feature, biases, sens) return F.depend(loss, self.optimizer(grads)) @@ -174,7 +174,7 @@ class TrainGAT(nn.Cell): self.loss_train_net = TrainOneStepCell(loss_net, optimizer) self.accuracy_func = MaskedAccuracy(num_class, label, mask) - def construct(self): - loss = self.loss_train_net() - accuracy = self.accuracy_func(self.network()) + def construct(self, feature, biases): + loss = self.loss_train_net(feature, biases) + accuracy = self.accuracy_func(self.network(feature, biases)) return loss, accuracy diff --git a/model_zoo/official/gnn/gat/train.py b/model_zoo/official/gnn/gat/train.py index 4898cbf0718..1e022cbd81d 100644 --- a/model_zoo/official/gnn/gat/train.py +++ b/model_zoo/official/gnn/gat/train.py @@ -20,6 +20,7 @@ import numpy as np import mindspore.context as context from mindspore.train.serialization import save_checkpoint, load_checkpoint from mindspore.common import set_seed +from mindspore import Tensor from src.config import GatConfig from src.dataset import load_and_process @@ -56,9 +57,7 @@ def train(): num_nodes = feature.shape[1] num_class = y_train.shape[2] - gat_net = GAT(feature, - biases, - feature_size, + gat_net = GAT(feature_size, num_class, num_nodes, hid_units, @@ -67,6 +66,9 @@ def train(): ftr_drop=GatConfig.feature_dropout) gat_net.add_flags_recursive(fp16=True) + feature = Tensor(feature) + biases = Tensor(biases) + eval_net = LossAccuracyWrapper(gat_net, num_class, y_val, @@ -84,11 +86,11 @@ def train(): val_acc_max = 0.0 val_loss_min = np.inf for _epoch in range(num_epochs): - train_result = train_net() + train_result = train_net(feature, biases) train_loss = train_result[0].asnumpy() train_acc = train_result[1].asnumpy() - eval_result = eval_net() + eval_result = eval_net(feature, biases) eval_loss = eval_result[0].asnumpy() eval_acc = eval_result[1].asnumpy() @@ -110,9 +112,7 @@ def train(): print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max)) print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model)) break - gat_net_test = GAT(feature, - biases, - feature_size, + gat_net_test = GAT(feature_size, num_class, num_nodes, hid_units, @@ -127,7 +127,7 @@ def train(): y_test, test_mask, l2_coeff) - test_result = test_net() + test_result = test_net(feature, biases) print("Test loss={}, test acc={}".format(test_result[0], test_result[1])) diff --git a/model_zoo/official/gnn/gcn/src/gcn.py b/model_zoo/official/gnn/gcn/src/gcn.py index 3e1c4324056..7da858a091c 100644 --- a/model_zoo/official/gnn/gcn/src/gcn.py +++ b/model_zoo/official/gnn/gcn/src/gcn.py @@ -92,15 +92,12 @@ class GCN(nn.Cell): output_dim (int): The number of output channels, equal to classes num. """ - def __init__(self, config, adj, feature, output_dim): + def __init__(self, config, input_dim, output_dim): super(GCN, self).__init__() - self.adj = Tensor(adj) - self.feature = Tensor(feature) - input_dim = feature.shape[1] self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout) self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None) - def construct(self): - output0 = self.layer0(self.adj, self.feature) - output1 = self.layer1(self.adj, output0) + def construct(self, adj, feature): + output0 = self.layer0(adj, feature) + output1 = self.layer1(adj, output0) return output1 diff --git a/model_zoo/official/gnn/gcn/src/metrics.py b/model_zoo/official/gnn/gcn/src/metrics.py index 11b4ab84bab..d11236fa73b 100644 --- a/model_zoo/official/gnn/gcn/src/metrics.py +++ b/model_zoo/official/gnn/gcn/src/metrics.py @@ -91,8 +91,8 @@ class LossAccuracyWrapper(nn.Cell): self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) self.accuracy = Accuracy(label, mask) - def construct(self): - preds = self.network() + def construct(self, adj, feature): + preds = self.network(adj, feature) loss = self.loss(preds) accuracy = self.accuracy(preds) return loss, accuracy @@ -114,8 +114,8 @@ class LossWrapper(nn.Cell): self.network = network self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) - def construct(self): - preds = self.network() + def construct(self, adj, feature): + preds = self.network(adj, feature) loss = self.loss(preds) return loss @@ -154,11 +154,11 @@ class TrainOneStepCell(nn.Cell): self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.sens = sens - def construct(self): + def construct(self, adj, feature): weights = self.weights - loss = self.network() + loss = self.network(adj, feature) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) - grads = self.grad(self.network, weights)(sens) + grads = self.grad(self.network, weights)(adj, feature, sens) return F.depend(loss, self.optimizer(grads)) @@ -182,7 +182,7 @@ class TrainNetWrapper(nn.Cell): self.loss_train_net = TrainOneStepCell(loss_net, optimizer) self.accuracy = Accuracy(label, mask) - def construct(self): - loss = self.loss_train_net() - accuracy = self.accuracy(self.network()) + def construct(self, adj, feature): + loss = self.loss_train_net(adj, feature) + accuracy = self.accuracy(self.network(adj, feature)) return loss, accuracy diff --git a/model_zoo/official/gnn/gcn/train.py b/model_zoo/official/gnn/gcn/train.py index 33d928d8d01..2f010805810 100644 --- a/model_zoo/official/gnn/gcn/train.py +++ b/model_zoo/official/gnn/gcn/train.py @@ -26,6 +26,7 @@ from matplotlib import pyplot as plt from matplotlib import animation from sklearn import manifold from mindspore import context +from mindspore import Tensor from mindspore.common import set_seed from mindspore.train.serialization import save_checkpoint, load_checkpoint @@ -71,9 +72,13 @@ def train(): test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num) class_num = label_onehot.shape[1] - gcn_net = GCN(config, adj, feature, class_num) + input_dim = feature.shape[1] + gcn_net = GCN(config, input_dim, class_num) gcn_net.add_flags_recursive(fp16=True) + adj = Tensor(adj) + feature = Tensor(feature) + eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay) train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config) @@ -92,12 +97,12 @@ def train(): t = time.time() train_net.set_train() - train_result = train_net() + train_result = train_net(adj, feature) train_loss = train_result[0].asnumpy() train_accuracy = train_result[1].asnumpy() eval_net.set_train(False) - eval_result = eval_net() + eval_result = eval_net(adj, feature) eval_loss = eval_result[0].asnumpy() eval_accuracy = eval_result[1].asnumpy() @@ -115,14 +120,14 @@ def train(): print("Early stopping...") break save_checkpoint(gcn_net, "ckpts/gcn.ckpt") - gcn_net_test = GCN(config, adj, feature, class_num) + gcn_net_test = GCN(config, input_dim, class_num) load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test) gcn_net_test.add_flags_recursive(fp16=True) test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay) t_test = time.time() test_net.set_train(False) - test_result = test_net() + test_result = test_net(adj, feature) test_loss = test_result[0].asnumpy() test_accuracy = test_result[1].asnumpy() print("Test set results:", "loss=", "{:.5f}".format(test_loss), diff --git a/tests/st/gnn/gcn/test_gcn.py b/tests/st/gnn/gcn/test_gcn.py index 09789a00e57..4bcbf087395 100644 --- a/tests/st/gnn/gcn/test_gcn.py +++ b/tests/st/gnn/gcn/test_gcn.py @@ -17,6 +17,7 @@ import time import pytest import numpy as np from mindspore import context +from mindspore import Tensor from model_zoo.official.gnn.gcn.src.gcn import GCN from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper from model_zoo.official.gnn.gcn.src.config import ConfigGCN @@ -49,9 +50,13 @@ def test_gcn(): test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num) class_num = label_onehot.shape[1] - gcn_net = GCN(config, adj, feature, class_num) + input_dim = feature.shape[1] + gcn_net = GCN(config, input_dim, class_num) gcn_net.add_flags_recursive(fp16=True) + adj = Tensor(adj) + feature = Tensor(feature) + eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay) test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay) train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config) @@ -61,12 +66,12 @@ def test_gcn(): t = time.time() train_net.set_train() - train_result = train_net() + train_result = train_net(adj, feature) train_loss = train_result[0].asnumpy() train_accuracy = train_result[1].asnumpy() eval_net.set_train(False) - eval_result = eval_net() + eval_result = eval_net(adj, feature) eval_loss = eval_result[0].asnumpy() eval_accuracy = eval_result[1].asnumpy() @@ -80,7 +85,7 @@ def test_gcn(): break test_net.set_train(False) - test_result = test_net() + test_result = test_net(adj, feature) test_loss = test_result[0].asnumpy() test_accuracy = test_result[1].asnumpy() print("Test set results:", "loss=", "{:.5f}".format(test_loss),