forked from mindspore-Ecosystem/mindspore
!48159 增加SparseGatherV2数据类型
Merge pull request !48159 from zong_shuai/null_tensor
This commit is contained in:
commit
d57a167eee
|
@ -31,7 +31,7 @@ bool GatherV2FwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
|
|||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_map_[kernel_name_][index].second;
|
||||
kernel_func_ = func_list_[index].second;
|
||||
|
||||
input_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype);
|
||||
indices_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).dtype);
|
||||
|
@ -75,13 +75,8 @@ int GatherV2FwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
}
|
||||
|
||||
std::vector<KernelAttr> GatherV2FwdGpuKernelMod::GetOpSupport() {
|
||||
auto iter = func_map_.find(kernel_type_);
|
||||
if (iter == func_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' gpu does not support " << kernel_type_;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, GatherV2Func> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
@ -121,41 +116,22 @@ bool GatherV2FwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs
|
|||
GATHER_GPU_REG(MS_T, kNumberTypeInt64, kNumberTypeInt32, T, int64_t, int32_t) \
|
||||
}
|
||||
|
||||
std::map<std::string, std::vector<std::pair<KernelAttr, GatherV2FwdGpuKernelMod::GatherV2Func>>>
|
||||
GatherV2FwdGpuKernelMod::func_map_ = {
|
||||
{kGather,
|
||||
{
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeComplex64, mindspore::utils::Complex<float>),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeComplex128, mindspore::utils::Complex<double>),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeFloat64, double),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeFloat32, float),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeFloat16, half),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeInt64, int64_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeInt32, int32_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeInt16, int16_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeInt8, int8_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeUInt64, uint64_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeUInt32, uint32_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeUInt16, uint16_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeUInt8, uint8_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeBool, bool),
|
||||
}},
|
||||
{kSparseGatherV2,
|
||||
{
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<float, int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<half, int32_t, int64_t>},
|
||||
}},
|
||||
};
|
||||
std::vector<std::pair<KernelAttr, GatherV2FwdGpuKernelMod::GatherV2Func>> GatherV2FwdGpuKernelMod::func_list_ = {{
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeComplex64, mindspore::utils::Complex<float>),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeComplex128, mindspore::utils::Complex<double>),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeFloat64, double),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeFloat32, float),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeFloat16, half),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeInt64, int64_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeInt32, int32_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeInt16, int16_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeInt8, int8_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeUInt64, uint64_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeUInt32, uint32_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeUInt16, uint16_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeUInt8, uint8_t),
|
||||
GATHER_GPU_INDEX_REG(kNumberTypeBool, bool),
|
||||
}};
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Gather,
|
||||
[]() { return std::make_shared<GatherV2FwdGpuKernelMod>(kGather); });
|
||||
|
|
|
@ -102,7 +102,7 @@ class GatherV2FwdGpuKernelMod : public NativeGpuKernelMod {
|
|||
private:
|
||||
using GatherV2Func = std::function<bool(GatherV2FwdGpuKernelMod *, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
|
||||
static std::map<std::string, std::vector<std::pair<KernelAttr, GatherV2Func>>> func_map_;
|
||||
static std::vector<std::pair<KernelAttr, GatherV2Func>> func_list_;
|
||||
GatherV2Func kernel_func_;
|
||||
|
||||
std::vector<int64_t> input_shapes_;
|
||||
|
|
Loading…
Reference in New Issue