fixed decorators in cpu tests

This commit is contained in:
huangbo77 2021-07-20 10:24:06 +08:00
parent 4ebb0a6dd8
commit b7ad898ce2
29 changed files with 119 additions and 86 deletions

View File

@ -20,6 +20,7 @@
namespace mindspore {
namespace kernel {
void MirrorPadCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
std::string mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "mode");
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
if (mode == "REFLECT") {
@ -47,7 +48,6 @@ void MirrorPadCPUKernel::InitKernel(const CNodePtr &kernel_node) {
tensor_size_ *= input_shape[i];
input_shape_.push_back(SizeToLong(input_shape[i]));
}
std::vector<size_t> padding_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
num_paddings_ = SizeToLong(padding_shape[0]);
@ -59,7 +59,6 @@ void MirrorPadCPUKernel::InitKernel(const CNodePtr &kernel_node) {
int64_t max_width = input_shape_[3];
int64_t max_height = input_shape_[2];
if (mode_ == 1) { // symmetric
max_width = max_width + (2 * max_width);
max_height = max_height + (2 * max_height);
@ -110,7 +109,6 @@ void MirrorPadCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, con
const int64_t padded_height = output_shape_[dim_offset];
const int64_t padded_width = output_shape_[dim_offset + 1];
const int64_t padd_dim = num_paddings_;
const int64_t mode = mode_;
int64_t paddings[MAX_PADDINGS * PADDING_SIZE]; // local and fixed size to keep in registers

View File

@ -77,6 +77,20 @@ MS_REG_CPU_KERNEL(
MS_REG_CPU_KERNEL(
MirrorPad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
MirrorPadCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
MirrorPadCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
MirrorPadCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
MirrorPadCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MIRROR_PAD_CPU_KERNEL_H_

View File

@ -94,6 +94,21 @@ MS_REG_CPU_KERNEL(
MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
MirrorPadGradCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
MirrorPadGradCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
MirrorPadGradCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
MirrorPadGradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MIRROR_PAD_CPU_KERNEL_H_

View File

@ -801,7 +801,7 @@ class Pad(Cell):
if mode == "CONSTANT":
self.pad = P.Pad(self.paddings)
else:
self.paddings = Tensor(np.array(self.paddings))
self.paddings = Tensor(np.array(self.paddings), dtype=mstype.int64)
self.pad = P.MirrorPad(mode=mode)
def construct(self, x):

View File

@ -22,6 +22,9 @@ mirror_pad_op_info = CpuRegOp("MirrorPad") \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()

View File

@ -22,6 +22,9 @@ mirror_pad_grad_op_info = CpuRegOp("MirrorPadGrad") \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()

View File

@ -82,7 +82,7 @@ def test_sub():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_div():
prop = 1 if np.random.random() < 0.5 else -1
@ -175,7 +175,7 @@ def test_div():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_floor_div():
prop = 1 if np.random.random() < 0.5 else -1
@ -240,7 +240,7 @@ def test_floor_div():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mod():
prop = 1 if np.random.random() < 0.5 else -1
@ -334,7 +334,7 @@ def test_mod():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_floor_mod():
prop = 1 if np.random.random() < 0.5 else -1

View File

@ -22,7 +22,7 @@ from mindspore.ops import operations as P
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_broadcast():
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@ -76,7 +76,7 @@ def test_broadcast():
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_broadcast_dyn_init():
"""
@ -105,7 +105,7 @@ def test_broadcast_dyn_init():
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_broadcast_dyn_invalid_init():
"""

View File

@ -55,61 +55,61 @@ def DepthToSpace(nptype, block_size=2, input_shape=(1, 12, 1, 1)):
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_depthtospace_graph_float32():
DepthToSpace(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_depthtospace_graph_float16():
DepthToSpace(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_depthtospace_graph_int32():
DepthToSpace(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_depthtospace_graph_int64():
DepthToSpace(np.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_depthtospace_graph_int8():
DepthToSpace(np.int8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_depthtospace_graph_int16():
DepthToSpace(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_depthtospace_graph_uint8():
DepthToSpace(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_depthtospace_graph_uint16():
DepthToSpace(np.uint16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_depthtospace_graph_uint32():
DepthToSpace(np.uint32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_depthtospace_graph_uint64():
DepthToSpace(np.uint64)

View File

@ -34,7 +34,7 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net():
x0 = Tensor(np.array([np.log(-1), 0.4, np.log(0)]).astype(np.float16))

View File

@ -39,7 +39,7 @@ x3 = np.array([[1, 2], [3, 4], [5.0, 88.0]]).astype(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_nan():
ms_isnan = Netnan()

View File

@ -33,7 +33,7 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net01():
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')

View File

@ -34,7 +34,7 @@ class Net(Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l2normalize_float32():
x = np.arange(20*20*20*20).astype(np.float32).reshape(20, 20, 20, 20)
@ -50,7 +50,7 @@ def test_l2normalize_float32():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l2normalize_float16():
x = np.arange(96).astype(np.float16).reshape(2, 3, 4, 4)
@ -66,7 +66,7 @@ def test_l2normalize_float16():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l2normalize_axis():
axis = -2
@ -83,7 +83,7 @@ def test_l2normalize_axis():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l2normalize_epsilon():
axis = -1

View File

@ -30,7 +30,7 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net_fp32():
x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
@ -84,7 +84,7 @@ def test_net_fp32():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net_fp16():
x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16)
@ -138,7 +138,7 @@ def test_net_fp16():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net_int32():
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.int32)
@ -156,7 +156,7 @@ def test_net_int32():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net_int64():
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.int64)
@ -174,7 +174,7 @@ def test_net_int64():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net_float64():
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float64)
@ -192,7 +192,7 @@ def test_net_float64():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net_int16():
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.int16)

View File

@ -30,7 +30,7 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net():
x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)

View File

@ -33,7 +33,7 @@ class MaxmumGradNet(Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maximum_grad_random():
np.random.seed(0)
@ -49,7 +49,7 @@ def test_maximum_grad_random():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_broadcast_grad_cpu():
x = np.array([[[[0.659578],

View File

@ -42,7 +42,7 @@ class TwoTensorsMaximum(Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maximum_constScalar_tensor_int():
x = Tensor(np.array([[2, 3, 4], [100, 200, 300]]).astype(np.int32))
@ -58,7 +58,7 @@ def test_maximum_constScalar_tensor_int():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maximum_two_tensors_Not_Broadcast_int():
x = Tensor(np.array([[2, 3, 4], [100, 200, 300]]).astype(np.int32))
@ -75,7 +75,7 @@ def test_maximum_two_tensors_Not_Broadcast_int():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maximum_two_tensors_Broadcast_int():
x = Tensor(np.array([[2, 3, 4], [100, 200, 300]]).astype(np.int32))
@ -92,7 +92,7 @@ def test_maximum_two_tensors_Broadcast_int():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maximum_two_tensors_Broadcast_oneDimension_int():
x = Tensor(np.array([[2, 3, 4], [100, 200, 300]]).astype(np.int32))
@ -109,7 +109,7 @@ def test_maximum_two_tensors_Broadcast_oneDimension_int():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maximum_two_tensors_notBroadcast_all_oneDimension_int():
x = Tensor(np.array([[2]]).astype(np.int32))
@ -126,7 +126,7 @@ def test_maximum_two_tensors_notBroadcast_all_oneDimension_int():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maximum_two_tensors_notBroadcast_float32():
x = Tensor(np.array([[2.0, 2.0], [-1, 100]]).astype(np.float32))
@ -143,7 +143,7 @@ def test_maximum_two_tensors_notBroadcast_float32():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maximum_two_tensors_notBroadcast_float64():
x = Tensor(np.array([[2.0, 2.0], [-1, 100]]).astype(np.float64))

View File

@ -24,7 +24,7 @@ from mindspore import Tensor
from mindspore.ops.composite import GradOperation
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mirror_pad():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@ -72,7 +72,7 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mirror_pad_backprop():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@ -88,7 +88,7 @@ def test_mirror_pad_backprop():
np.testing.assert_array_almost_equal(dx, expected_dx)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mirror_pad_fwd_back_4d_int32_reflect():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@ -129,7 +129,7 @@ def test_mirror_pad_fwd_back_4d_int32_reflect():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mirror_pad_fwd_back_4d_int32_symm():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")

View File

@ -41,7 +41,7 @@ class NetNorm(nn.Cell):
self.norm_4(indices))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_norm():
norm = NetNorm()

View File

@ -45,7 +45,7 @@ class NetOneHot(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_one_hot():
one_hot = NetOneHot()

View File

@ -32,7 +32,7 @@ class NetOnesLike(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_OnesLike():
x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)

View File

@ -25,7 +25,7 @@ from mindspore.ops.composite import GradOperation
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pad_basic():
"""
@ -53,7 +53,7 @@ def test_pad_basic():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pad_row():
"""
@ -84,7 +84,7 @@ def test_pad_row():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pad_column():
"""
@ -115,7 +115,7 @@ def test_pad_column():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pad_3d_pad():
"""
@ -173,7 +173,7 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pad_3d_backprop():
"""
@ -212,7 +212,7 @@ def test_pad_3d_backprop():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pad_error_cases():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")

View File

@ -32,7 +32,7 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net():
x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)

View File

@ -32,7 +32,7 @@ class NetRealDiv(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_real_div():
x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)

View File

@ -82,7 +82,7 @@ def rmspropcented_numpy(variable, gradients, mean_gradients, mean_square, moment
@pytest.mark.level0
@pytest.mark.platform_cpu_training
@pytest.mark.platform_cpu
@pytest.mark.env_onecard
def test_rmsprop():
learning_rate, decay, momentum, epsilon, centered = [0.5, 0.8, 0.9, 1e-3, True]
@ -143,7 +143,7 @@ def test_rmsprop():
@pytest.mark.level0
@pytest.mark.platform_cpu_training
@pytest.mark.platform_cpu
@pytest.mark.env_onecard
def test_rmspropcenter():
learning_rate, decay, momentum, epsilon, centered = [0.1, 0.3, 0.9, 1.0, False]

View File

@ -49,61 +49,61 @@ def SpaceToDepth(nptype):
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_spacetodepth_graph_float32():
SpaceToDepth(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_spacetodepth_graph_float16():
SpaceToDepth(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_spacetodepth_graph_int32():
SpaceToDepth(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_spacetodepth_graph_int64():
SpaceToDepth(np.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_spacetodepth_graph_int8():
SpaceToDepth(np.int8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_spacetodepth_graph_int16():
SpaceToDepth(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_spacetodepth_graph_uint8():
SpaceToDepth(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_spacetodepth_graph_uint16():
SpaceToDepth(np.uint16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_spacetodepth_graph_uint32():
SpaceToDepth(np.uint32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_spacetodepth_graph_uint64():
SpaceToDepth(np.uint64)

View File

@ -35,7 +35,7 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net01():
net = Net()
@ -65,7 +65,7 @@ def test_net01():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net02():
net = Net()

View File

@ -132,84 +132,84 @@ def unpack_pynative(nptype):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unpack_graph_float32():
unpack(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unpack_graph_float16():
unpack(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unpack_graph_int32():
unpack(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unpack_graph_int16():
unpack(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unpack_graph_uint8():
unpack(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unpack_graph_bool():
unpack(np.bool)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unpack_pynative_float32():
unpack_pynative(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unpack_pynative_float16():
unpack_pynative(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unpack_pynative_int32():
unpack_pynative(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unpack_pynative_int16():
unpack_pynative(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unpack_pynative_uint8():
unpack_pynative(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unpack_pynative_bool():
unpack_pynative(np.bool)

View File

@ -32,7 +32,7 @@ class NetZerosLike(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_ZerosLike():
x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)