!38695 add cpu maskedfill support type

Merge pull request !38695 from 范吉斌/maskedfill_cpu
This commit is contained in:
i-robot 2022-08-11 03:04:47 +00:00 committed by Gitee
commit 7dc9deaab4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 110 additions and 18 deletions

View File

@ -18,11 +18,14 @@
#include <algorithm>
#include <utility>
#include <functional>
#include <complex>
#include "mindspore/core/ops/masked_fill.h"
namespace mindspore {
namespace kernel {
namespace {
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
constexpr size_t kMaskedFillInputsNum = 3;
constexpr size_t kMaskedFillOutputsNum = 1;
} // namespace
@ -31,6 +34,7 @@ bool MaskedFillCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const st
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
batch_rank_ = base_operator->get_batch_rank();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
@ -56,6 +60,11 @@ int MaskedFillCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
(void)std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(output_shape_), LongToSize);
need_broadcast_ = (input_shape_ == mask_shape_) ? false : true;
size_t batch_size = value_shape.size();
if (LongToSize(batch_rank_) != batch_size) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value shape size should equal to " << batch_rank_
<< ", but got " << batch_size;
return KRET_RESIZE_FAILED;
}
if (input_shape.size() < batch_size || mask_shape.size() < batch_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of input and mask should not be less than value's, but got input: "
@ -70,9 +79,23 @@ int MaskedFillCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
}
output_size_ = std::accumulate(output_shape_.begin(), output_shape_.end(), size_t(1), std::multiplies<size_t>());
size_t rank_size =
value_size_ =
LongToSize(std::accumulate(value_shape.begin(), value_shape.end(), int64_t(1), std::multiplies<int64_t>()));
inner_size_ = output_size_ / rank_size;
inner_size_ = output_size_ / value_size_;
mask_index_.clear();
input_index_.clear();
mask_index_.resize(output_size_);
input_index_.resize(output_size_);
if (need_broadcast_) {
BroadcastIterator base_iter(input_shape_, mask_shape_, output_shape_);
base_iter.SetPos(0);
for (size_t i = 0; i < output_size_; i++) {
mask_index_[i] = base_iter.GetInputPosB();
input_index_[i] = base_iter.GetInputPosA();
base_iter.GenNextPos();
}
}
return ret;
}
@ -88,26 +111,31 @@ bool MaskedFillCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
auto output = reinterpret_cast<T *>(outputs[0]->addr);
if (need_broadcast_) {
BroadcastIterator base_iter(input_shape_, mask_shape_, output_shape_);
auto task = [this, &base_iter, input, mask, output, value](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
auto task = [this, input, mask, output, value](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
output[i] = mask[iter.GetInputPosB()] ? value[i / inner_size_] : input[iter.GetInputPosA()];
iter.GenNextPos();
output[i] = mask[mask_index_[i]] ? value[i / inner_size_] : input[input_index_[i]];
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return true;
}
auto task = [this, input, mask, output, value](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
output[i] = mask[i] ? value[i / inner_size_] : input[i];
}
};
if (value_size_ == 1) {
auto task = [this, input, mask, output, value](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
output[i] = mask[i] ? value[0] : input[i];
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
} else {
auto task = [this, input, mask, output, value](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
output[i] = mask[i] ? value[i / inner_size_] : input[i];
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
}
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return true;
}
@ -124,18 +152,78 @@ std::vector<std::pair<KernelAttr, MaskedFillCpuKernelMod::MaskedFillFunc>> Maske
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&MaskedFillCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&MaskedFillCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&MaskedFillCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
&MaskedFillCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&MaskedFillCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&MaskedFillCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&MaskedFillCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
&MaskedFillCpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
&MaskedFillCpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
&MaskedFillCpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
&MaskedFillCpuKernelMod::LaunchKernel<bool>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
&MaskedFillCpuKernelMod::LaunchKernel<complex64>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
&MaskedFillCpuKernelMod::LaunchKernel<complex128>},
};
std::vector<KernelAttr> MaskedFillCpuKernelMod::GetOpSupport() {

View File

@ -51,9 +51,13 @@ class MaskedFillCpuKernelMod : public NativeCpuKernelMod {
MaskedFillFunc kernel_func_;
size_t output_size_{1};
size_t inner_size_{1};
size_t value_size_{1};
int64_t batch_rank_{0};
std::vector<int64_t> input_shape_;
std::vector<int64_t> mask_shape_;
std::vector<int64_t> output_shape_;
std::vector<size_t> mask_index_;
std::vector<size_t> input_index_;
bool need_broadcast_{false};
};
} // namespace kernel

View File

@ -76,12 +76,12 @@ TypePtr MaskedFillInferType(const PrimitivePtr &prim, const std::vector<Abstract
std::set<TypePtr> valid_types;
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
if (is_gpu) {
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
if (is_ascend) {
valid_types = {kFloat16, kFloat32, kInt8, kInt32};
} else {
valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64,
kFloat16, kFloat32, kFloat64, kInt, kUInt, kFloat, kComplex64, kComplex128};
} else {
valid_types = {kFloat16, kFloat32, kInt8, kInt32};
}
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
std::map<std::string, TypePtr> types;