diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/masked_select_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/masked_select_grad_cpu_kernel.cc index aa7ad60a560..502c71e32e7 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/masked_select_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/masked_select_grad_cpu_kernel.cc @@ -64,8 +64,9 @@ int MaskedSelectGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, input_shape_b_ = inputs[kIndexMask]->GetShapeVector(); grad_shape_ = inputs[kIndexGrad]->GetShapeVector(); output_shape_ = CPUKernelUtils::GetBroadcastShape(input_shape_a_, input_shape_b_); - if (KernelMod::Resize(base_operator, inputs, outputs) != KRET_OK) { - MS_LOG(EXCEPTION) << "MaskedSelectGradCpuKernelMod resize failed."; + const auto ret = KernelMod::Resize(base_operator, inputs, outputs); + if (ret != KRET_OK) { + return ret; } tensor_size_ = 1; tensor_size_ = @@ -114,41 +115,89 @@ bool MaskedSelectGradCpuKernelMod::LaunchKernel(const std::vector> MaskedSelectGradCpuKernelMod::func_list_ = {{KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeBool) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - &MaskedSelectGradCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeBool) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - &MaskedSelectGradCpuKernelMod::LaunchKernel}, - {KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeBool) .AddInputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16), &MaskedSelectGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &MaskedSelectGradCpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kNumberTypeFloat64) .AddInputAttr(kNumberTypeBool) .AddInputAttr(kNumberTypeFloat64) .AddOutputAttr(kNumberTypeFloat64), &MaskedSelectGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + &MaskedSelectGradCpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kNumberTypeInt16) .AddInputAttr(kNumberTypeBool) .AddInputAttr(kNumberTypeInt16) .AddOutputAttr(kNumberTypeInt16), &MaskedSelectGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &MaskedSelectGradCpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeBool) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeInt64), - &MaskedSelectGradCpuKernelMod::LaunchKernel}}; + &MaskedSelectGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + &MaskedSelectGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeUInt16), + &MaskedSelectGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeUInt32), + &MaskedSelectGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeUInt64) + .AddOutputAttr(kNumberTypeUInt64), + &MaskedSelectGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeBool) + .AddOutputAttr(kNumberTypeBool), + &MaskedSelectGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeComplex64), + &MaskedSelectGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeComplex128), + &MaskedSelectGradCpuKernelMod::LaunchKernel}}; std::vector MaskedSelectGradCpuKernelMod::GetOpSupport() { std::vector support_list; diff --git a/tests/st/ops/cpu/test_masked_select_op.py b/tests/st/ops/cpu/test_masked_select_op.py index fcc065953c5..b6277dc89eb 100644 --- a/tests/st/ops/cpu/test_masked_select_op.py +++ b/tests/st/ops/cpu/test_masked_select_op.py @@ -16,6 +16,7 @@ import numpy as np import pytest +import mindspore import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor @@ -140,19 +141,59 @@ class Net(nn.Cell): return self.op(x, mask) -def masked_select_grad(): - x = np.array([1, 2, 3, 4]).astype(np.int32) +def masked_select_grad(data_type): + x = np.array([1, 2, 3, 4]).astype(data_type) mask = np.array([[0], [1], [0], [1]]).astype(np.bool) - dy = np.array([i for i in range(8)]).astype(np.int32) + dy = np.array([i for i in range(8)]).astype(data_type) grad = Grad(Net()) return grad(Tensor(x), Tensor(mask), Tensor(dy))[0] +def masked_select_grad_dynamic_shape(): + x = Tensor(np.array([1, 2, 3, 4]).astype(np.int32)) + mask = Tensor(np.array([[0], [1], [0], [1]]).astype(np.bool)) + dy = Tensor(np.array([i for i in range(8)]).astype(np.int32)) + x_dynamic_shape = Tensor(shape=[None], dtype=mindspore.int32) + grad = Grad(Net()) + grad.set_inputs(x_dynamic_shape, mask, dy) + return grad(x, mask, dy)[0] + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_masked_select_grad(): context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - dx = masked_select_grad() + dx = masked_select_grad(np.int32) + expect = [4, 6, 8, 10] + assert (dx.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_masked_select_grad_float64(): + """ + Feature: test MaskedSelectGrad complex64 type on CPU + Description: the type of input is float64 + Expectation: the result match with expect + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + dx = masked_select_grad(np.float64) + expect = [4, 6, 8, 10] + assert (dx.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_masked_select_grad_dynamic_shape(): + """ + Feature: test MaskedSelectGrad dynamic shape on CPU + Description: the shape of input is dynamic + Expectation: the result match with expect + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + dx = masked_select_grad_dynamic_shape() expect = [4, 6, 8, 10] assert (dx.asnumpy() == expect).all()