forked from mindspore-Ecosystem/mindspore
MaskedSelectedGradCpuKernel support more data type
This commit is contained in:
parent
7102e94755
commit
63a77ba345
|
@ -64,8 +64,9 @@ int MaskedSelectGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
input_shape_b_ = inputs[kIndexMask]->GetShapeVector();
|
||||
grad_shape_ = inputs[kIndexGrad]->GetShapeVector();
|
||||
output_shape_ = CPUKernelUtils::GetBroadcastShape(input_shape_a_, input_shape_b_);
|
||||
if (KernelMod::Resize(base_operator, inputs, outputs) != KRET_OK) {
|
||||
MS_LOG(EXCEPTION) << "MaskedSelectGradCpuKernelMod resize failed.";
|
||||
const auto ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
tensor_size_ = 1;
|
||||
tensor_size_ =
|
||||
|
@ -114,41 +115,89 @@ bool MaskedSelectGradCpuKernelMod::LaunchKernel(const std::vector<kernel::Addres
|
|||
|
||||
std::vector<std::pair<KernelAttr, MaskedSelectGradCpuKernelMod::MaskedSelectGradFunc>>
|
||||
MaskedSelectGradCpuKernelMod::func_list_ = {{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<int64_t>}};
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<complex64>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&MaskedSelectGradCpuKernelMod::LaunchKernel<complex128>}};
|
||||
|
||||
std::vector<KernelAttr> MaskedSelectGradCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
@ -140,19 +141,59 @@ class Net(nn.Cell):
|
|||
return self.op(x, mask)
|
||||
|
||||
|
||||
def masked_select_grad():
|
||||
x = np.array([1, 2, 3, 4]).astype(np.int32)
|
||||
def masked_select_grad(data_type):
|
||||
x = np.array([1, 2, 3, 4]).astype(data_type)
|
||||
mask = np.array([[0], [1], [0], [1]]).astype(np.bool)
|
||||
dy = np.array([i for i in range(8)]).astype(np.int32)
|
||||
dy = np.array([i for i in range(8)]).astype(data_type)
|
||||
grad = Grad(Net())
|
||||
return grad(Tensor(x), Tensor(mask), Tensor(dy))[0]
|
||||
|
||||
|
||||
def masked_select_grad_dynamic_shape():
|
||||
x = Tensor(np.array([1, 2, 3, 4]).astype(np.int32))
|
||||
mask = Tensor(np.array([[0], [1], [0], [1]]).astype(np.bool))
|
||||
dy = Tensor(np.array([i for i in range(8)]).astype(np.int32))
|
||||
x_dynamic_shape = Tensor(shape=[None], dtype=mindspore.int32)
|
||||
grad = Grad(Net())
|
||||
grad.set_inputs(x_dynamic_shape, mask, dy)
|
||||
return grad(x, mask, dy)[0]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_masked_select_grad():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
dx = masked_select_grad()
|
||||
dx = masked_select_grad(np.int32)
|
||||
expect = [4, 6, 8, 10]
|
||||
assert (dx.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_masked_select_grad_float64():
|
||||
"""
|
||||
Feature: test MaskedSelectGrad complex64 type on CPU
|
||||
Description: the type of input is float64
|
||||
Expectation: the result match with expect
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
dx = masked_select_grad(np.float64)
|
||||
expect = [4, 6, 8, 10]
|
||||
assert (dx.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_masked_select_grad_dynamic_shape():
|
||||
"""
|
||||
Feature: test MaskedSelectGrad dynamic shape on CPU
|
||||
Description: the shape of input is dynamic
|
||||
Expectation: the result match with expect
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
dx = masked_select_grad_dynamic_shape()
|
||||
expect = [4, 6, 8, 10]
|
||||
assert (dx.asnumpy() == expect).all()
|
||||
|
|
Loading…
Reference in New Issue