diff --git a/tests/ut/python/utils/test_export.py b/tests/ut/python/utils/test_export.py new file mode 100644 index 00000000000..199e0838cc0 --- /dev/null +++ b/tests/ut/python/utils/test_export.py @@ -0,0 +1,97 @@ +import os +import numpy as np + +import mindspore.nn as nn +from mindspore import context +from mindspore.common.tensor import Tensor +from mindspore.common.initializer import TruncatedNormal +from mindspore.common.parameter import ParameterTuple +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.train.serialization import export + + +def weight_variable(): + return TruncatedNormal(0.02) + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +class LeNet5(nn.Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.batch_size = 32 + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, 10) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.reshape(x, (self.batch_size, -1)) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +class WithLossCell(nn.Cell): + def __init__(self, network): + super(WithLossCell, self).__init__(auto_prefix=False) + self.loss = nn.SoftmaxCrossEntropyWithLogits() + self.network = network + + def construct(self, x, label): + predict = self.network(x) + return self.loss(predict, label) + + +class TrainOneStepCell(nn.Cell): + def __init__(self, network): + super(TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_train() + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) + self.hyper_map = C.HyperMap() + self.grad = C.GradOperation(get_by_list=True) + + def construct(self, x, label): + weights = self.weights + grads = self.grad(self.network, weights)(x, label) + return self.optimizer(grads) + + +def test_export_lenet_grad_mindir(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + network = LeNet5() + network.set_train() + predict = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.zeros([32, 10]).astype(np.float32)) + net = TrainOneStepCell(WithLossCell(network)) + file_name = "lenet_grad.mindir" + export(net, predict, label, file_name=file_name, file_format='MINDIR') + assert os.path.exists(file_name) + os.remove(file_name)