!41880 修补EmbeddingLookup算子上的bug,恢复误删的支持的数据类型
Merge pull request !41880 from jin_jiaqi/EmbeddingLookUp
This commit is contained in:
commit
62b2a39156
|
@ -31,7 +31,7 @@ using KernelRunFunc = EmbeddingLookUpCpuKernelMod::KernelRunFunc;
|
|||
.AddInputAttr(kNumberType##input_params_dtype) \
|
||||
.AddInputAttr(kNumberType##input_indices_dtype) \
|
||||
.AddOutputAttr(kNumberType##output_dtype), \
|
||||
&EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type> \
|
||||
&EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type, int64_t> \
|
||||
}
|
||||
|
||||
#define ADD_KERNEL_DYNAMIC(input_params_dtype, input_indices_dtype, output_dtype, input_params_type, \
|
||||
|
@ -42,7 +42,18 @@ using KernelRunFunc = EmbeddingLookUpCpuKernelMod::KernelRunFunc;
|
|||
.AddInputAttr(kNumberType##input_indices_dtype) \
|
||||
.AddInputAttr(kNumberTypeInt64) \
|
||||
.AddOutputAttr(kNumberType##output_dtype), \
|
||||
&EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type> \
|
||||
&EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type, int64_t> \
|
||||
}
|
||||
|
||||
#define ADD_KERNEL_DYNAMIC_INT32(input_params_dtype, input_indices_dtype, output_dtype, input_params_type, \
|
||||
input_indices_type) \
|
||||
{ \
|
||||
KernelAttr() \
|
||||
.AddInputAttr(kNumberType##input_params_dtype) \
|
||||
.AddInputAttr(kNumberType##input_indices_dtype) \
|
||||
.AddInputAttr(kNumberTypeInt32) \
|
||||
.AddOutputAttr(kNumberType##output_dtype), \
|
||||
&EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type, int32_t> \
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
|
@ -121,7 +132,11 @@ const std::vector<std::pair<KernelAttr, KernelRunFunc>> &EmbeddingLookUpCpuKerne
|
|||
ADD_KERNEL_DYNAMIC(UInt64, Int64, UInt64, uint64_t, int64_t),
|
||||
ADD_KERNEL_DYNAMIC(Float16, Int64, Float16, float16, int64_t),
|
||||
ADD_KERNEL_DYNAMIC(Float32, Int64, Float32, float, int64_t),
|
||||
ADD_KERNEL_DYNAMIC(Float64, Int64, Float64, double, int64_t)};
|
||||
ADD_KERNEL_DYNAMIC(Float64, Int64, Float64, double, int64_t),
|
||||
|
||||
ADD_KERNEL_DYNAMIC_INT32(Int32, Int32, Int32, int32_t, int32_t),
|
||||
ADD_KERNEL_DYNAMIC_INT32(Float32, Int32, Float32, float, int32_t)};
|
||||
|
||||
return func_list;
|
||||
}
|
||||
|
||||
|
@ -171,15 +186,15 @@ int EmbeddingLookUpCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
template <typename T, typename S, typename G>
|
||||
bool EmbeddingLookUpCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input_params_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
S *input_indices_addr = reinterpret_cast<S *>(inputs[1]->addr);
|
||||
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
if (inputs.size() == kEmbeddingLookupDynamicShapeInputsNum) {
|
||||
int64_t *input_offset_addr = reinterpret_cast<int64_t *>(inputs[2]->addr);
|
||||
memcpy(&offset_, input_offset_addr, sizeof(int64_t));
|
||||
G *input_offset_addr = reinterpret_cast<G *>(inputs[2]->addr);
|
||||
memcpy(&offset_, input_offset_addr, sizeof(G));
|
||||
}
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
size_t task_proc_lens = end - start;
|
||||
|
|
|
@ -54,7 +54,7 @@ class EmbeddingLookUpCpuKernelMod : public NativeCpuKernelMod, public MatchKerne
|
|||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
protected:
|
||||
template <typename T, typename S>
|
||||
template <typename T, typename S, typename G>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
|
|
|
@ -24,10 +24,10 @@ constexpr size_t kEmbeddingLookupOutputsNum = 1;
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
template <typename T, typename S>
|
||||
template <typename T, typename S, typename G>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateEmbeddingLookupKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::EmbeddingLookupHelperGpuKernel<T, S>>(kernel_name, device_id);
|
||||
return std::make_unique<cukernel::EmbeddingLookupHelperGpuKernel<T, S, G>>(kernel_name, device_id);
|
||||
}
|
||||
|
||||
using EmbeddingLookupPtrCreatorFunc =
|
||||
|
@ -35,133 +35,145 @@ using EmbeddingLookupPtrCreatorFunc =
|
|||
|
||||
const std::vector<std::pair<KernelAttr, EmbeddingLookupPtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateEmbeddingLookupKernelPtr<double, int>},
|
||||
CreateEmbeddingLookupKernelPtr<double, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateEmbeddingLookupKernelPtr<double, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<double, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateEmbeddingLookupKernelPtr<float, int>},
|
||||
CreateEmbeddingLookupKernelPtr<float, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateEmbeddingLookupKernelPtr<float, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<float, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateEmbeddingLookupKernelPtr<half, int>},
|
||||
CreateEmbeddingLookupKernelPtr<half, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateEmbeddingLookupKernelPtr<half, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<half, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
CreateEmbeddingLookupKernelPtr<bool, int>},
|
||||
CreateEmbeddingLookupKernelPtr<bool, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
CreateEmbeddingLookupKernelPtr<bool, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<bool, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateEmbeddingLookupKernelPtr<int, int>},
|
||||
CreateEmbeddingLookupKernelPtr<int, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateEmbeddingLookupKernelPtr<int, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<int, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateEmbeddingLookupKernelPtr<int16_t, int>},
|
||||
CreateEmbeddingLookupKernelPtr<int16_t, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateEmbeddingLookupKernelPtr<int16_t, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<int16_t, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
CreateEmbeddingLookupKernelPtr<int8_t, int>},
|
||||
CreateEmbeddingLookupKernelPtr<int8_t, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
CreateEmbeddingLookupKernelPtr<int8_t, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<int8_t, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateEmbeddingLookupKernelPtr<uint8_t, int>},
|
||||
CreateEmbeddingLookupKernelPtr<uint8_t, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateEmbeddingLookupKernelPtr<uint8_t, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<uint8_t, int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateEmbeddingLookupKernelPtr<double, int>},
|
||||
CreateEmbeddingLookupKernelPtr<double, int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateEmbeddingLookupKernelPtr<double, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<double, int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateEmbeddingLookupKernelPtr<float, int>},
|
||||
CreateEmbeddingLookupKernelPtr<float, int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateEmbeddingLookupKernelPtr<float, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<float, int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateEmbeddingLookupKernelPtr<float, int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateEmbeddingLookupKernelPtr<int32_t, int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateEmbeddingLookupKernelPtr<half, int>},
|
||||
CreateEmbeddingLookupKernelPtr<half, int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateEmbeddingLookupKernelPtr<half, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<half, int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
CreateEmbeddingLookupKernelPtr<bool, int>},
|
||||
CreateEmbeddingLookupKernelPtr<bool, int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
CreateEmbeddingLookupKernelPtr<bool, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<bool, int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateEmbeddingLookupKernelPtr<int, int>},
|
||||
CreateEmbeddingLookupKernelPtr<int, int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateEmbeddingLookupKernelPtr<int, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<int, int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
CreateEmbeddingLookupKernelPtr<int16_t, int>},
|
||||
CreateEmbeddingLookupKernelPtr<int16_t, int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
CreateEmbeddingLookupKernelPtr<int16_t, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<int16_t, int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
CreateEmbeddingLookupKernelPtr<int8_t, int>},
|
||||
CreateEmbeddingLookupKernelPtr<int8_t, int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
CreateEmbeddingLookupKernelPtr<int8_t, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<int8_t, int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateEmbeddingLookupKernelPtr<uint8_t, int>},
|
||||
CreateEmbeddingLookupKernelPtr<uint8_t, int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateEmbeddingLookupKernelPtr<uint8_t, int64_t>},
|
||||
CreateEmbeddingLookupKernelPtr<uint8_t, int64_t, int64_t>},
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ size_t GetSize(const std::vector<int64_t> &shape) {
|
|||
return result;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
template <typename T, typename S, typename G>
|
||||
class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase {
|
||||
public:
|
||||
explicit EmbeddingLookupHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
|
@ -88,7 +88,7 @@ class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase {
|
|||
input_size_list_.push_back(GetSize<S>(input_indices_shape_));
|
||||
if (input_shapes.size() == kEmbeddingLookupDynamicShapeInputsNum) {
|
||||
is_dynamic_shape_ = true;
|
||||
input_size_list_.push_back(sizeof(int64_t));
|
||||
input_size_list_.push_back(sizeof(G));
|
||||
}
|
||||
|
||||
output_size_list_.push_back(GetSize<T>(output_shape_));
|
||||
|
@ -118,15 +118,14 @@ class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase {
|
|||
return flag;
|
||||
}
|
||||
if (is_dynamic_shape_) {
|
||||
int64_t *input_offset_addr = nullptr;
|
||||
flag = GetDeviceAddress<int64_t>(input_ptrs, kIndex2, kernel_name_, &input_offset_addr);
|
||||
G *input_offset_addr = nullptr;
|
||||
flag = GetDeviceAddress<G>(input_ptrs, kIndex2, kernel_name_, &input_offset_addr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(&offset_, input_offset_addr, sizeof(int64_t), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream)),
|
||||
"cudaMemcpyAsync offset_ failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(&offset_, input_offset_addr, sizeof(G), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream)),
|
||||
"cudaMemcpyAsync offset_ failed");
|
||||
}
|
||||
CalEmbeddingLookup(input_params_addr, input_indices_addr, output_addr, dims_[kIndex0], dims_[kIndex1],
|
||||
dims_[kIndex2], input_dim1_, static_cast<int64_t>(offset_),
|
||||
|
@ -134,6 +133,11 @@ class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase {
|
|||
return 0;
|
||||
}
|
||||
|
||||
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
|
||||
attr_ptr_ = std::dynamic_pointer_cast<EmbeddingLookupAttr>(kernel_attr);
|
||||
offset_ = attr_ptr_->offset;
|
||||
}
|
||||
|
||||
void ResetResource() override {
|
||||
is_null_input_ = false;
|
||||
input_params_shape_.clear();
|
||||
|
@ -155,8 +159,9 @@ class EmbeddingLookupHelperGpuKernel : public GpuKernelHelperBase {
|
|||
int64_t input_dim1_;
|
||||
bool is_null_input_;
|
||||
size_t dims_[kIndex3] = {};
|
||||
int64_t offset_ = 0;
|
||||
G offset_ = 0;
|
||||
bool is_dynamic_shape_;
|
||||
std::shared_ptr<EmbeddingLookupAttr> attr_ptr_;
|
||||
};
|
||||
} // namespace cukernel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue