!41880 修补EmbeddingLookup算子上的bug,恢复误删的支持的数据类型

Merge pull request !41880 from jin_jiaqi/EmbeddingLookUp
This commit is contained in:
i-robot 2022-09-16 08:46:18 +00:00 committed by Gitee
commit 62b2a39156
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 82 additions and 50 deletions

View File

@ -31,7 +31,7 @@ using KernelRunFunc = EmbeddingLookUpCpuKernelMod::KernelRunFunc;
.AddInputAttr(kNumberType##input_params_dtype) \
.AddInputAttr(kNumberType##input_indices_dtype) \
.AddOutputAttr(kNumberType##output_dtype), \
&EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type> \
&EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type, int64_t> \
}
#define ADD_KERNEL_DYNAMIC(input_params_dtype, input_indices_dtype, output_dtype, input_params_type, \
@ -42,7 +42,18 @@ using KernelRunFunc = EmbeddingLookUpCpuKernelMod::KernelRunFunc;
.AddInputAttr(kNumberType##input_indices_dtype) \
.AddInputAttr(kNumberTypeInt64) \
.AddOutputAttr(kNumberType##output_dtype), \
&EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type> \
&EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type, int64_t> \
}
#define ADD_KERNEL_DYNAMIC_INT32(input_params_dtype, input_indices_dtype, output_dtype, input_params_type, \
input_indices_type) \
{ \
KernelAttr() \
.AddInputAttr(kNumberType##input_params_dtype) \
.AddInputAttr(kNumberType##input_indices_dtype) \
.AddInputAttr(kNumberTypeInt32) \
.AddOutputAttr(kNumberType##output_dtype), \
&EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type, int32_t> \
}
template <typename T, typename S>
@ -121,7 +132,11 @@ const std::vector<std::pair<KernelAttr, KernelRunFunc>> &EmbeddingLookUpCpuKerne
ADD_KERNEL_DYNAMIC(UInt64, Int64, UInt64, uint64_t, int64_t),
ADD_KERNEL_DYNAMIC(Float16, Int64, Float16, float16, int64_t),
ADD_KERNEL_DYNAMIC(Float32, Int64, Float32, float, int64_t),
ADD_KERNEL_DYNAMIC(Float64, Int64, Float64, double, int64_t)};
ADD_KERNEL_DYNAMIC(Float64, Int64, Float64, double, int64_t),
ADD_KERNEL_DYNAMIC_INT32(Int32, Int32, Int32, int32_t, int32_t),
ADD_KERNEL_DYNAMIC_INT32(Float32, Int32, Float32, float, int32_t)};
return func_list;
}
@ -171,15 +186,15 @@ int EmbeddingLookUpCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
return KRET_OK;
}
template <typename T, typename S>
template <typename T, typename S, typename G>
bool EmbeddingLookUpCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
T *input_params_addr = reinterpret_cast<T *>(inputs[0]->addr);
S *input_indices_addr = reinterpret_cast<S *>(inputs[1]->addr);
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
if (inputs.size() == kEmbeddingLookupDynamicShapeInputsNum) {
int64_t *input_offset_addr = reinterpret_cast<int64_t *>(inputs[2]->addr);
memcpy(&offset_, input_offset_addr, sizeof(int64_t));
G *input_offset_addr = reinterpret_cast<G *>(inputs[2]->addr);
memcpy(&offset_, input_offset_addr, sizeof(G));
}
auto task = [&](size_t start, size_t end) {
size_t task_proc_lens = end - start;

View File

@ -54,7 +54,7 @@ class EmbeddingLookUpCpuKernelMod : public NativeCpuKernelMod, public MatchKerne
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
protected:
template <typename T, typename S>
template <typename T, typename S, typename G>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs);

View File

@ -24,10 +24,10 @@ constexpr size_t kEmbeddingLookupOutputsNum = 1;
namespace mindspore {
namespace kernel {
namespace {
template <typename T, typename S>
template <typename T, typename S, typename G>
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateEmbeddingLookupKernelPtr(const std::string &kernel_name,
const uint32_t &device_id) {
return std::make_unique<cukernel::EmbeddingLookupHelperGpuKernel<T, S>>(kernel_name, device_id);
return std::make_unique<cukernel::EmbeddingLookupHelperGpuKernel<T, S, G>>(kernel_name, device_id);
}
using EmbeddingLookupPtrCreatorFunc =
@ -35,133 +35,145 @@ using EmbeddingLookupPtrCreatorFunc =
const std::vector<std::pair<KernelAttr, EmbeddingLookupPtrCreatorFunc>> kernel_attr = {
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
CreateEmbeddingLookupKernelPtr<double, int>},
CreateEmbeddingLookupKernelPtr<double, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
CreateEmbeddingLookupKernelPtr<double, int64_t>},
CreateEmbeddingLookupKernelPtr<double, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
CreateEmbeddingLookupKernelPtr<float, int>},
CreateEmbeddingLookupKernelPtr<float, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
CreateEmbeddingLookupKernelPtr<float, int64_t>},
CreateEmbeddingLookupKernelPtr<float, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
CreateEmbeddingLookupKernelPtr<half, int>},
CreateEmbeddingLookupKernelPtr<half, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
CreateEmbeddingLookupKernelPtr<half, int64_t>},
CreateEmbeddingLookupKernelPtr<half, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
CreateEmbeddingLookupKernelPtr<bool, int>},
CreateEmbeddingLookupKernelPtr<bool, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
CreateEmbeddingLookupKernelPtr<bool, int64_t>},
CreateEmbeddingLookupKernelPtr<bool, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
CreateEmbeddingLookupKernelPtr<int, int>},
CreateEmbeddingLookupKernelPtr<int, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
CreateEmbeddingLookupKernelPtr<int, int64_t>},
CreateEmbeddingLookupKernelPtr<int, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
CreateEmbeddingLookupKernelPtr<int16_t, int>},
CreateEmbeddingLookupKernelPtr<int16_t, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
CreateEmbeddingLookupKernelPtr<int16_t, int64_t>},
CreateEmbeddingLookupKernelPtr<int16_t, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
CreateEmbeddingLookupKernelPtr<int8_t, int>},
CreateEmbeddingLookupKernelPtr<int8_t, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
CreateEmbeddingLookupKernelPtr<int8_t, int64_t>},
CreateEmbeddingLookupKernelPtr<int8_t, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
CreateEmbeddingLookupKernelPtr<uint8_t, int>},
CreateEmbeddingLookupKernelPtr<uint8_t, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
CreateEmbeddingLookupKernelPtr<uint8_t, int64_t>},
CreateEmbeddingLookupKernelPtr<uint8_t, int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
CreateEmbeddingLookupKernelPtr<double, int>},
CreateEmbeddingLookupKernelPtr<double, int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
CreateEmbeddingLookupKernelPtr<double, int64_t>},
CreateEmbeddingLookupKernelPtr<double, int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
CreateEmbeddingLookupKernelPtr<float, int>},
CreateEmbeddingLookupKernelPtr<float, int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
CreateEmbeddingLookupKernelPtr<float, int64_t>},
CreateEmbeddingLookupKernelPtr<float, int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
CreateEmbeddingLookupKernelPtr<float, int32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
CreateEmbeddingLookupKernelPtr<int32_t, int32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
CreateEmbeddingLookupKernelPtr<half, int>},
CreateEmbeddingLookupKernelPtr<half, int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
CreateEmbeddingLookupKernelPtr<half, int64_t>},
CreateEmbeddingLookupKernelPtr<half, int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
CreateEmbeddingLookupKernelPtr<bool, int>},
CreateEmbeddingLookupKernelPtr<bool, int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
CreateEmbeddingLookupKernelPtr<bool, int64_t>},
CreateEmbeddingLookupKernelPtr<bool, int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
CreateEmbeddingLookupKernelPtr<int, int>},
CreateEmbeddingLookupKernelPtr<int, int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
CreateEmbeddingLookupKernelPtr<int, int64_t>},
CreateEmbeddingLookupKernelPtr<int, int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
CreateEmbeddingLookupKernelPtr<int16_t, int>},
CreateEmbeddingLookupKernelPtr<int16_t, int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
CreateEmbeddingLookupKernelPtr<int16_t, int64_t>},
CreateEmbeddingLookupKernelPtr<int16_t, int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt8),
CreateEmbeddingLookupKernelPtr<int8_t, int>},
CreateEmbeddingLookupKernelPtr<int8_t, int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt8),
CreateEmbeddingLookupKernelPtr<int8_t, int64_t>},
CreateEmbeddingLookupKernelPtr<int8_t, int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt8),
CreateEmbeddingLookupKernelPtr<uint8_t, int>},
CreateEmbeddingLookupKernelPtr<uint8_t, int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt8),
CreateEmbeddingLookupKernelPtr<uint8_t, int64_t>},
CreateEmbeddingLookupKernelPtr<uint8_t, int64_t, int64_t>},
};
} // namespace

View File

@ -50,7 +50,7 @@ size_t GetSize(const std::vector<int64_t> &shape) {
return result;
}
template <typename T, typename S>
template <typename T, typename S, typename G>
class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase {
public:
explicit EmbeddingLookupHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
@ -88,7 +88,7 @@ class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase {
input_size_list_.push_back(GetSize<S>(input_indices_shape_));
if (input_shapes.size() == kEmbeddingLookupDynamicShapeInputsNum) {
is_dynamic_shape_ = true;
input_size_list_.push_back(sizeof(int64_t));
input_size_list_.push_back(sizeof(G));
}
output_size_list_.push_back(GetSize<T>(output_shape_));
@ -118,15 +118,14 @@ class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase {
return flag;
}
if (is_dynamic_shape_) {
int64_t *input_offset_addr = nullptr;
flag = GetDeviceAddress<int64_t>(input_ptrs, kIndex2, kernel_name_, &input_offset_addr);
G *input_offset_addr = nullptr;
flag = GetDeviceAddress<G>(input_ptrs, kIndex2, kernel_name_, &input_offset_addr);
if (flag != 0) {
return flag;
}
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(&offset_, input_offset_addr, sizeof(int64_t), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(cuda_stream)),
"cudaMemcpyAsync offset_ failed");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(&offset_, input_offset_addr, sizeof(G), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(cuda_stream)),
"cudaMemcpyAsync offset_ failed");
}
CalEmbeddingLookup(input_params_addr, input_indices_addr, output_addr, dims_[kIndex0], dims_[kIndex1],
dims_[kIndex2], input_dim1_, static_cast<int64_t>(offset_),
@ -134,6 +133,11 @@ class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase {
return 0;
}
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
attr_ptr_ = std::dynamic_pointer_cast<EmbeddingLookupAttr>(kernel_attr);
offset_ = attr_ptr_->offset;
}
void ResetResource() override {
is_null_input_ = false;
input_params_shape_.clear();
@ -155,8 +159,9 @@ class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase {
int64_t input_dim1_;
bool is_null_input_;
size_t dims_[kIndex3] = {};
int64_t offset_ = 0;
G offset_ = 0;
bool is_dynamic_shape_;
std::shared_ptr<EmbeddingLookupAttr> attr_ptr_;
};
} // namespace cukernel
} // namespace mindspore