adapt SparseSegmentSumWithNumSegments KernelMod.

This commit is contained in:
y00451588 2022-11-03 14:20:51 +08:00
parent af9bf8fe71
commit ccedc1db01
2 changed files with 85 additions and 120 deletions

View File

@ -20,109 +20,59 @@
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kSparseSegmentSumWithNumSegmentsInputsNum = 4;
constexpr size_t kSparseSegmentSumWithNumSegmentsOutputsNum = 1;
constexpr size_t kInputsNum = 4;
constexpr size_t kOutputsNum = 1;
#define ADD_KERNEL(t1, t2, t3, t4, t5) \
KernelAttr() \
.AddInputAttr(kNumberType##t1) \
.AddInputAttr(kNumberType##t2) \
.AddInputAttr(kNumberType##t3) \
.AddInputAttr(kNumberType##t4) \
.AddOutputAttr(kNumberType##t5)
#define ADD_KERNEL(T1, T2, T3, T4, T5, T6, T7) \
{ \
KernelAttr() \
.AddInputAttr(kNumberType##T1) \
.AddInputAttr(kNumberType##T2) \
.AddInputAttr(kNumberType##T3) \
.AddInputAttr(kNumberType##T4) \
.AddOutputAttr(kNumberType##T5), \
&SparseSegmentSumWithNumSegmentsCpuKernelMod::LaunchKernel<T6, T7> \
}
} // namespace
void SparseSegmentSumWithNumSegmentsCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0);
indices_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1);
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
bool SparseSegmentSumWithNumSegmentsCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_ERROR_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "The kernel '" << kernel_name_ << "' does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = f_list_[index].second;
return true;
}
int SparseSegmentSumWithNumSegmentsCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
x_dtype_ = inputs[kIndex0]->GetDtype();
indices_dtype_ = inputs[kIndex1]->GetDtype();
x_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
segment_ids_shape_ = inputs[kIndex2]->GetDeviceShapeAdaptively();
y_shape_ = outputs[kIndex0]->GetDeviceShapeAdaptively();
return KRET_OK;
}
bool SparseSegmentSumWithNumSegmentsCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
switch (x_dtype_) {
case (kNumberTypeInt8):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<int8_t, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<int8_t, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeInt16):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<int16_t, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<int16_t, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeInt32):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<int32_t, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<int32_t, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeInt64):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<int64_t, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<int64_t, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeUInt8):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<uint8_t, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<uint8_t, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeUInt16):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<uint16_t, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<uint16_t, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeFloat16):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<float16, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<float16, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeFloat32):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<float, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<float, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeFloat64):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<double, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<double, int64_t>(inputs, outputs);
break;
}
default:
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', data type of x is " << TypeIdLabel(x_dtype_)
<< " which is not supported.";
}
MS_ERROR_IF_NULL(kernel_func_);
kernel_func_(this, inputs, outputs);
return true;
}
@ -170,26 +120,26 @@ void SparseSegmentSumWithNumSegmentsCpuKernelMod::LaunchKernel(const std::vector
}
}
void SparseSegmentSumWithNumSegmentsCpuKernelMod::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kSparseSegmentSumWithNumSegmentsInputsNum, kernel_name_);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kSparseSegmentSumWithNumSegmentsOutputsNum, kernel_name_);
}
std::vector<KernelAttr> SparseSegmentSumWithNumSegmentsCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {
ADD_KERNEL(Int8, Int32, Int32, Int32, Int8), ADD_KERNEL(Float16, Int32, Int32, Int32, Float16),
ADD_KERNEL(Int16, Int32, Int32, Int32, Int16), ADD_KERNEL(Float32, Int32, Int32, Int32, Float32),
ADD_KERNEL(Int32, Int32, Int32, Int32, Int32), ADD_KERNEL(Float64, Int32, Int32, Int32, Float64),
ADD_KERNEL(Int64, Int32, Int32, Int32, Int64), ADD_KERNEL(UInt8, Int32, Int32, Int32, UInt8),
ADD_KERNEL(UInt16, Int32, Int32, Int32, UInt16), ADD_KERNEL(Int8, Int64, Int64, Int64, Int8),
ADD_KERNEL(Float16, Int64, Int64, Int64, Float16), ADD_KERNEL(Int16, Int64, Int64, Int64, Int16),
ADD_KERNEL(Float32, Int64, Int64, Int64, Float32), ADD_KERNEL(Int32, Int64, Int64, Int64, Int32),
ADD_KERNEL(Float64, Int64, Int64, Int64, Float64), ADD_KERNEL(Int64, Int64, Int64, Int64, Int64),
ADD_KERNEL(UInt8, Int64, Int64, Int64, UInt8), ADD_KERNEL(UInt16, Int64, Int64, Int64, UInt16)};
return kernel_attr_list;
}
std::vector<std::pair<KernelAttr, SparseSegmentSumWithNumSegmentsCpuKernelMod::LaunchKernelFunc>>
SparseSegmentSumWithNumSegmentsCpuKernelMod::f_list_ = {
ADD_KERNEL(Int8, Int32, Int32, Int32, Int8, int8_t, int32_t),
ADD_KERNEL(Int8, Int64, Int64, Int64, Int8, int8_t, int64_t),
ADD_KERNEL(Int16, Int32, Int32, Int32, Int16, int16_t, int32_t),
ADD_KERNEL(Int16, Int64, Int64, Int64, Int16, int16_t, int64_t),
ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, int32_t, int32_t),
ADD_KERNEL(Int32, Int64, Int64, Int64, Int32, int32_t, int64_t),
ADD_KERNEL(Int64, Int32, Int32, Int32, Int64, int64_t, int32_t),
ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, int64_t, int64_t),
ADD_KERNEL(UInt8, Int32, Int32, Int32, UInt8, uint8_t, int32_t),
ADD_KERNEL(UInt8, Int64, Int64, Int64, UInt8, uint8_t, int64_t),
ADD_KERNEL(UInt16, Int32, Int32, Int32, UInt16, uint16_t, int32_t),
ADD_KERNEL(UInt16, Int64, Int64, Int64, UInt16, uint16_t, int64_t),
ADD_KERNEL(Float16, Int32, Int32, Int32, Float16, float16, int32_t),
ADD_KERNEL(Float16, Int64, Int64, Int64, Float16, float16, int64_t),
ADD_KERNEL(Float32, Int32, Int32, Int32, Float32, float, int32_t),
ADD_KERNEL(Float32, Int64, Int64, Int64, Float32, float, int64_t),
ADD_KERNEL(Float64, Int32, Int32, Int32, Float64, double, int32_t),
ADD_KERNEL(Float64, Int64, Int64, Int64, Float64, double, int64_t)};
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseSegmentSumWithNumSegments, SparseSegmentSumWithNumSegmentsCpuKernelMod);
} // namespace kernel

View File

@ -17,6 +17,8 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SUM_WITH_NUM_SGEMENTS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SUM_WITH_NUM_SGEMENTS_CPU_KERNEL_H_
#include <map>
#include <utility>
#include <functional>
#include <numeric>
#include <algorithm>
@ -29,12 +31,16 @@
namespace mindspore {
namespace kernel {
class SparseSegmentSumWithNumSegmentsCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class SparseSegmentSumWithNumSegmentsCpuKernelMod : public NativeCpuKernelMod {
public:
SparseSegmentSumWithNumSegmentsCpuKernelMod() = default;
~SparseSegmentSumWithNumSegmentsCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
@ -42,11 +48,20 @@ class SparseSegmentSumWithNumSegmentsCpuKernelMod : public DeprecatedNativeCpuKe
template <typename T1, typename T2>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
std::vector<KernelAttr> GetOpSupport() override {
std::vector<KernelAttr> kernel_attr_list;
(void)std::transform(f_list_.begin(), f_list_.end(), std::back_inserter(kernel_attr_list),
[](const std::pair<KernelAttr, LaunchKernelFunc> &pair) { return pair.first; });
return kernel_attr_list;
}
using LaunchKernelFunc =
std::function<void(SparseSegmentSumWithNumSegmentsCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
private:
void CheckParam(const CNodePtr &kernel_node);
static std::vector<std::pair<KernelAttr, LaunchKernelFunc>> f_list_;
LaunchKernelFunc kernel_func_;
ShapeVector x_shape_;
ShapeVector segment_ids_shape_;
ShapeVector y_shape_;