Add testcases for JIT Fallback gradient

This commit is contained in:
yujianfeng 2023-02-13 20:21:11 +08:00
parent 5157591443
commit d8807e3e85
2 changed files with 238 additions and 7 deletions

View File

@ -565,8 +565,10 @@ FuncGraphPtr MakeDictGradient::GenerateFuncGraph(const AbstractBasePtrList &args
FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
int64_t args_size = SizeToLong(args_spec_list.size()); int64_t args_size = SizeToLong(args_spec_list.size());
constexpr auto py_execute_grad_input_count = 3; constexpr auto py_execute_grad_input_count = 3;
constexpr auto op_name = "PyExecute"; constexpr auto op_name = "PyExecuteGradient";
CheckArgsSize(op_name, args_spec_list, py_execute_grad_input_count); if (args_size < py_execute_grad_input_count) {
MS_LOG(EXCEPTION) << "The inputs size of " << op_name << " should not less than " << py_execute_grad_input_count;
}
std::ostringstream ss; std::ostringstream ss;
// ▶PyExecute // ▶PyExecute
@ -636,6 +638,18 @@ FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
} }
(void)grads.emplace_back(bprop->NewCNodeInOrder(values)); (void)grads.emplace_back(bprop->NewCNodeInOrder(values));
// Add gradients for extra monad.
for (size_t i = py_execute_grad_input_count; i < args_spec_list.size(); ++i) {
if (args_spec_list[i]->isa<abstract::AbstractUMonad>()) {
(void)grads.emplace_back(NewValueNode(kUMonad));
} else if (args_spec_list[i]->isa<abstract::AbstractIOMonad>()) {
(void)grads.emplace_back(NewValueNode(kIOMonad));
} else {
MS_LOG(EXCEPTION) << "The extra input of " << op_name << " should be UMonad or IOMonad, but got "
<< args_spec_list[i]->ToString();
}
}
bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true); bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
bprop->set_output(bprop->NewCNodeInOrder(grads)); bprop->set_output(bprop->NewCNodeInOrder(grads));

View File

@ -18,6 +18,8 @@ import numpy as np
import mindspore as ms import mindspore as ms
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
from mindspore import ops
from mindspore import mutable
ms.set_context(mode=ms.GRAPH_MODE) ms.set_context(mode=ms.GRAPH_MODE)
@ -63,6 +65,23 @@ def test_fallback_np():
np.testing.assert_almost_equal(output, const_output, 3) np.testing.assert_almost_equal(output, const_output, 3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_np_grad():
"""
Feature: Support JIT Fallback runtime feature.
Description: Support JIT Fallback runtime feature.
Expectation: No exception.
"""
a = ms.Tensor(np.array(4), ms.int32)
b = ms.Tensor(np.array(5), ms.int32)
output = ops.grad(Net())(a, b)
assert output == 0
class Net1(ms.nn.Cell): class Net1(ms.nn.Cell):
def np_function(self, a, b): def np_function(self, a, b):
x = a.asnumpy() x = a.asnumpy()
@ -91,6 +110,23 @@ def test_fallback_np_asnumpy():
np.testing.assert_almost_equal(output, const_output, 3) np.testing.assert_almost_equal(output, const_output, 3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_np_asnumpy_grad():
"""
Feature: Support JIT Fallback runtime feature.
Description: Support JIT Fallback runtime feature.
Expectation: No exception.
"""
a = ms.Tensor(np.array(4), ms.int32)
b = ms.Tensor(np.array(5), ms.int32)
output = ops.grad(Net1())(a, b)
assert output == 0
@ms.jit @ms.jit
def tensor_asnumpy(): def tensor_asnumpy():
tensor = ms.Tensor(np.arange(0, 6).reshape(2, 3)) tensor = ms.Tensor(np.arange(0, 6).reshape(2, 3))
@ -277,6 +313,26 @@ def test_multiple_return_contains_dict_2():
assert out[1][1] == (1, 2) assert out[1][1] == (1, 2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_multiple_return_contains_dict_2_grad():
"""
Feature: Return multiple outputs including dict.
Description: Support grad for dict return.
Expectation: Get expected gradient.
"""
@ms.jit
def dict_net_2(a):
x = {'a': a, 'b': 2}
return a, (x, (1, 2))
out = ops.grad(dict_net_2)(ms.Tensor([1]))
assert out == 2
def weight_variable(): def weight_variable():
"""weight initial""" """weight initial"""
return TruncatedNormal(0.02) return TruncatedNormal(0.02)
@ -284,17 +340,14 @@ def weight_variable():
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer""" """weight initial for conv layer"""
weight = weight_variable()
return ms.nn.Conv2d(in_channels, out_channels, return ms.nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding, kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid") weight_init="ones", has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels): def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer""" """weight initial for fc layer"""
weight = weight_variable() return ms.nn.Dense(input_channels, out_channels, "ones", "ones")
bias = weight_variable()
return ms.nn.Dense(input_channels, out_channels, weight, bias)
@pytest.mark.level0 @pytest.mark.level0
@ -347,6 +400,54 @@ def test_net_dict_1():
assert outputs['fc'].shape == (64, 10) assert outputs['fc'].shape == (64, 10)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_net_dict_1_grad():
"""
Feature: Return dict.
Description: Support grad for dict return.
Expectation: Get expected gradient.
"""
class DictLeNetNet(ms.nn.Cell):
def __init__(self, num_class=10):
super(DictLeNetNet, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = ms.nn.ReLU()
self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = ms.nn.Flatten()
def construct(self, x):
x = self.conv1(x)
conv1_x = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
conv2_x = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
fc_x = x
outputs = dict(conv1=conv1_x, conv2=conv2_x, fc=fc_x)
return outputs
net = DictLeNetNet()
x = ms.Tensor(np.random.rand(64, 1, 32, 32).astype(np.float32))
outputs = ops.grad(net)(x)
assert np.all(outputs.asnumpy() == np.zeros((64, 1, 32, 32)))
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@ -397,6 +498,86 @@ def test_net_dict_2():
assert outputs['fc'].shape == (64, 10) assert outputs['fc'].shape == (64, 10)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_net_dict_2_grad():
"""
Feature: Return dict.
Description: Support grad for dict return.
Expectation: Get expected gradients.
"""
class LeNet(ms.nn.Cell):
def __init__(self, num_class=10):
super(LeNet, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = ms.nn.ReLU()
self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = ms.nn.Flatten()
def construct(self, x):
x = self.conv1(x)
output_conv1 = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
output_conv2 = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
output_fc = x
return output_conv1, output_conv2, output_fc
class DictLeNet(ms.nn.Cell):
def __init__(self, num_class=10):
super(DictLeNet, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = ms.nn.ReLU()
self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = ms.nn.Flatten()
def construct(self, x):
outputs = dict()
x = self.conv1(x)
outputs['conv1'] = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
outputs['conv2'] = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
outputs['fc'] = x
return outputs
x = ms.Tensor(np.random.rand(64, 1, 32, 32).astype(np.float32))
net = LeNet()
outputs1 = ops.grad(net)(x)
dict_lenet = DictLeNet()
outputs2 = ops.grad(dict_lenet)(x)
assert np.all(outputs1.asnumpy() == outputs2.asnumpy())
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@ -656,3 +837,39 @@ def test_parser_fallback_nested_class_outer():
y = 4 y = 4
net = NestedNet() net = NestedNet()
assert net(x, y) == 12 assert net(x, y) == 12
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parser_fallback_nested_class_outer_grad():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support custom class input.
Expectation: AttributeError.
"""
class Inner:
def __init__(self):
self.number = ms.Tensor(2, dtype=ms.int32)
def act(self, x, y):
return self.number * (x + y)
@ms.jit_class
class InnerNet:
def __init__(self):
self.inner = Inner()
class NestedNet(ms.nn.Cell):
@ms.jit
def construct(self, x, y):
out = InnerNet().inner.act(x, y)
return out
x = 2
y = 4
net = NestedNet()
output = ops.grad(net)(mutable(x), y)
assert output == 0