!38695 add cpu maskedfill support type
Merge pull request !38695 from 范吉斌/maskedfill_cpu
This commit is contained in:
commit
7dc9deaab4
|
@ -18,11 +18,14 @@
|
|||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include <complex>
|
||||
#include "mindspore/core/ops/masked_fill.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
constexpr size_t kMaskedFillInputsNum = 3;
|
||||
constexpr size_t kMaskedFillOutputsNum = 1;
|
||||
} // namespace
|
||||
|
@ -31,6 +34,7 @@ bool MaskedFillCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const st
|
|||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
batch_rank_ = base_operator->get_batch_rank();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
|
@ -56,6 +60,11 @@ int MaskedFillCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
|
|||
(void)std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(output_shape_), LongToSize);
|
||||
need_broadcast_ = (input_shape_ == mask_shape_) ? false : true;
|
||||
size_t batch_size = value_shape.size();
|
||||
if (LongToSize(batch_rank_) != batch_size) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value shape size should equal to " << batch_rank_
|
||||
<< ", but got " << batch_size;
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (input_shape.size() < batch_size || mask_shape.size() < batch_size) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of input and mask should not be less than value's, but got input: "
|
||||
|
@ -70,9 +79,23 @@ int MaskedFillCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
|
|||
}
|
||||
|
||||
output_size_ = std::accumulate(output_shape_.begin(), output_shape_.end(), size_t(1), std::multiplies<size_t>());
|
||||
size_t rank_size =
|
||||
value_size_ =
|
||||
LongToSize(std::accumulate(value_shape.begin(), value_shape.end(), int64_t(1), std::multiplies<int64_t>()));
|
||||
inner_size_ = output_size_ / rank_size;
|
||||
inner_size_ = output_size_ / value_size_;
|
||||
mask_index_.clear();
|
||||
input_index_.clear();
|
||||
mask_index_.resize(output_size_);
|
||||
input_index_.resize(output_size_);
|
||||
if (need_broadcast_) {
|
||||
BroadcastIterator base_iter(input_shape_, mask_shape_, output_shape_);
|
||||
base_iter.SetPos(0);
|
||||
for (size_t i = 0; i < output_size_; i++) {
|
||||
mask_index_[i] = base_iter.GetInputPosB();
|
||||
input_index_[i] = base_iter.GetInputPosA();
|
||||
base_iter.GenNextPos();
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -88,26 +111,31 @@ bool MaskedFillCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
|
|||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
if (need_broadcast_) {
|
||||
BroadcastIterator base_iter(input_shape_, mask_shape_, output_shape_);
|
||||
auto task = [this, &base_iter, input, mask, output, value](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
auto task = [this, input, mask, output, value](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
output[i] = mask[iter.GetInputPosB()] ? value[i / inner_size_] : input[iter.GetInputPosA()];
|
||||
iter.GenNextPos();
|
||||
output[i] = mask[mask_index_[i]] ? value[i / inner_size_] : input[input_index_[i]];
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
auto task = [this, input, mask, output, value](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
output[i] = mask[i] ? value[i / inner_size_] : input[i];
|
||||
}
|
||||
};
|
||||
if (value_size_ == 1) {
|
||||
auto task = [this, input, mask, output, value](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
output[i] = mask[i] ? value[0] : input[i];
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
||||
} else {
|
||||
auto task = [this, input, mask, output, value](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
output[i] = mask[i] ? value[i / inner_size_] : input[i];
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
||||
}
|
||||
|
||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -124,18 +152,78 @@ std::vector<std::pair<KernelAttr, MaskedFillCpuKernelMod::MaskedFillFunc>> Maske
|
|||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<bool>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<complex64>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&MaskedFillCpuKernelMod::LaunchKernel<complex128>},
|
||||
};
|
||||
|
||||
std::vector<KernelAttr> MaskedFillCpuKernelMod::GetOpSupport() {
|
||||
|
|
|
@ -51,9 +51,13 @@ class MaskedFillCpuKernelMod : public NativeCpuKernelMod {
|
|||
MaskedFillFunc kernel_func_;
|
||||
size_t output_size_{1};
|
||||
size_t inner_size_{1};
|
||||
size_t value_size_{1};
|
||||
int64_t batch_rank_{0};
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> mask_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
std::vector<size_t> mask_index_;
|
||||
std::vector<size_t> input_index_;
|
||||
bool need_broadcast_{false};
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -76,12 +76,12 @@ TypePtr MaskedFillInferType(const PrimitivePtr &prim, const std::vector<Abstract
|
|||
std::set<TypePtr> valid_types;
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
||||
if (is_gpu) {
|
||||
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
|
||||
if (is_ascend) {
|
||||
valid_types = {kFloat16, kFloat32, kInt8, kInt32};
|
||||
} else {
|
||||
valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64,
|
||||
kFloat16, kFloat32, kFloat64, kInt, kUInt, kFloat, kComplex64, kComplex128};
|
||||
} else {
|
||||
valid_types = {kFloat16, kFloat32, kInt8, kInt32};
|
||||
}
|
||||
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
|
||||
std::map<std::string, TypePtr> types;
|
||||
|
|
Loading…
Reference in New Issue