MaskedSelectedGradCpuKernel support more data type

This commit is contained in:
tanghuikang 2022-09-22 09:44:42 +08:00
parent 7102e94755
commit 63a77ba345
2 changed files with 109 additions and 19 deletions

View File

@ -64,8 +64,9 @@ int MaskedSelectGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
input_shape_b_ = inputs[kIndexMask]->GetShapeVector();
grad_shape_ = inputs[kIndexGrad]->GetShapeVector();
output_shape_ = CPUKernelUtils::GetBroadcastShape(input_shape_a_, input_shape_b_);
if (KernelMod::Resize(base_operator, inputs, outputs) != KRET_OK) {
MS_LOG(EXCEPTION) << "MaskedSelectGradCpuKernelMod resize failed.";
const auto ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
tensor_size_ = 1;
tensor_size_ =
@ -114,41 +115,89 @@ bool MaskedSelectGradCpuKernelMod::LaunchKernel(const std::vector<kernel::Addres
std::vector<std::pair<KernelAttr, MaskedSelectGradCpuKernelMod::MaskedSelectGradFunc>>
MaskedSelectGradCpuKernelMod::func_list_ = {{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&MaskedSelectGradCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&MaskedSelectGradCpuKernelMod::LaunchKernel<int>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&MaskedSelectGradCpuKernelMod::LaunchKernel<float16>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&MaskedSelectGradCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&MaskedSelectGradCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&MaskedSelectGradCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
&MaskedSelectGradCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&MaskedSelectGradCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&MaskedSelectGradCpuKernelMod::LaunchKernel<int64_t>}};
&MaskedSelectGradCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&MaskedSelectGradCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
&MaskedSelectGradCpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
&MaskedSelectGradCpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
&MaskedSelectGradCpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
&MaskedSelectGradCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
&MaskedSelectGradCpuKernelMod::LaunchKernel<complex64>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
&MaskedSelectGradCpuKernelMod::LaunchKernel<complex128>}};
std::vector<KernelAttr> MaskedSelectGradCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;

View File

@ -16,6 +16,7 @@
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
@ -140,19 +141,59 @@ class Net(nn.Cell):
return self.op(x, mask)
def masked_select_grad():
x = np.array([1, 2, 3, 4]).astype(np.int32)
def masked_select_grad(data_type):
x = np.array([1, 2, 3, 4]).astype(data_type)
mask = np.array([[0], [1], [0], [1]]).astype(np.bool)
dy = np.array([i for i in range(8)]).astype(np.int32)
dy = np.array([i for i in range(8)]).astype(data_type)
grad = Grad(Net())
return grad(Tensor(x), Tensor(mask), Tensor(dy))[0]
def masked_select_grad_dynamic_shape():
x = Tensor(np.array([1, 2, 3, 4]).astype(np.int32))
mask = Tensor(np.array([[0], [1], [0], [1]]).astype(np.bool))
dy = Tensor(np.array([i for i in range(8)]).astype(np.int32))
x_dynamic_shape = Tensor(shape=[None], dtype=mindspore.int32)
grad = Grad(Net())
grad.set_inputs(x_dynamic_shape, mask, dy)
return grad(x, mask, dy)[0]
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_masked_select_grad():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
dx = masked_select_grad()
dx = masked_select_grad(np.int32)
expect = [4, 6, 8, 10]
assert (dx.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_masked_select_grad_float64():
"""
Feature: test MaskedSelectGrad complex64 type on CPU
Description: the type of input is float64
Expectation: the result match with expect
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
dx = masked_select_grad(np.float64)
expect = [4, 6, 8, 10]
assert (dx.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_masked_select_grad_dynamic_shape():
"""
Feature: test MaskedSelectGrad dynamic shape on CPU
Description: the shape of input is dynamic
Expectation: the result match with expect
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
dx = masked_select_grad_dynamic_shape()
expect = [4, 6, 8, 10]
assert (dx.asnumpy() == expect).all()