forked from mindspore-Ecosystem/mindspore
Add testcases for JIT Fallback gradient
This commit is contained in:
parent
5157591443
commit
d8807e3e85
|
@ -565,8 +565,10 @@ FuncGraphPtr MakeDictGradient::GenerateFuncGraph(const AbstractBasePtrList &args
|
|||
FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
int64_t args_size = SizeToLong(args_spec_list.size());
|
||||
constexpr auto py_execute_grad_input_count = 3;
|
||||
constexpr auto op_name = "PyExecute";
|
||||
CheckArgsSize(op_name, args_spec_list, py_execute_grad_input_count);
|
||||
constexpr auto op_name = "PyExecuteGradient";
|
||||
if (args_size < py_execute_grad_input_count) {
|
||||
MS_LOG(EXCEPTION) << "The inputs size of " << op_name << " should not less than " << py_execute_grad_input_count;
|
||||
}
|
||||
|
||||
std::ostringstream ss;
|
||||
// ▶PyExecute
|
||||
|
@ -636,6 +638,18 @@ FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
|
|||
}
|
||||
(void)grads.emplace_back(bprop->NewCNodeInOrder(values));
|
||||
|
||||
// Add gradients for extra monad.
|
||||
for (size_t i = py_execute_grad_input_count; i < args_spec_list.size(); ++i) {
|
||||
if (args_spec_list[i]->isa<abstract::AbstractUMonad>()) {
|
||||
(void)grads.emplace_back(NewValueNode(kUMonad));
|
||||
} else if (args_spec_list[i]->isa<abstract::AbstractIOMonad>()) {
|
||||
(void)grads.emplace_back(NewValueNode(kIOMonad));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The extra input of " << op_name << " should be UMonad or IOMonad, but got "
|
||||
<< args_spec_list[i]->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
bprop->set_output(bprop->NewCNodeInOrder(grads));
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@ import numpy as np
|
|||
|
||||
import mindspore as ms
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore import ops
|
||||
from mindspore import mutable
|
||||
|
||||
ms.set_context(mode=ms.GRAPH_MODE)
|
||||
|
||||
|
@ -63,6 +65,23 @@ def test_fallback_np():
|
|||
np.testing.assert_almost_equal(output, const_output, 3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_np_grad():
|
||||
"""
|
||||
Feature: Support JIT Fallback runtime feature.
|
||||
Description: Support JIT Fallback runtime feature.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
a = ms.Tensor(np.array(4), ms.int32)
|
||||
b = ms.Tensor(np.array(5), ms.int32)
|
||||
output = ops.grad(Net())(a, b)
|
||||
assert output == 0
|
||||
|
||||
|
||||
class Net1(ms.nn.Cell):
|
||||
def np_function(self, a, b):
|
||||
x = a.asnumpy()
|
||||
|
@ -91,6 +110,23 @@ def test_fallback_np_asnumpy():
|
|||
np.testing.assert_almost_equal(output, const_output, 3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_np_asnumpy_grad():
|
||||
"""
|
||||
Feature: Support JIT Fallback runtime feature.
|
||||
Description: Support JIT Fallback runtime feature.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
a = ms.Tensor(np.array(4), ms.int32)
|
||||
b = ms.Tensor(np.array(5), ms.int32)
|
||||
output = ops.grad(Net1())(a, b)
|
||||
assert output == 0
|
||||
|
||||
|
||||
@ms.jit
|
||||
def tensor_asnumpy():
|
||||
tensor = ms.Tensor(np.arange(0, 6).reshape(2, 3))
|
||||
|
@ -277,6 +313,26 @@ def test_multiple_return_contains_dict_2():
|
|||
assert out[1][1] == (1, 2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_multiple_return_contains_dict_2_grad():
|
||||
"""
|
||||
Feature: Return multiple outputs including dict.
|
||||
Description: Support grad for dict return.
|
||||
Expectation: Get expected gradient.
|
||||
"""
|
||||
@ms.jit
|
||||
def dict_net_2(a):
|
||||
x = {'a': a, 'b': 2}
|
||||
return a, (x, (1, 2))
|
||||
|
||||
out = ops.grad(dict_net_2)(ms.Tensor([1]))
|
||||
assert out == 2
|
||||
|
||||
|
||||
def weight_variable():
|
||||
"""weight initial"""
|
||||
return TruncatedNormal(0.02)
|
||||
|
@ -284,17 +340,14 @@ def weight_variable():
|
|||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
"""weight initial for conv layer"""
|
||||
weight = weight_variable()
|
||||
return ms.nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
weight_init=weight, has_bias=False, pad_mode="valid")
|
||||
weight_init="ones", has_bias=False, pad_mode="valid")
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
"""weight initial for fc layer"""
|
||||
weight = weight_variable()
|
||||
bias = weight_variable()
|
||||
return ms.nn.Dense(input_channels, out_channels, weight, bias)
|
||||
return ms.nn.Dense(input_channels, out_channels, "ones", "ones")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -347,6 +400,54 @@ def test_net_dict_1():
|
|||
assert outputs['fc'].shape == (64, 10)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_dict_1_grad():
|
||||
"""
|
||||
Feature: Return dict.
|
||||
Description: Support grad for dict return.
|
||||
Expectation: Get expected gradient.
|
||||
"""
|
||||
class DictLeNetNet(ms.nn.Cell):
|
||||
def __init__(self, num_class=10):
|
||||
super(DictLeNetNet, self).__init__()
|
||||
self.conv1 = conv(1, 6, 5)
|
||||
self.conv2 = conv(6, 16, 5)
|
||||
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||
self.fc2 = fc_with_initialize(120, 84)
|
||||
self.fc3 = fc_with_initialize(84, 10)
|
||||
self.relu = ms.nn.ReLU()
|
||||
self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = ms.nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
conv1_x = x
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
conv2_x = x
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
fc_x = x
|
||||
outputs = dict(conv1=conv1_x, conv2=conv2_x, fc=fc_x)
|
||||
return outputs
|
||||
|
||||
net = DictLeNetNet()
|
||||
x = ms.Tensor(np.random.rand(64, 1, 32, 32).astype(np.float32))
|
||||
outputs = ops.grad(net)(x)
|
||||
assert np.all(outputs.asnumpy() == np.zeros((64, 1, 32, 32)))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -397,6 +498,86 @@ def test_net_dict_2():
|
|||
assert outputs['fc'].shape == (64, 10)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_dict_2_grad():
|
||||
"""
|
||||
Feature: Return dict.
|
||||
Description: Support grad for dict return.
|
||||
Expectation: Get expected gradients.
|
||||
"""
|
||||
class LeNet(ms.nn.Cell):
|
||||
def __init__(self, num_class=10):
|
||||
super(LeNet, self).__init__()
|
||||
self.conv1 = conv(1, 6, 5)
|
||||
self.conv2 = conv(6, 16, 5)
|
||||
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||
self.fc2 = fc_with_initialize(120, 84)
|
||||
self.fc3 = fc_with_initialize(84, 10)
|
||||
self.relu = ms.nn.ReLU()
|
||||
self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = ms.nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
output_conv1 = x
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
output_conv2 = x
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
output_fc = x
|
||||
return output_conv1, output_conv2, output_fc
|
||||
|
||||
class DictLeNet(ms.nn.Cell):
|
||||
def __init__(self, num_class=10):
|
||||
super(DictLeNet, self).__init__()
|
||||
self.conv1 = conv(1, 6, 5)
|
||||
self.conv2 = conv(6, 16, 5)
|
||||
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||
self.fc2 = fc_with_initialize(120, 84)
|
||||
self.fc3 = fc_with_initialize(84, 10)
|
||||
self.relu = ms.nn.ReLU()
|
||||
self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = ms.nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
outputs = dict()
|
||||
x = self.conv1(x)
|
||||
outputs['conv1'] = x
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
outputs['conv2'] = x
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
outputs['fc'] = x
|
||||
return outputs
|
||||
|
||||
x = ms.Tensor(np.random.rand(64, 1, 32, 32).astype(np.float32))
|
||||
net = LeNet()
|
||||
outputs1 = ops.grad(net)(x)
|
||||
dict_lenet = DictLeNet()
|
||||
outputs2 = ops.grad(dict_lenet)(x)
|
||||
assert np.all(outputs1.asnumpy() == outputs2.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -656,3 +837,39 @@ def test_parser_fallback_nested_class_outer():
|
|||
y = 4
|
||||
net = NestedNet()
|
||||
assert net(x, y) == 12
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parser_fallback_nested_class_outer_grad():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support custom class input.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
class Inner:
|
||||
def __init__(self):
|
||||
self.number = ms.Tensor(2, dtype=ms.int32)
|
||||
|
||||
def act(self, x, y):
|
||||
return self.number * (x + y)
|
||||
|
||||
@ms.jit_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.inner = Inner()
|
||||
|
||||
class NestedNet(ms.nn.Cell):
|
||||
@ms.jit
|
||||
def construct(self, x, y):
|
||||
out = InnerNet().inner.act(x, y)
|
||||
return out
|
||||
|
||||
x = 2
|
||||
y = 4
|
||||
net = NestedNet()
|
||||
output = ops.grad(net)(mutable(x), y)
|
||||
assert output == 0
|
||||
|
|
Loading…
Reference in New Issue