adapt SparseSegmentSumWithNumSegments KernelMod.
This commit is contained in:
parent
af9bf8fe71
commit
ccedc1db01
|
@ -20,109 +20,59 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kSparseSegmentSumWithNumSegmentsInputsNum = 4;
|
||||
constexpr size_t kSparseSegmentSumWithNumSegmentsOutputsNum = 1;
|
||||
constexpr size_t kInputsNum = 4;
|
||||
constexpr size_t kOutputsNum = 1;
|
||||
|
||||
#define ADD_KERNEL(t1, t2, t3, t4, t5) \
|
||||
KernelAttr() \
|
||||
.AddInputAttr(kNumberType##t1) \
|
||||
.AddInputAttr(kNumberType##t2) \
|
||||
.AddInputAttr(kNumberType##t3) \
|
||||
.AddInputAttr(kNumberType##t4) \
|
||||
.AddOutputAttr(kNumberType##t5)
|
||||
#define ADD_KERNEL(T1, T2, T3, T4, T5, T6, T7) \
|
||||
{ \
|
||||
KernelAttr() \
|
||||
.AddInputAttr(kNumberType##T1) \
|
||||
.AddInputAttr(kNumberType##T2) \
|
||||
.AddInputAttr(kNumberType##T3) \
|
||||
.AddInputAttr(kNumberType##T4) \
|
||||
.AddOutputAttr(kNumberType##T5), \
|
||||
&SparseSegmentSumWithNumSegmentsCpuKernelMod::LaunchKernel<T6, T7> \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void SparseSegmentSumWithNumSegmentsCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0);
|
||||
indices_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1);
|
||||
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
|
||||
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
|
||||
y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
|
||||
bool SparseSegmentSumWithNumSegmentsCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_ERROR_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "The kernel '" << kernel_name_ << "' does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = f_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int SparseSegmentSumWithNumSegmentsCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
x_dtype_ = inputs[kIndex0]->GetDtype();
|
||||
indices_dtype_ = inputs[kIndex1]->GetDtype();
|
||||
x_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
segment_ids_shape_ = inputs[kIndex2]->GetDeviceShapeAdaptively();
|
||||
y_shape_ = outputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool SparseSegmentSumWithNumSegmentsCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
switch (x_dtype_) {
|
||||
case (kNumberTypeInt8):
|
||||
if (indices_dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<int8_t, int32_t>(inputs, outputs);
|
||||
break;
|
||||
} else {
|
||||
LaunchKernel<int8_t, int64_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case (kNumberTypeInt16):
|
||||
if (indices_dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<int16_t, int32_t>(inputs, outputs);
|
||||
break;
|
||||
} else {
|
||||
LaunchKernel<int16_t, int64_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case (kNumberTypeInt32):
|
||||
if (indices_dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<int32_t, int32_t>(inputs, outputs);
|
||||
break;
|
||||
} else {
|
||||
LaunchKernel<int32_t, int64_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case (kNumberTypeInt64):
|
||||
if (indices_dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<int64_t, int32_t>(inputs, outputs);
|
||||
break;
|
||||
} else {
|
||||
LaunchKernel<int64_t, int64_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case (kNumberTypeUInt8):
|
||||
if (indices_dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<uint8_t, int32_t>(inputs, outputs);
|
||||
break;
|
||||
} else {
|
||||
LaunchKernel<uint8_t, int64_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case (kNumberTypeUInt16):
|
||||
if (indices_dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<uint16_t, int32_t>(inputs, outputs);
|
||||
break;
|
||||
} else {
|
||||
LaunchKernel<uint16_t, int64_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case (kNumberTypeFloat16):
|
||||
if (indices_dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<float16, int32_t>(inputs, outputs);
|
||||
break;
|
||||
} else {
|
||||
LaunchKernel<float16, int64_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case (kNumberTypeFloat32):
|
||||
if (indices_dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<float, int32_t>(inputs, outputs);
|
||||
break;
|
||||
} else {
|
||||
LaunchKernel<float, int64_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case (kNumberTypeFloat64):
|
||||
if (indices_dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<double, int32_t>(inputs, outputs);
|
||||
break;
|
||||
} else {
|
||||
LaunchKernel<double, int64_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', data type of x is " << TypeIdLabel(x_dtype_)
|
||||
<< " which is not supported.";
|
||||
}
|
||||
MS_ERROR_IF_NULL(kernel_func_);
|
||||
kernel_func_(this, inputs, outputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -170,26 +120,26 @@ void SparseSegmentSumWithNumSegmentsCpuKernelMod::LaunchKernel(const std::vector
|
|||
}
|
||||
}
|
||||
|
||||
void SparseSegmentSumWithNumSegmentsCpuKernelMod::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kSparseSegmentSumWithNumSegmentsInputsNum, kernel_name_);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kSparseSegmentSumWithNumSegmentsOutputsNum, kernel_name_);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> SparseSegmentSumWithNumSegmentsCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {
|
||||
ADD_KERNEL(Int8, Int32, Int32, Int32, Int8), ADD_KERNEL(Float16, Int32, Int32, Int32, Float16),
|
||||
ADD_KERNEL(Int16, Int32, Int32, Int32, Int16), ADD_KERNEL(Float32, Int32, Int32, Int32, Float32),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, Int32), ADD_KERNEL(Float64, Int32, Int32, Int32, Float64),
|
||||
ADD_KERNEL(Int64, Int32, Int32, Int32, Int64), ADD_KERNEL(UInt8, Int32, Int32, Int32, UInt8),
|
||||
ADD_KERNEL(UInt16, Int32, Int32, Int32, UInt16), ADD_KERNEL(Int8, Int64, Int64, Int64, Int8),
|
||||
ADD_KERNEL(Float16, Int64, Int64, Int64, Float16), ADD_KERNEL(Int16, Int64, Int64, Int64, Int16),
|
||||
ADD_KERNEL(Float32, Int64, Int64, Int64, Float32), ADD_KERNEL(Int32, Int64, Int64, Int64, Int32),
|
||||
ADD_KERNEL(Float64, Int64, Int64, Int64, Float64), ADD_KERNEL(Int64, Int64, Int64, Int64, Int64),
|
||||
ADD_KERNEL(UInt8, Int64, Int64, Int64, UInt8), ADD_KERNEL(UInt16, Int64, Int64, Int64, UInt16)};
|
||||
return kernel_attr_list;
|
||||
}
|
||||
std::vector<std::pair<KernelAttr, SparseSegmentSumWithNumSegmentsCpuKernelMod::LaunchKernelFunc>>
|
||||
SparseSegmentSumWithNumSegmentsCpuKernelMod::f_list_ = {
|
||||
ADD_KERNEL(Int8, Int32, Int32, Int32, Int8, int8_t, int32_t),
|
||||
ADD_KERNEL(Int8, Int64, Int64, Int64, Int8, int8_t, int64_t),
|
||||
ADD_KERNEL(Int16, Int32, Int32, Int32, Int16, int16_t, int32_t),
|
||||
ADD_KERNEL(Int16, Int64, Int64, Int64, Int16, int16_t, int64_t),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, int32_t, int32_t),
|
||||
ADD_KERNEL(Int32, Int64, Int64, Int64, Int32, int32_t, int64_t),
|
||||
ADD_KERNEL(Int64, Int32, Int32, Int32, Int64, int64_t, int32_t),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, int64_t, int64_t),
|
||||
ADD_KERNEL(UInt8, Int32, Int32, Int32, UInt8, uint8_t, int32_t),
|
||||
ADD_KERNEL(UInt8, Int64, Int64, Int64, UInt8, uint8_t, int64_t),
|
||||
ADD_KERNEL(UInt16, Int32, Int32, Int32, UInt16, uint16_t, int32_t),
|
||||
ADD_KERNEL(UInt16, Int64, Int64, Int64, UInt16, uint16_t, int64_t),
|
||||
ADD_KERNEL(Float16, Int32, Int32, Int32, Float16, float16, int32_t),
|
||||
ADD_KERNEL(Float16, Int64, Int64, Int64, Float16, float16, int64_t),
|
||||
ADD_KERNEL(Float32, Int32, Int32, Int32, Float32, float, int32_t),
|
||||
ADD_KERNEL(Float32, Int64, Int64, Int64, Float32, float, int64_t),
|
||||
ADD_KERNEL(Float64, Int32, Int32, Int32, Float64, double, int32_t),
|
||||
ADD_KERNEL(Float64, Int64, Int64, Int64, Float64, double, int64_t)};
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseSegmentSumWithNumSegments, SparseSegmentSumWithNumSegmentsCpuKernelMod);
|
||||
} // namespace kernel
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SUM_WITH_NUM_SGEMENTS_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SUM_WITH_NUM_SGEMENTS_CPU_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
|
@ -29,12 +31,16 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SparseSegmentSumWithNumSegmentsCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class SparseSegmentSumWithNumSegmentsCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
SparseSegmentSumWithNumSegmentsCpuKernelMod() = default;
|
||||
~SparseSegmentSumWithNumSegmentsCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
@ -42,11 +48,20 @@ class SparseSegmentSumWithNumSegmentsCpuKernelMod : public DeprecatedNativeCpuKe
|
|||
template <typename T1, typename T2>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
std::vector<KernelAttr> kernel_attr_list;
|
||||
(void)std::transform(f_list_.begin(), f_list_.end(), std::back_inserter(kernel_attr_list),
|
||||
[](const std::pair<KernelAttr, LaunchKernelFunc> &pair) { return pair.first; });
|
||||
return kernel_attr_list;
|
||||
}
|
||||
|
||||
using LaunchKernelFunc =
|
||||
std::function<void(SparseSegmentSumWithNumSegmentsCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
static std::vector<std::pair<KernelAttr, LaunchKernelFunc>> f_list_;
|
||||
LaunchKernelFunc kernel_func_;
|
||||
ShapeVector x_shape_;
|
||||
ShapeVector segment_ids_shape_;
|
||||
ShapeVector y_shape_;
|
||||
|
|
Loading…
Reference in New Issue