!37351 add testcases for tensor index operations

Merge pull request !37351 from zhengzuohe/testcase_tensorindex
This commit is contained in:
i-robot 2022-07-11 02:08:26 +00:00 committed by Gitee
commit 7d5d51ca90
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 703 additions and 313 deletions

View File

@ -14,183 +14,67 @@
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.common.dtype as mstype
from mindspore import Tensor, ops, ParameterTuple
from mindspore.ops.composite import GradOperation
from mindspore.nn import Cell
class NumpyGetItem():
def __init__(self, index1, index2):
super(NumpyGetItem, self).__init__()
self.index1 = index1
self.index2 = index2
class _Grad(Cell):
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
super().__init__()
self.network = network
self.grad = grad
self.sens_param = self.grad.sens_param
self.wrt_params = wrt_params
self.real_inputs_count = real_inputs_count
if self.wrt_params:
self.params = ParameterTuple(self.network.trainable_params())
def __call__(self, tensor1, tensor2):
return tensor1[self.index1], tensor2[self.index2]
def construct(self, *inputs):
if self.real_inputs_count is None or self.sens_param is False:
if self.wrt_params:
return self.grad(self.network, self.params)(*inputs)
return self.grad(self.network)(*inputs)
real_inputs = inputs[:self.real_inputs_count]
sense_param_inputs = inputs[self.real_inputs_count:]
if self.wrt_params:
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
return self.grad(self.network)(*real_inputs, sense_param_inputs)
class TensorGetItem(nn.Cell):
def __init__(self, index1, index2):
super(TensorGetItem, self).__init__()
self.index1 = index1
self.index2 = index2
def construct(self, tensor1, tensor2):
return tensor1[self.index1], tensor2[self.index2]
def common_func(ms_net, np_net):
x = Tensor(shape=[8, None, 32], dtype=mindspore.float32)
y = Tensor(shape=[None, 32, 32], dtype=mindspore.float32)
ms_net.set_inputs(x, y)
input_np1 = np.arange(8 * 16 * 32).reshape(8, 16, 32).astype(np.float32)
input_np2 = np.arange(16 * 32 * 32).reshape(16, 32, 32).astype(np.float32)
out0, out1 = ms_net(Tensor(input_np1), Tensor(input_np2))
out_np0, out_np1 = np_net(input_np1, input_np2)
assert np.all(out0.asnumpy() == out_np0)
assert np.all(out1.asnumpy() == out_np1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_int_negative():
class GradOfFirstInput(_Grad):
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is negative int.
Expectation: Assert the result is equal the numpy result.
get grad of first input
"""
index1 = -2
index2 = -1
ms_net = TensorGetItem(index1, index2)
np_net = NumpyGetItem(index1, index2)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
def __init__(self, network, sens_param=True, real_inputs_count=None):
super().__init__(grad=GradOperation(sens_param=sens_param),
network=network, real_inputs_count=real_inputs_count)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_int():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is int.
Expectation: Assert the result is equal the numpy result.
"""
index1 = 2
index2 = 1
ms_net = TensorGetItem(index1, index2)
np_net = NumpyGetItem(index1, index2)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
class CommonFunc():
def __init__(self, ms_net, np_net, input_np):
super(CommonFunc, self).__init__()
self.ms_net = ms_net
self.ms_net.set_grad()
self.np_net = np_net
self.input_np = input_np
self.input_np_bp = input_np
self.out_np = np.array(1).astype(input_np.dtype)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_tuple_basic():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is basic tuple.
Expectation: Assert the result is equal the numpy result.
"""
index1 = (1, slice(0, 1, 1), ...)
index2 = (slice(2, None, None), 1, slice(3, 4, None))
ms_net = TensorGetItem(index1, index2)
np_net = NumpyGetItem(index1, index2)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
def forward_cmp(self):
out_ms = self.ms_net(Tensor(self.input_np))
self.out_np = self.np_net(self.input_np)
assert np.all(out_ms.asnumpy() == self.out_np)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_tuple_basic_neg():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is basic tuple(int is negative).
Expectation: Assert the result is equal the numpy result.
"""
index1 = (slice(0, 1, 1), ..., -1)
index2 = (-2, slice(2, None, None), slice(3, 4, None))
ms_net = TensorGetItem(index1, index2)
np_net = NumpyGetItem(index1, index2)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_tuple():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is tuple.
Expectation: Assert the result is equal the numpy result.
"""
tensor_index = Tensor(np.array([[1, 2, 1], [0, 3, 2]]), mindspore.int32)
index1 = (slice(2, None, None), (0, 2, 1), tensor_index)
index2 = (-1, slice(0, 1, None), tensor_index)
ms_net = TensorGetItem(index1, index2)
index3 = (slice(2, None, None), (0, 2, 1), tensor_index.asnumpy())
index4 = (-1, slice(0, 1, None), tensor_index.asnumpy())
np_net = NumpyGetItem(index3, index4)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_bool():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is bool.
Expectation: Assert the result is equal the numpy result.
"""
index = True
ms_net = TensorGetItem(index, index)
np_net = NumpyGetItem(index, index)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_none():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is none.
Expectation: Assert the result is equal the numpy result.
"""
index = None
ms_net = TensorGetItem(index, index)
np_net = NumpyGetItem(index, index)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
def grad_impl(self):
grad_net = GradOfFirstInput(self.ms_net)
grad_net.set_train()
grad_net(Tensor(self.input_np_bp), Tensor(self.out_np))
@pytest.mark.level0
@ -203,13 +87,191 @@ def test_dynamic_getitem_ellipsis():
Description: The input shape is dynamic and the tensor index is ellipsis.
Expectation: Assert the result is equal the numpy result.
"""
index = ...
ms_net = TensorGetItem(index, index)
np_net = NumpyGetItem(index, index)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
class Net(Cell):
def construct(self, x):
x = x[...]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
x = x[...]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(None,), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(4).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_bool():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is bool.
Expectation: Assert the result is equal the numpy result.
"""
class Net(Cell):
def construct(self, x):
x = x[True]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
x = x[True]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(None, 3), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(2, 3).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_none():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is None.
Expectation: Assert the result is equal the numpy result.
"""
class Net(Cell):
def construct(self, x):
x = x[None]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
x = x[None]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(None, 3), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(2, 3).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_tensor():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is tensor of int.
Expectation: Assert the result is equal the numpy result.
"""
class Net(Cell):
def __init__(self):
super().__init__()
self.index = Tensor([0, 1])
def construct(self, x):
index = self.index
x = x[index]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
x = x[[0, 1]]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(None, 4), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(3, 4).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_tensor_001():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is dynamic shape tensor.
Expectation: Assert the result is equal the numpy result.
"""
class Net(Cell):
def __init__(self):
super().__init__()
self.unique = ops.Unique()
self.index = Tensor([1, 1, 1, 2])
def construct(self, x):
index = self.unique(self.index)[0]
x = x[index]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
index = np.unique(np.array([1, 1, 1, 2]))
x = x[index]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(None, 3), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(3, 3).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -222,54 +284,184 @@ def test_dynamic_getitem_slice():
Description: The input shape is dynamic and the tensor index is slice.
Expectation: Assert the result is equal the numpy result.
"""
index1 = slice(1, 5, 1)
index2 = slice(1, None, None)
ms_net = TensorGetItem(index1, index2)
np_net = NumpyGetItem(index1, index2)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
class Net(Cell):
def construct(self, x):
x = x[2:4]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
x = x[2:4]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(None, 4), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(6, 4).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_slice_neg():
def test_dynamic_getitem_slice_001():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is negative slice.
Description: The input shape is dynamic and the tensor index is slice with negative int.
Expectation: Assert the result is equal the numpy result.
"""
index1 = slice(-3, -1, 1)
index2 = slice(-1, None, None)
ms_net = TensorGetItem(index1, index2)
np_net = NumpyGetItem(index1, index2)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
class Net(Cell):
def construct(self, x):
x = x[-3:-1]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
x = x[-3:-1]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(None, 4), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(6, 4).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_tensor():
def test_dynamic_getitem_int():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is tensor.
Description: The input shape is dynamic and the tensor index is int.
Expectation: Assert the result is equal the numpy result.
"""
index1 = Tensor(np.array([[1, 2], [0, 3]]), mindspore.int32)
index2 = Tensor(np.array([[1, 2]]), mindspore.int32)
ms_net = TensorGetItem(index1, index2)
np_net = NumpyGetItem(index1.asnumpy(), index2.asnumpy())
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
class Net(Cell):
def construct(self, x):
x = x[-3]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
x = x[-3]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(None, 4), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(3, 4).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_int_001():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is int with control flow.
Expectation: Assert the result is equal the numpy result.
"""
class Net(Cell):
def __init__(self):
super().__init__()
self.extra = 0
def construct(self, x):
index = 1 if self.extra > 1 else 2
x = x[index]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
x = x[2]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(None, 2), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(3, 2).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_int_002():
"""
Feature: Test Tensor slice for twice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is int.
Expectation: Assert the result is equal the numpy result.
"""
class Net(Cell):
def construct(self, x):
x = x[3][4]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
x = x[3][4]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(None, None, 3), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(5, 5, 3).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -279,54 +471,157 @@ def test_dynamic_getitem_tensor():
def test_dynamic_getitem_list():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is list.
Description: The input shape is dynamic and the tensor index is list of bool and int.
Expectation: Assert the result is equal the numpy result.
"""
index1 = [True, 2, True]
index2 = [1, 2, 0]
ms_net = TensorGetItem(index1, index2)
np_net = NumpyGetItem(index1, index2)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
class Net(Cell):
def construct(self, x):
index = [False, 1]
x = x[index]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
index = [False, 1]
x = x[index]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(None,), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(5).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_slice_startoversize():
def test_dynamic_getitem_tuple():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is slice and start is over size.
Description: The input shape is dynamic and the tensor index is tuple of tensor and slice.
Expectation: Assert the result is equal the numpy result.
"""
index1 = slice(8, None, 1)
index2 = slice(30, None, None)
ms_net = TensorGetItem(index1, index2)
np_net = NumpyGetItem(index1, index2)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
class Net(Cell):
def __init__(self):
super().__init__()
self.extra = Tensor(0)
self.extra2 = Tensor(2)
def construct(self, x):
x = x[self.extra, self.extra:self.extra2, ...]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
x = x[0, 0:2, ...]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(2, None, 3), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(2, 4, 3).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_tuple_1():
def test_dynamic_getitem_tuple_001():
"""
Feature: Test Tensor slice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is advanced tuple.
Description: The input shape is dynamic and the tensor index is tuple of advanced indices.
Expectation: Assert the result is equal the numpy result.
"""
index1 = (2, True, 0)
index2 = (-2, True, slice(0, 2, None))
ms_net = TensorGetItem(index1, index2)
np_net = NumpyGetItem(index1, index2)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
class Net(Cell):
def construct(self, x):
index = (..., True, 4, slice(0, 2), None)
x = x[index]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
index = (..., True, 4, slice(0, 2), None)
x = x[index]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(3, 4, None, 2), dtype=mstype.float32)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(3, 4, 5, 2).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_getitem_tuple_002():
"""
Feature: Test Tensor slice for twice for dynamic shape in feed mode.
Description: The input shape is dynamic and the tensor index is tuple of advanced indices.
Expectation: Assert the result is equal the numpy result.
"""
class Net(Cell):
def __init__(self):
super().__init__()
self.extra = Tensor([2, 3])
def construct(self, x):
x = x[True, [1, 2]][..., self.extra]
return x
class NumpyNet():
@classmethod
def __call__(cls, x):
x = x[True, [1, 2]][..., [2, 3]]
return x
net_ms = Net()
dynamic_input = Tensor(shape=(None, 4, 5, 2, None),
dtype=mstype.float32) # (1,2,4,5,2,None)
net_ms.set_inputs(dynamic_input)
net_np = NumpyNet()
input_np = np.random.randn(3, 4, 5, 2, 4).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(net_ms, net_np, input_np)
fact.forward_cmp()
fact.grad_impl()

View File

@ -15,10 +15,81 @@
import numpy as np
import pytest
from mindspore import context, nn
from mindspore import Tensor
from mindspore import Tensor, ParameterTuple
from mindspore.ops.composite import GradOperation
from mindspore.nn import Cell
import mindspore.common.dtype as mstype
class _Grad(Cell):
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
super().__init__()
self.network = network
self.grad = grad
self.sens_param = self.grad.sens_param
self.wrt_params = wrt_params
self.real_inputs_count = real_inputs_count
if self.wrt_params:
self.params = ParameterTuple(self.network.trainable_params())
def construct(self, *inputs):
if self.real_inputs_count is None or self.sens_param is False:
if self.wrt_params:
return self.grad(self.network, self.params)(*inputs)
return self.grad(self.network)(*inputs)
real_inputs = inputs[:self.real_inputs_count]
sense_param_inputs = inputs[self.real_inputs_count:]
if self.wrt_params:
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
return self.grad(self.network)(*real_inputs, sense_param_inputs)
class GradOfAllInputs(_Grad):
"""
get grad of all inputs
"""
def __init__(self, network, sens_param=True, real_inputs_count=None):
super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param),
network=network, real_inputs_count=real_inputs_count)
class CommonFunc():
def __init__(self, ms_net, np_net):
super(CommonFunc, self).__init__()
self.ms_net = ms_net
self.ms_net.set_grad()
self.np_net = np_net
input_dyn0 = Tensor(shape=[8, None, 3], dtype=mstype.float32)
input_dyn1 = Tensor(shape=[None, 32, 3], dtype=mstype.float32)
ms_net.set_inputs(input_dyn0, input_dyn1)
self.input_np0 = np.arange(
8 * 16 * 3).reshape(8, 16, 3).astype(np.float32)
self.input_np1 = np.arange(
16 * 32 * 3).reshape(16, 32, 3).astype(np.float32)
self.input_np0_bp = self.input_np0.copy()
self.input_np1_bp = self.input_np1.copy()
self.out_np0 = np.array(1).astype(self.input_np0.dtype)
self.out_np1 = np.array(1).astype(self.input_np1.dtype)
def forward_cmp(self):
out_ms0, out_ms1 = self.ms_net(
Tensor(self.input_np0), Tensor(self.input_np1))
self.out_np0, self.out_np1 = self. np_net(
self.input_np0, self.input_np1)
assert np.all(out_ms0.asnumpy() == self.out_np0)
assert np.all(out_ms1.asnumpy() == self.out_np1)
def grad_impl(self):
grad_net = GradOfAllInputs(self.ms_net)
grad_net.set_train()
grad_net(Tensor(self.input_np0_bp), Tensor(self.input_np1_bp),
(Tensor(self.out_np0), Tensor(self.out_np1)))
class NumpySetItem():
def __init__(self, index, value):
super(NumpySetItem, self).__init__()
@ -43,18 +114,6 @@ class TensorSetItem(nn.Cell):
return tensor1, tensor2
def common_func(ms_net, np_net):
x = Tensor(shape=[8, None, 3], dtype=mstype.float32)
y = Tensor(shape=[None, 32, 3], dtype=mstype.float32)
ms_net.set_inputs(x, y)
input_np1 = np.arange(8 * 16 * 3).reshape(8, 16, 3).astype(np.float32)
input_np2 = np.arange(16 * 32 * 3).reshape(16, 32, 3).astype(np.float32)
out0, out1 = ms_net(Tensor(input_np1), Tensor(input_np2))
out_np0, out_np1 = np_net(input_np1, input_np2)
assert np.all(out0.asnumpy() == out_np0)
assert np.all(out1.asnumpy() == out_np1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -69,10 +128,14 @@ def test_dynamic_setitem_int_number():
value = 88.0
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -90,10 +153,14 @@ def test_dynamic_setitem_int_tensor():
(1 * 3)).astype(np.float32), mstype.float32)
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value.asnumpy())
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -110,10 +177,14 @@ def test_dynamic_setitem_int_sequence():
value = (1.0, Tensor(5, mstype.float32), 8.0)
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -131,10 +202,14 @@ def test_dynamic_setitem_tensor_number():
value = 88.0
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index.asnumpy(), value)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -153,10 +228,14 @@ def test_dynamic_setitem_tensor_tensor():
(1 * 3)).astype(np.float32), mstype.float32)
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index.asnumpy(), value.asnumpy())
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -174,10 +253,14 @@ def test_dynamic_setitem_tensor_sequence():
value = (1.0, Tensor(5, mstype.float32), 8.0)
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index.asnumpy(), value)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -194,10 +277,10 @@ def test_dynamic_setitem_none_number():
value = 88.0
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -215,10 +298,10 @@ def test_dynamic_setitem_none_tensor():
(1 * 3)).astype(np.float32), mstype.float32)
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value.asnumpy())
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -235,10 +318,10 @@ def test_dynamic_setitem_none_sequence():
value = (1.0, Tensor(5, mstype.float32), 8.0)
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -255,10 +338,10 @@ def test_dynamic_setitem_ellipsis_number():
value = 88.0
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -276,10 +359,10 @@ def test_dynamic_setitem_ellipsis_tensor():
(1 * 3)).astype(np.float32), mstype.float32)
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value.asnumpy())
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -296,10 +379,10 @@ def test_dynamic_setitem_ellipsis_sequence():
value = (1.0, Tensor(5, mstype.float32), 8.0)
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -316,10 +399,10 @@ def test_dynamic_setitem_bool_number():
value = 88.0
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -337,10 +420,10 @@ def test_dynamic_setitem_bool_tensor():
(1 * 3)).astype(np.float32), mstype.float32)
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value.asnumpy())
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -357,10 +440,10 @@ def test_dynamic_setitem_bool_sequence():
value = (1.0, Tensor(5, mstype.float32), 8.0)
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -377,10 +460,14 @@ def test_dynamic_setitem_list_number():
value = 88.0
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -398,10 +485,14 @@ def test_dynamic_setitem_list_tensor():
(1 * 3)).astype(np.float32), mstype.float32)
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value.asnumpy())
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
@pytest.mark.level0
@ -418,7 +509,11 @@ def test_dynamic_setitem_list_sequence():
value = (1.0, Tensor(5, mstype.float32), 8.0)
ms_net = TensorSetItem(index, value)
np_net = NumpySetItem(index, value)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
common_func(ms_net, np_net)
context.set_context(mode=context.PYNATIVE_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()
context.set_context(mode=context.GRAPH_MODE)
fact = CommonFunc(ms_net, np_net)
fact.forward_cmp()
fact.grad_impl()