From d8807e3e85c7cc1f1b3fcf63571c419656bd3d5f Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Mon, 13 Feb 2023 20:21:11 +0800 Subject: [PATCH] Add testcases for JIT Fallback gradient --- .../frontend/operator/composite/composite.cc | 18 +- .../fallback/test_graph_fallback_runtime.py | 227 +++++++++++++++++- 2 files changed, 238 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 169e7644557..4e66d2b6769 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -565,8 +565,10 @@ FuncGraphPtr MakeDictGradient::GenerateFuncGraph(const AbstractBasePtrList &args FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { int64_t args_size = SizeToLong(args_spec_list.size()); constexpr auto py_execute_grad_input_count = 3; - constexpr auto op_name = "PyExecute"; - CheckArgsSize(op_name, args_spec_list, py_execute_grad_input_count); + constexpr auto op_name = "PyExecuteGradient"; + if (args_size < py_execute_grad_input_count) { + MS_LOG(EXCEPTION) << "The inputs size of " << op_name << " should not less than " << py_execute_grad_input_count; + } std::ostringstream ss; // â–¶PyExecute @@ -636,6 +638,18 @@ FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &arg } (void)grads.emplace_back(bprop->NewCNodeInOrder(values)); + // Add gradients for extra monad. + for (size_t i = py_execute_grad_input_count; i < args_spec_list.size(); ++i) { + if (args_spec_list[i]->isa()) { + (void)grads.emplace_back(NewValueNode(kUMonad)); + } else if (args_spec_list[i]->isa()) { + (void)grads.emplace_back(NewValueNode(kIOMonad)); + } else { + MS_LOG(EXCEPTION) << "The extra input of " << op_name << " should be UMonad or IOMonad, but got " + << args_spec_list[i]->ToString(); + } + } + bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true); bprop->set_output(bprop->NewCNodeInOrder(grads)); diff --git a/tests/st/fallback/test_graph_fallback_runtime.py b/tests/st/fallback/test_graph_fallback_runtime.py index 19894a5e2da..6f9ca18a0e5 100644 --- a/tests/st/fallback/test_graph_fallback_runtime.py +++ b/tests/st/fallback/test_graph_fallback_runtime.py @@ -18,6 +18,8 @@ import numpy as np import mindspore as ms from mindspore.common.initializer import TruncatedNormal +from mindspore import ops +from mindspore import mutable ms.set_context(mode=ms.GRAPH_MODE) @@ -63,6 +65,23 @@ def test_fallback_np(): np.testing.assert_almost_equal(output, const_output, 3) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_np_grad(): + """ + Feature: Support JIT Fallback runtime feature. + Description: Support JIT Fallback runtime feature. + Expectation: No exception. + """ + a = ms.Tensor(np.array(4), ms.int32) + b = ms.Tensor(np.array(5), ms.int32) + output = ops.grad(Net())(a, b) + assert output == 0 + + class Net1(ms.nn.Cell): def np_function(self, a, b): x = a.asnumpy() @@ -91,6 +110,23 @@ def test_fallback_np_asnumpy(): np.testing.assert_almost_equal(output, const_output, 3) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_np_asnumpy_grad(): + """ + Feature: Support JIT Fallback runtime feature. + Description: Support JIT Fallback runtime feature. + Expectation: No exception. + """ + a = ms.Tensor(np.array(4), ms.int32) + b = ms.Tensor(np.array(5), ms.int32) + output = ops.grad(Net1())(a, b) + assert output == 0 + + @ms.jit def tensor_asnumpy(): tensor = ms.Tensor(np.arange(0, 6).reshape(2, 3)) @@ -277,6 +313,26 @@ def test_multiple_return_contains_dict_2(): assert out[1][1] == (1, 2) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_multiple_return_contains_dict_2_grad(): + """ + Feature: Return multiple outputs including dict. + Description: Support grad for dict return. + Expectation: Get expected gradient. + """ + @ms.jit + def dict_net_2(a): + x = {'a': a, 'b': 2} + return a, (x, (1, 2)) + + out = ops.grad(dict_net_2)(ms.Tensor([1])) + assert out == 2 + + def weight_variable(): """weight initial""" return TruncatedNormal(0.02) @@ -284,17 +340,14 @@ def weight_variable(): def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): """weight initial for conv layer""" - weight = weight_variable() return ms.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, - weight_init=weight, has_bias=False, pad_mode="valid") + weight_init="ones", has_bias=False, pad_mode="valid") def fc_with_initialize(input_channels, out_channels): """weight initial for fc layer""" - weight = weight_variable() - bias = weight_variable() - return ms.nn.Dense(input_channels, out_channels, weight, bias) + return ms.nn.Dense(input_channels, out_channels, "ones", "ones") @pytest.mark.level0 @@ -347,6 +400,54 @@ def test_net_dict_1(): assert outputs['fc'].shape == (64, 10) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_net_dict_1_grad(): + """ + Feature: Return dict. + Description: Support grad for dict return. + Expectation: Get expected gradient. + """ + class DictLeNetNet(ms.nn.Cell): + def __init__(self, num_class=10): + super(DictLeNetNet, self).__init__() + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, 10) + self.relu = ms.nn.ReLU() + self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = ms.nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + conv1_x = x + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + conv2_x = x + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + fc_x = x + outputs = dict(conv1=conv1_x, conv2=conv2_x, fc=fc_x) + return outputs + + net = DictLeNetNet() + x = ms.Tensor(np.random.rand(64, 1, 32, 32).astype(np.float32)) + outputs = ops.grad(net)(x) + assert np.all(outputs.asnumpy() == np.zeros((64, 1, 32, 32))) + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_arm_ascend_training @@ -397,6 +498,86 @@ def test_net_dict_2(): assert outputs['fc'].shape == (64, 10) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_net_dict_2_grad(): + """ + Feature: Return dict. + Description: Support grad for dict return. + Expectation: Get expected gradients. + """ + class LeNet(ms.nn.Cell): + def __init__(self, num_class=10): + super(LeNet, self).__init__() + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, 10) + self.relu = ms.nn.ReLU() + self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = ms.nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + output_conv1 = x + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + output_conv2 = x + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + output_fc = x + return output_conv1, output_conv2, output_fc + + class DictLeNet(ms.nn.Cell): + def __init__(self, num_class=10): + super(DictLeNet, self).__init__() + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, 10) + self.relu = ms.nn.ReLU() + self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = ms.nn.Flatten() + + def construct(self, x): + outputs = dict() + x = self.conv1(x) + outputs['conv1'] = x + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + outputs['conv2'] = x + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + outputs['fc'] = x + return outputs + + x = ms.Tensor(np.random.rand(64, 1, 32, 32).astype(np.float32)) + net = LeNet() + outputs1 = ops.grad(net)(x) + dict_lenet = DictLeNet() + outputs2 = ops.grad(dict_lenet)(x) + assert np.all(outputs1.asnumpy() == outputs2.asnumpy()) + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_arm_ascend_training @@ -656,3 +837,39 @@ def test_parser_fallback_nested_class_outer(): y = 4 net = NestedNet() assert net(x, y) == 12 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_parser_fallback_nested_class_outer_grad(): + """ + Feature: Syntax getattr. + Description: Graph syntax getattr support custom class input. + Expectation: AttributeError. + """ + class Inner: + def __init__(self): + self.number = ms.Tensor(2, dtype=ms.int32) + + def act(self, x, y): + return self.number * (x + y) + + @ms.jit_class + class InnerNet: + def __init__(self): + self.inner = Inner() + + class NestedNet(ms.nn.Cell): + @ms.jit + def construct(self, x, y): + out = InnerNet().inner.act(x, y) + return out + + x = 2 + y = 4 + net = NestedNet() + output = ops.grad(net)(mutable(x), y) + assert output == 0