From e540be5fd0b7b1a9c4ff8218c9c0ef428f2a3b2e Mon Sep 17 00:00:00 2001 From: jin_jiaqi <1228766517@qq.com> Date: Thu, 15 Sep 2022 19:26:55 +0800 Subject: [PATCH] fix bug in op EmbeddingLookup --- .../kernel/embedding_look_up_cpu_kernel.cc | 27 +++++-- .../cpu/kernel/embedding_look_up_cpu_kernel.h | 2 +- .../arrays/embedding_lookup_gpu_kernel.cc | 80 +++++++++++-------- .../cuda_class/embedding_lookup_helper.h | 23 +++--- 4 files changed, 82 insertions(+), 50 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/embedding_look_up_cpu_kernel.cc index 51119cb79d5..45797e2b6a1 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/embedding_look_up_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/embedding_look_up_cpu_kernel.cc @@ -31,7 +31,7 @@ using KernelRunFunc = EmbeddingLookUpCpuKernelMod::KernelRunFunc; .AddInputAttr(kNumberType##input_params_dtype) \ .AddInputAttr(kNumberType##input_indices_dtype) \ .AddOutputAttr(kNumberType##output_dtype), \ - &EmbeddingLookUpCpuKernelMod::LaunchKernel \ + &EmbeddingLookUpCpuKernelMod::LaunchKernel \ } #define ADD_KERNEL_DYNAMIC(input_params_dtype, input_indices_dtype, output_dtype, input_params_type, \ @@ -42,7 +42,18 @@ using KernelRunFunc = EmbeddingLookUpCpuKernelMod::KernelRunFunc; .AddInputAttr(kNumberType##input_indices_dtype) \ .AddInputAttr(kNumberTypeInt64) \ .AddOutputAttr(kNumberType##output_dtype), \ - &EmbeddingLookUpCpuKernelMod::LaunchKernel \ + &EmbeddingLookUpCpuKernelMod::LaunchKernel \ + } + +#define ADD_KERNEL_DYNAMIC_INT32(input_params_dtype, input_indices_dtype, output_dtype, input_params_type, \ + input_indices_type) \ + { \ + KernelAttr() \ + .AddInputAttr(kNumberType##input_params_dtype) \ + .AddInputAttr(kNumberType##input_indices_dtype) \ + .AddInputAttr(kNumberTypeInt32) \ + .AddOutputAttr(kNumberType##output_dtype), \ + &EmbeddingLookUpCpuKernelMod::LaunchKernel \ } template @@ -121,7 +132,11 @@ const std::vector> &EmbeddingLookUpCpuKerne ADD_KERNEL_DYNAMIC(UInt64, Int64, UInt64, uint64_t, int64_t), ADD_KERNEL_DYNAMIC(Float16, Int64, Float16, float16, int64_t), ADD_KERNEL_DYNAMIC(Float32, Int64, Float32, float, int64_t), - ADD_KERNEL_DYNAMIC(Float64, Int64, Float64, double, int64_t)}; + ADD_KERNEL_DYNAMIC(Float64, Int64, Float64, double, int64_t), + + ADD_KERNEL_DYNAMIC_INT32(Int32, Int32, Int32, int32_t, int32_t), + ADD_KERNEL_DYNAMIC_INT32(Float32, Int32, Float32, float, int32_t)}; + return func_list; } @@ -171,15 +186,15 @@ int EmbeddingLookUpCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, return KRET_OK; } -template +template bool EmbeddingLookUpCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &, const std::vector &outputs) { T *input_params_addr = reinterpret_cast(inputs[0]->addr); S *input_indices_addr = reinterpret_cast(inputs[1]->addr); T *output_addr = reinterpret_cast(outputs[0]->addr); if (inputs.size() == kEmbeddingLookupDynamicShapeInputsNum) { - int64_t *input_offset_addr = reinterpret_cast(inputs[2]->addr); - memcpy(&offset_, input_offset_addr, sizeof(int64_t)); + G *input_offset_addr = reinterpret_cast(inputs[2]->addr); + memcpy(&offset_, input_offset_addr, sizeof(G)); } auto task = [&](size_t start, size_t end) { size_t task_proc_lens = end - start; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/embedding_look_up_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/embedding_look_up_cpu_kernel.h index 6819bc45ca3..470af8089b1 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/embedding_look_up_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/embedding_look_up_cpu_kernel.h @@ -54,7 +54,7 @@ class EmbeddingLookUpCpuKernelMod : public NativeCpuKernelMod, public MatchKerne std::vector GetOpSupport() override { return OpSupport(); } protected: - template + template bool LaunchKernel(const std::vector &inputs, const std::vector &, const std::vector &outputs); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/embedding_lookup_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/embedding_lookup_gpu_kernel.cc index 111682d97dc..09191a8ccaf 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/embedding_lookup_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/embedding_lookup_gpu_kernel.cc @@ -24,10 +24,10 @@ constexpr size_t kEmbeddingLookupOutputsNum = 1; namespace mindspore { namespace kernel { namespace { -template +template std::unique_ptr CreateEmbeddingLookupKernelPtr(const std::string &kernel_name, const uint32_t &device_id) { - return std::make_unique>(kernel_name, device_id); + return std::make_unique>(kernel_name, device_id); } using EmbeddingLookupPtrCreatorFunc = @@ -35,133 +35,145 @@ using EmbeddingLookupPtrCreatorFunc = const std::vector> kernel_attr = { {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeFloat64) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat64), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeFloat64) .AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat64), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat32), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat32), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + CreateEmbeddingLookupKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat16), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat16), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeBool) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeBool), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeBool) .AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeBool), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeInt32), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeInt32), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeInt16) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeInt16), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeInt16) .AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeInt16), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeInt8) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeInt8), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeInt8) .AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeInt8), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeUInt8) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeUInt8), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, {KernelAttr() .AddInputAttr(kNumberTypeUInt8) .AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeUInt8), - CreateEmbeddingLookupKernelPtr}, + CreateEmbeddingLookupKernelPtr}, }; } // namespace diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/embedding_lookup_helper.h b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/embedding_lookup_helper.h index 41590a1d27a..1ce45dba896 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/embedding_lookup_helper.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/embedding_lookup_helper.h @@ -50,7 +50,7 @@ size_t GetSize(const std::vector &shape) { return result; } -template +template class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase { public: explicit EmbeddingLookupHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) @@ -88,7 +88,7 @@ class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase { input_size_list_.push_back(GetSize(input_indices_shape_)); if (input_shapes.size() == kEmbeddingLookupDynamicShapeInputsNum) { is_dynamic_shape_ = true; - input_size_list_.push_back(sizeof(int64_t)); + input_size_list_.push_back(sizeof(G)); } output_size_list_.push_back(GetSize(output_shape_)); @@ -118,15 +118,14 @@ class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase { return flag; } if (is_dynamic_shape_) { - int64_t *input_offset_addr = nullptr; - flag = GetDeviceAddress(input_ptrs, kIndex2, kernel_name_, &input_offset_addr); + G *input_offset_addr = nullptr; + flag = GetDeviceAddress(input_ptrs, kIndex2, kernel_name_, &input_offset_addr); if (flag != 0) { return flag; } - CHECK_CUDA_RET_WITH_ERROR_NOTRACE( - cudaMemcpyAsync(&offset_, input_offset_addr, sizeof(int64_t), cudaMemcpyDeviceToHost, - reinterpret_cast(cuda_stream)), - "cudaMemcpyAsync offset_ failed"); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(&offset_, input_offset_addr, sizeof(G), cudaMemcpyDeviceToHost, + reinterpret_cast(cuda_stream)), + "cudaMemcpyAsync offset_ failed"); } CalEmbeddingLookup(input_params_addr, input_indices_addr, output_addr, dims_[kIndex0], dims_[kIndex1], dims_[kIndex2], input_dim1_, static_cast(offset_), @@ -134,6 +133,11 @@ class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase { return 0; } + void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override { + attr_ptr_ = std::dynamic_pointer_cast(kernel_attr); + offset_ = attr_ptr_->offset; + } + void ResetResource() override { is_null_input_ = false; input_params_shape_.clear(); @@ -155,8 +159,9 @@ class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase { int64_t input_dim1_; bool is_null_input_; size_t dims_[kIndex3] = {}; - int64_t offset_ = 0; + G offset_ = 0; bool is_dynamic_shape_; + std::shared_ptr attr_ptr_; }; } // namespace cukernel } // namespace mindspore