!38695 add cpu maskedfill support type
Merge pull request !38695 from 范吉斌/maskedfill_cpu
This commit is contained in:
commit
7dc9deaab4
|
@ -18,11 +18,14 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <complex>
|
||||||
#include "mindspore/core/ops/masked_fill.h"
|
#include "mindspore/core/ops/masked_fill.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace {
|
namespace {
|
||||||
|
using complex64 = std::complex<float>;
|
||||||
|
using complex128 = std::complex<double>;
|
||||||
constexpr size_t kMaskedFillInputsNum = 3;
|
constexpr size_t kMaskedFillInputsNum = 3;
|
||||||
constexpr size_t kMaskedFillOutputsNum = 1;
|
constexpr size_t kMaskedFillOutputsNum = 1;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -31,6 +34,7 @@ bool MaskedFillCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const st
|
||||||
const std::vector<KernelTensorPtr> &outputs) {
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
MS_EXCEPTION_IF_NULL(base_operator);
|
MS_EXCEPTION_IF_NULL(base_operator);
|
||||||
kernel_name_ = base_operator->name();
|
kernel_name_ = base_operator->name();
|
||||||
|
batch_rank_ = base_operator->get_batch_rank();
|
||||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
if (!is_match) {
|
if (!is_match) {
|
||||||
|
@ -56,6 +60,11 @@ int MaskedFillCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
|
||||||
(void)std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(output_shape_), LongToSize);
|
(void)std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(output_shape_), LongToSize);
|
||||||
need_broadcast_ = (input_shape_ == mask_shape_) ? false : true;
|
need_broadcast_ = (input_shape_ == mask_shape_) ? false : true;
|
||||||
size_t batch_size = value_shape.size();
|
size_t batch_size = value_shape.size();
|
||||||
|
if (LongToSize(batch_rank_) != batch_size) {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value shape size should equal to " << batch_rank_
|
||||||
|
<< ", but got " << batch_size;
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
|
}
|
||||||
if (input_shape.size() < batch_size || mask_shape.size() < batch_size) {
|
if (input_shape.size() < batch_size || mask_shape.size() < batch_size) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||||
<< "', the dimension of input and mask should not be less than value's, but got input: "
|
<< "', the dimension of input and mask should not be less than value's, but got input: "
|
||||||
|
@ -70,9 +79,23 @@ int MaskedFillCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
|
||||||
}
|
}
|
||||||
|
|
||||||
output_size_ = std::accumulate(output_shape_.begin(), output_shape_.end(), size_t(1), std::multiplies<size_t>());
|
output_size_ = std::accumulate(output_shape_.begin(), output_shape_.end(), size_t(1), std::multiplies<size_t>());
|
||||||
size_t rank_size =
|
value_size_ =
|
||||||
LongToSize(std::accumulate(value_shape.begin(), value_shape.end(), int64_t(1), std::multiplies<int64_t>()));
|
LongToSize(std::accumulate(value_shape.begin(), value_shape.end(), int64_t(1), std::multiplies<int64_t>()));
|
||||||
inner_size_ = output_size_ / rank_size;
|
inner_size_ = output_size_ / value_size_;
|
||||||
|
mask_index_.clear();
|
||||||
|
input_index_.clear();
|
||||||
|
mask_index_.resize(output_size_);
|
||||||
|
input_index_.resize(output_size_);
|
||||||
|
if (need_broadcast_) {
|
||||||
|
BroadcastIterator base_iter(input_shape_, mask_shape_, output_shape_);
|
||||||
|
base_iter.SetPos(0);
|
||||||
|
for (size_t i = 0; i < output_size_; i++) {
|
||||||
|
mask_index_[i] = base_iter.GetInputPosB();
|
||||||
|
input_index_[i] = base_iter.GetInputPosA();
|
||||||
|
base_iter.GenNextPos();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,26 +111,31 @@ bool MaskedFillCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
|
||||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||||
|
|
||||||
if (need_broadcast_) {
|
if (need_broadcast_) {
|
||||||
BroadcastIterator base_iter(input_shape_, mask_shape_, output_shape_);
|
auto task = [this, input, mask, output, value](size_t start, size_t end) {
|
||||||
auto task = [this, &base_iter, input, mask, output, value](size_t start, size_t end) {
|
|
||||||
auto iter = base_iter;
|
|
||||||
iter.SetPos(start);
|
|
||||||
for (size_t i = start; i < end; i++) {
|
for (size_t i = start; i < end; i++) {
|
||||||
output[i] = mask[iter.GetInputPosB()] ? value[i / inner_size_] : input[iter.GetInputPosA()];
|
output[i] = mask[mask_index_[i]] ? value[i / inner_size_] : input[input_index_[i]];
|
||||||
iter.GenNextPos();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto task = [this, input, mask, output, value](size_t start, size_t end) {
|
if (value_size_ == 1) {
|
||||||
for (size_t i = start; i < end; i++) {
|
auto task = [this, input, mask, output, value](size_t start, size_t end) {
|
||||||
output[i] = mask[i] ? value[i / inner_size_] : input[i];
|
for (size_t i = start; i < end; i++) {
|
||||||
}
|
output[i] = mask[i] ? value[0] : input[i];
|
||||||
};
|
}
|
||||||
|
};
|
||||||
|
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
||||||
|
} else {
|
||||||
|
auto task = [this, input, mask, output, value](size_t start, size_t end) {
|
||||||
|
for (size_t i = start; i < end; i++) {
|
||||||
|
output[i] = mask[i] ? value[i / inner_size_] : input[i];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
||||||
|
}
|
||||||
|
|
||||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,18 +152,78 @@ std::vector<std::pair<KernelAttr, MaskedFillCpuKernelMod::MaskedFillFunc>> Maske
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
&MaskedFillCpuKernelMod::LaunchKernel<float>},
|
&MaskedFillCpuKernelMod::LaunchKernel<float>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
&MaskedFillCpuKernelMod::LaunchKernel<double>},
|
||||||
{KernelAttr()
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt8)
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
.AddInputAttr(kNumberTypeBool)
|
.AddInputAttr(kNumberTypeBool)
|
||||||
.AddInputAttr(kNumberTypeInt8)
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
.AddOutputAttr(kNumberTypeInt8),
|
.AddOutputAttr(kNumberTypeInt8),
|
||||||
&MaskedFillCpuKernelMod::LaunchKernel<int8_t>},
|
&MaskedFillCpuKernelMod::LaunchKernel<int8_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddOutputAttr(kNumberTypeInt16),
|
||||||
|
&MaskedFillCpuKernelMod::LaunchKernel<int16_t>},
|
||||||
{KernelAttr()
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddInputAttr(kNumberTypeBool)
|
.AddInputAttr(kNumberTypeBool)
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddOutputAttr(kNumberTypeInt32),
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
&MaskedFillCpuKernelMod::LaunchKernel<int32_t>},
|
&MaskedFillCpuKernelMod::LaunchKernel<int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeInt64),
|
||||||
|
&MaskedFillCpuKernelMod::LaunchKernel<int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
&MaskedFillCpuKernelMod::LaunchKernel<uint8_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
&MaskedFillCpuKernelMod::LaunchKernel<uint16_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
&MaskedFillCpuKernelMod::LaunchKernel<uint32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
&MaskedFillCpuKernelMod::LaunchKernel<uint64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddOutputAttr(kNumberTypeBool),
|
||||||
|
&MaskedFillCpuKernelMod::LaunchKernel<bool>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeComplex64)
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeComplex64)
|
||||||
|
.AddOutputAttr(kNumberTypeComplex64),
|
||||||
|
&MaskedFillCpuKernelMod::LaunchKernel<complex64>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeComplex128)
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeComplex128)
|
||||||
|
.AddOutputAttr(kNumberTypeComplex128),
|
||||||
|
&MaskedFillCpuKernelMod::LaunchKernel<complex128>},
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<KernelAttr> MaskedFillCpuKernelMod::GetOpSupport() {
|
std::vector<KernelAttr> MaskedFillCpuKernelMod::GetOpSupport() {
|
||||||
|
|
|
@ -51,9 +51,13 @@ class MaskedFillCpuKernelMod : public NativeCpuKernelMod {
|
||||||
MaskedFillFunc kernel_func_;
|
MaskedFillFunc kernel_func_;
|
||||||
size_t output_size_{1};
|
size_t output_size_{1};
|
||||||
size_t inner_size_{1};
|
size_t inner_size_{1};
|
||||||
|
size_t value_size_{1};
|
||||||
|
int64_t batch_rank_{0};
|
||||||
std::vector<int64_t> input_shape_;
|
std::vector<int64_t> input_shape_;
|
||||||
std::vector<int64_t> mask_shape_;
|
std::vector<int64_t> mask_shape_;
|
||||||
std::vector<int64_t> output_shape_;
|
std::vector<int64_t> output_shape_;
|
||||||
|
std::vector<size_t> mask_index_;
|
||||||
|
std::vector<size_t> input_index_;
|
||||||
bool need_broadcast_{false};
|
bool need_broadcast_{false};
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -76,12 +76,12 @@ TypePtr MaskedFillInferType(const PrimitivePtr &prim, const std::vector<Abstract
|
||||||
std::set<TypePtr> valid_types;
|
std::set<TypePtr> valid_types;
|
||||||
auto context = MsContext::GetInstance();
|
auto context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
|
||||||
if (is_gpu) {
|
if (is_ascend) {
|
||||||
|
valid_types = {kFloat16, kFloat32, kInt8, kInt32};
|
||||||
|
} else {
|
||||||
valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64,
|
valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64,
|
||||||
kFloat16, kFloat32, kFloat64, kInt, kUInt, kFloat, kComplex64, kComplex128};
|
kFloat16, kFloat32, kFloat64, kInt, kUInt, kFloat, kComplex64, kComplex128};
|
||||||
} else {
|
|
||||||
valid_types = {kFloat16, kFloat32, kInt8, kInt32};
|
|
||||||
}
|
}
|
||||||
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
|
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
|
||||||
std::map<std::string, TypePtr> types;
|
std::map<std::string, TypePtr> types;
|
||||||
|
|
Loading…
Reference in New Issue