!37351 add testcases for tensor index operations
Merge pull request !37351 from zhengzuohe/testcase_tensorindex
This commit is contained in:
commit
7d5d51ca90
|
@ -14,183 +14,67 @@
|
|||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, ops, ParameterTuple
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from mindspore.nn import Cell
|
||||
|
||||
|
||||
class NumpyGetItem():
|
||||
def __init__(self, index1, index2):
|
||||
super(NumpyGetItem, self).__init__()
|
||||
self.index1 = index1
|
||||
self.index2 = index2
|
||||
class _Grad(Cell):
|
||||
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
|
||||
super().__init__()
|
||||
self.network = network
|
||||
self.grad = grad
|
||||
self.sens_param = self.grad.sens_param
|
||||
self.wrt_params = wrt_params
|
||||
self.real_inputs_count = real_inputs_count
|
||||
if self.wrt_params:
|
||||
self.params = ParameterTuple(self.network.trainable_params())
|
||||
|
||||
def __call__(self, tensor1, tensor2):
|
||||
return tensor1[self.index1], tensor2[self.index2]
|
||||
def construct(self, *inputs):
|
||||
if self.real_inputs_count is None or self.sens_param is False:
|
||||
if self.wrt_params:
|
||||
return self.grad(self.network, self.params)(*inputs)
|
||||
return self.grad(self.network)(*inputs)
|
||||
|
||||
real_inputs = inputs[:self.real_inputs_count]
|
||||
sense_param_inputs = inputs[self.real_inputs_count:]
|
||||
if self.wrt_params:
|
||||
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
|
||||
return self.grad(self.network)(*real_inputs, sense_param_inputs)
|
||||
|
||||
|
||||
class TensorGetItem(nn.Cell):
|
||||
def __init__(self, index1, index2):
|
||||
super(TensorGetItem, self).__init__()
|
||||
self.index1 = index1
|
||||
self.index2 = index2
|
||||
|
||||
def construct(self, tensor1, tensor2):
|
||||
return tensor1[self.index1], tensor2[self.index2]
|
||||
|
||||
|
||||
def common_func(ms_net, np_net):
|
||||
x = Tensor(shape=[8, None, 32], dtype=mindspore.float32)
|
||||
y = Tensor(shape=[None, 32, 32], dtype=mindspore.float32)
|
||||
ms_net.set_inputs(x, y)
|
||||
input_np1 = np.arange(8 * 16 * 32).reshape(8, 16, 32).astype(np.float32)
|
||||
input_np2 = np.arange(16 * 32 * 32).reshape(16, 32, 32).astype(np.float32)
|
||||
out0, out1 = ms_net(Tensor(input_np1), Tensor(input_np2))
|
||||
out_np0, out_np1 = np_net(input_np1, input_np2)
|
||||
assert np.all(out0.asnumpy() == out_np0)
|
||||
assert np.all(out1.asnumpy() == out_np1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_int_negative():
|
||||
class GradOfFirstInput(_Grad):
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is negative int.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
get grad of first input
|
||||
"""
|
||||
index1 = -2
|
||||
index2 = -1
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
def __init__(self, network, sens_param=True, real_inputs_count=None):
|
||||
super().__init__(grad=GradOperation(sens_param=sens_param),
|
||||
network=network, real_inputs_count=real_inputs_count)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_int():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is int.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = 2
|
||||
index2 = 1
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
class CommonFunc():
|
||||
def __init__(self, ms_net, np_net, input_np):
|
||||
super(CommonFunc, self).__init__()
|
||||
self.ms_net = ms_net
|
||||
self.ms_net.set_grad()
|
||||
self.np_net = np_net
|
||||
|
||||
self.input_np = input_np
|
||||
self.input_np_bp = input_np
|
||||
self.out_np = np.array(1).astype(input_np.dtype)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_tuple_basic():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is basic tuple.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = (1, slice(0, 1, 1), ...)
|
||||
index2 = (slice(2, None, None), 1, slice(3, 4, None))
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
def forward_cmp(self):
|
||||
out_ms = self.ms_net(Tensor(self.input_np))
|
||||
self.out_np = self.np_net(self.input_np)
|
||||
assert np.all(out_ms.asnumpy() == self.out_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_tuple_basic_neg():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is basic tuple(int is negative).
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = (slice(0, 1, 1), ..., -1)
|
||||
index2 = (-2, slice(2, None, None), slice(3, 4, None))
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_tuple():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is tuple.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
tensor_index = Tensor(np.array([[1, 2, 1], [0, 3, 2]]), mindspore.int32)
|
||||
index1 = (slice(2, None, None), (0, 2, 1), tensor_index)
|
||||
index2 = (-1, slice(0, 1, None), tensor_index)
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
index3 = (slice(2, None, None), (0, 2, 1), tensor_index.asnumpy())
|
||||
index4 = (-1, slice(0, 1, None), tensor_index.asnumpy())
|
||||
np_net = NumpyGetItem(index3, index4)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_bool():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is bool.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index = True
|
||||
ms_net = TensorGetItem(index, index)
|
||||
np_net = NumpyGetItem(index, index)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_none():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is none.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index = None
|
||||
ms_net = TensorGetItem(index, index)
|
||||
np_net = NumpyGetItem(index, index)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
def grad_impl(self):
|
||||
grad_net = GradOfFirstInput(self.ms_net)
|
||||
grad_net.set_train()
|
||||
grad_net(Tensor(self.input_np_bp), Tensor(self.out_np))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -203,13 +87,191 @@ def test_dynamic_getitem_ellipsis():
|
|||
Description: The input shape is dynamic and the tensor index is ellipsis.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index = ...
|
||||
ms_net = TensorGetItem(index, index)
|
||||
np_net = NumpyGetItem(index, index)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
class Net(Cell):
|
||||
def construct(self, x):
|
||||
x = x[...]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
x = x[...]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(None,), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(4).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_bool():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is bool.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
class Net(Cell):
|
||||
def construct(self, x):
|
||||
x = x[True]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
x = x[True]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(None, 3), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(2, 3).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_none():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is None.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
class Net(Cell):
|
||||
def construct(self, x):
|
||||
x = x[None]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
x = x[None]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(None, 3), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(2, 3).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_tensor():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is tensor of int.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.index = Tensor([0, 1])
|
||||
|
||||
def construct(self, x):
|
||||
index = self.index
|
||||
x = x[index]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
x = x[[0, 1]]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(None, 4), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(3, 4).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_tensor_001():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is dynamic shape tensor.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unique = ops.Unique()
|
||||
self.index = Tensor([1, 1, 1, 2])
|
||||
|
||||
def construct(self, x):
|
||||
index = self.unique(self.index)[0]
|
||||
x = x[index]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
index = np.unique(np.array([1, 1, 1, 2]))
|
||||
x = x[index]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(None, 3), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(3, 3).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -222,54 +284,184 @@ def test_dynamic_getitem_slice():
|
|||
Description: The input shape is dynamic and the tensor index is slice.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = slice(1, 5, 1)
|
||||
index2 = slice(1, None, None)
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
class Net(Cell):
|
||||
def construct(self, x):
|
||||
x = x[2:4]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
x = x[2:4]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(None, 4), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(6, 4).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_slice_neg():
|
||||
def test_dynamic_getitem_slice_001():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is negative slice.
|
||||
Description: The input shape is dynamic and the tensor index is slice with negative int.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = slice(-3, -1, 1)
|
||||
index2 = slice(-1, None, None)
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
class Net(Cell):
|
||||
def construct(self, x):
|
||||
x = x[-3:-1]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
x = x[-3:-1]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(None, 4), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(6, 4).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_tensor():
|
||||
def test_dynamic_getitem_int():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is tensor.
|
||||
Description: The input shape is dynamic and the tensor index is int.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = Tensor(np.array([[1, 2], [0, 3]]), mindspore.int32)
|
||||
index2 = Tensor(np.array([[1, 2]]), mindspore.int32)
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1.asnumpy(), index2.asnumpy())
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
class Net(Cell):
|
||||
def construct(self, x):
|
||||
x = x[-3]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
x = x[-3]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(None, 4), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(3, 4).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_int_001():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is int with control flow.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.extra = 0
|
||||
|
||||
def construct(self, x):
|
||||
index = 1 if self.extra > 1 else 2
|
||||
x = x[index]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
x = x[2]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(None, 2), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(3, 2).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_int_002():
|
||||
"""
|
||||
Feature: Test Tensor slice for twice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is int.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
class Net(Cell):
|
||||
def construct(self, x):
|
||||
x = x[3][4]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
x = x[3][4]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(None, None, 3), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(5, 5, 3).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -279,54 +471,157 @@ def test_dynamic_getitem_tensor():
|
|||
def test_dynamic_getitem_list():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is list.
|
||||
Description: The input shape is dynamic and the tensor index is list of bool and int.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = [True, 2, True]
|
||||
index2 = [1, 2, 0]
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
class Net(Cell):
|
||||
def construct(self, x):
|
||||
index = [False, 1]
|
||||
x = x[index]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
index = [False, 1]
|
||||
x = x[index]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(None,), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(5).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_slice_startoversize():
|
||||
def test_dynamic_getitem_tuple():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is slice and start is over size.
|
||||
Description: The input shape is dynamic and the tensor index is tuple of tensor and slice.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = slice(8, None, 1)
|
||||
index2 = slice(30, None, None)
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.extra = Tensor(0)
|
||||
self.extra2 = Tensor(2)
|
||||
|
||||
def construct(self, x):
|
||||
x = x[self.extra, self.extra:self.extra2, ...]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
x = x[0, 0:2, ...]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(2, None, 3), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(2, 4, 3).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_tuple_1():
|
||||
def test_dynamic_getitem_tuple_001():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is advanced tuple.
|
||||
Description: The input shape is dynamic and the tensor index is tuple of advanced indices.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = (2, True, 0)
|
||||
index2 = (-2, True, slice(0, 2, None))
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
class Net(Cell):
|
||||
def construct(self, x):
|
||||
index = (..., True, 4, slice(0, 2), None)
|
||||
x = x[index]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
index = (..., True, 4, slice(0, 2), None)
|
||||
x = x[index]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(3, 4, None, 2), dtype=mstype.float32)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(3, 4, 5, 2).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_tuple_002():
|
||||
"""
|
||||
Feature: Test Tensor slice for twice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is tuple of advanced indices.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.extra = Tensor([2, 3])
|
||||
|
||||
def construct(self, x):
|
||||
|
||||
x = x[True, [1, 2]][..., self.extra]
|
||||
return x
|
||||
|
||||
class NumpyNet():
|
||||
@classmethod
|
||||
def __call__(cls, x):
|
||||
x = x[True, [1, 2]][..., [2, 3]]
|
||||
return x
|
||||
|
||||
net_ms = Net()
|
||||
dynamic_input = Tensor(shape=(None, 4, 5, 2, None),
|
||||
dtype=mstype.float32) # (1,2,4,5,2,None)
|
||||
net_ms.set_inputs(dynamic_input)
|
||||
net_np = NumpyNet()
|
||||
input_np = np.random.randn(3, 4, 5, 2, 4).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(net_ms, net_np, input_np)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
|
|
@ -15,10 +15,81 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import context, nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, ParameterTuple
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
class _Grad(Cell):
|
||||
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
|
||||
super().__init__()
|
||||
self.network = network
|
||||
self.grad = grad
|
||||
self.sens_param = self.grad.sens_param
|
||||
self.wrt_params = wrt_params
|
||||
self.real_inputs_count = real_inputs_count
|
||||
if self.wrt_params:
|
||||
self.params = ParameterTuple(self.network.trainable_params())
|
||||
|
||||
def construct(self, *inputs):
|
||||
if self.real_inputs_count is None or self.sens_param is False:
|
||||
if self.wrt_params:
|
||||
return self.grad(self.network, self.params)(*inputs)
|
||||
return self.grad(self.network)(*inputs)
|
||||
|
||||
real_inputs = inputs[:self.real_inputs_count]
|
||||
sense_param_inputs = inputs[self.real_inputs_count:]
|
||||
if self.wrt_params:
|
||||
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
|
||||
return self.grad(self.network)(*real_inputs, sense_param_inputs)
|
||||
|
||||
|
||||
class GradOfAllInputs(_Grad):
|
||||
"""
|
||||
get grad of all inputs
|
||||
"""
|
||||
|
||||
def __init__(self, network, sens_param=True, real_inputs_count=None):
|
||||
super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param),
|
||||
network=network, real_inputs_count=real_inputs_count)
|
||||
|
||||
|
||||
class CommonFunc():
|
||||
def __init__(self, ms_net, np_net):
|
||||
super(CommonFunc, self).__init__()
|
||||
self.ms_net = ms_net
|
||||
self.ms_net.set_grad()
|
||||
self.np_net = np_net
|
||||
|
||||
input_dyn0 = Tensor(shape=[8, None, 3], dtype=mstype.float32)
|
||||
input_dyn1 = Tensor(shape=[None, 32, 3], dtype=mstype.float32)
|
||||
ms_net.set_inputs(input_dyn0, input_dyn1)
|
||||
|
||||
self.input_np0 = np.arange(
|
||||
8 * 16 * 3).reshape(8, 16, 3).astype(np.float32)
|
||||
self.input_np1 = np.arange(
|
||||
16 * 32 * 3).reshape(16, 32, 3).astype(np.float32)
|
||||
self.input_np0_bp = self.input_np0.copy()
|
||||
self.input_np1_bp = self.input_np1.copy()
|
||||
self.out_np0 = np.array(1).astype(self.input_np0.dtype)
|
||||
self.out_np1 = np.array(1).astype(self.input_np1.dtype)
|
||||
|
||||
def forward_cmp(self):
|
||||
out_ms0, out_ms1 = self.ms_net(
|
||||
Tensor(self.input_np0), Tensor(self.input_np1))
|
||||
self.out_np0, self.out_np1 = self. np_net(
|
||||
self.input_np0, self.input_np1)
|
||||
assert np.all(out_ms0.asnumpy() == self.out_np0)
|
||||
assert np.all(out_ms1.asnumpy() == self.out_np1)
|
||||
|
||||
def grad_impl(self):
|
||||
grad_net = GradOfAllInputs(self.ms_net)
|
||||
grad_net.set_train()
|
||||
grad_net(Tensor(self.input_np0_bp), Tensor(self.input_np1_bp),
|
||||
(Tensor(self.out_np0), Tensor(self.out_np1)))
|
||||
|
||||
|
||||
class NumpySetItem():
|
||||
def __init__(self, index, value):
|
||||
super(NumpySetItem, self).__init__()
|
||||
|
@ -43,18 +114,6 @@ class TensorSetItem(nn.Cell):
|
|||
return tensor1, tensor2
|
||||
|
||||
|
||||
def common_func(ms_net, np_net):
|
||||
x = Tensor(shape=[8, None, 3], dtype=mstype.float32)
|
||||
y = Tensor(shape=[None, 32, 3], dtype=mstype.float32)
|
||||
ms_net.set_inputs(x, y)
|
||||
input_np1 = np.arange(8 * 16 * 3).reshape(8, 16, 3).astype(np.float32)
|
||||
input_np2 = np.arange(16 * 32 * 3).reshape(16, 32, 3).astype(np.float32)
|
||||
out0, out1 = ms_net(Tensor(input_np1), Tensor(input_np2))
|
||||
out_np0, out_np1 = np_net(input_np1, input_np2)
|
||||
assert np.all(out0.asnumpy() == out_np0)
|
||||
assert np.all(out1.asnumpy() == out_np1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -69,10 +128,14 @@ def test_dynamic_setitem_int_number():
|
|||
value = 88.0
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -90,10 +153,14 @@ def test_dynamic_setitem_int_tensor():
|
|||
(1 * 3)).astype(np.float32), mstype.float32)
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value.asnumpy())
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -110,10 +177,14 @@ def test_dynamic_setitem_int_sequence():
|
|||
value = (1.0, Tensor(5, mstype.float32), 8.0)
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -131,10 +202,14 @@ def test_dynamic_setitem_tensor_number():
|
|||
value = 88.0
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index.asnumpy(), value)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -153,10 +228,14 @@ def test_dynamic_setitem_tensor_tensor():
|
|||
(1 * 3)).astype(np.float32), mstype.float32)
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index.asnumpy(), value.asnumpy())
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -174,10 +253,14 @@ def test_dynamic_setitem_tensor_sequence():
|
|||
value = (1.0, Tensor(5, mstype.float32), 8.0)
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index.asnumpy(), value)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -194,10 +277,10 @@ def test_dynamic_setitem_none_number():
|
|||
value = 88.0
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -215,10 +298,10 @@ def test_dynamic_setitem_none_tensor():
|
|||
(1 * 3)).astype(np.float32), mstype.float32)
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value.asnumpy())
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -235,10 +318,10 @@ def test_dynamic_setitem_none_sequence():
|
|||
value = (1.0, Tensor(5, mstype.float32), 8.0)
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -255,10 +338,10 @@ def test_dynamic_setitem_ellipsis_number():
|
|||
value = 88.0
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -276,10 +359,10 @@ def test_dynamic_setitem_ellipsis_tensor():
|
|||
(1 * 3)).astype(np.float32), mstype.float32)
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value.asnumpy())
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -296,10 +379,10 @@ def test_dynamic_setitem_ellipsis_sequence():
|
|||
value = (1.0, Tensor(5, mstype.float32), 8.0)
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -316,10 +399,10 @@ def test_dynamic_setitem_bool_number():
|
|||
value = 88.0
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -337,10 +420,10 @@ def test_dynamic_setitem_bool_tensor():
|
|||
(1 * 3)).astype(np.float32), mstype.float32)
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value.asnumpy())
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -357,10 +440,10 @@ def test_dynamic_setitem_bool_sequence():
|
|||
value = (1.0, Tensor(5, mstype.float32), 8.0)
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -377,10 +460,14 @@ def test_dynamic_setitem_list_number():
|
|||
value = 88.0
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -398,10 +485,14 @@ def test_dynamic_setitem_list_tensor():
|
|||
(1 * 3)).astype(np.float32), mstype.float32)
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value.asnumpy())
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -418,7 +509,11 @@ def test_dynamic_setitem_list_sequence():
|
|||
value = (1.0, Tensor(5, mstype.float32), 8.0)
|
||||
ms_net = TensorSetItem(index, value)
|
||||
np_net = NumpySetItem(index, value)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
fact = CommonFunc(ms_net, np_net)
|
||||
fact.forward_cmp()
|
||||
fact.grad_impl()
|
||||
|
|
Loading…
Reference in New Issue