diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/gather_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/gather_cpu_kernel.cc index 2b0720de312..3134bca5b46 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/gather_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/gather_cpu_kernel.cc @@ -22,6 +22,7 @@ #include "nnacl/gather_parameter.h" #include "nnacl/base/gather_base.h" #include "include/common/thread_pool.h" +#include "mindspore/core/ops/gather.h" namespace mindspore { namespace kernel { @@ -34,7 +35,12 @@ constexpr size_t kGatherInputParamsMaxDim = 7; } // namespace bool GatherCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); kernel_name_ = base_operator->name(); + auto kernel_ptr = std::dynamic_pointer_cast(base_operator); + MS_EXCEPTION_IF_NULL(kernel_ptr); + batch_dims_ = kernel_ptr->get_batch_dims(); + size_t input_num = inputs.size(); if (input_num != kGatherInputsNum) { MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherCPUKernel needs 2."; @@ -62,6 +68,9 @@ int GatherCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:: if (IsDynamic(input_shape_) || IsDynamic(indices_shape_) || IsDynamic(output_shape_)) { return KRET_UNKNOWN_SHAPE; } + if (batch_dims_ < 0) { + batch_dims_ += SizeToLong(input_shapes_.size()); + } is_null_input_ = input_shape_.empty() || indices_shape_.empty() || output_shape_.empty(); if (is_null_input_) { InitSizeLists(); @@ -94,31 +103,45 @@ bool GatherCpuKernelMod::LaunchKernel(const std::vector &inp axis_ = axis_ + dims; } - size_t outer_size = 1, inner_size = 1; - auto axis = static_cast(axis_); - for (size_t i = 0; i < axis; ++i) { + size_t batch_size = 1; + size_t outer_size = 1; + size_t indices_element_size = 1; + size_t inner_size = 1; + auto axis = LongToSize(axis_); + auto batch_dims = LongToSize(batch_dims_); + for (size_t i = 0; i < batch_dims; i++) { + batch_size *= LongToSize(input_shape_.at(i)); + } + for (size_t i = batch_dims; i < axis; ++i) { outer_size *= LongToSize(input_shape_.at(i)); } for (size_t i = axis + 1; i < input_shape_.size(); ++i) { inner_size *= LongToSize(input_shape_.at(i)); } - size_t indices_element_size = 1; - for (size_t i = 0; i < indices_shape_.size(); i++) { + for (size_t i = batch_dims; i < indices_shape_.size(); i++) { indices_element_size *= LongToSize(indices_shape_.at(i)); } + auto limit = LongToSize(input_shape_.at(axis)); size_t byte_inner_size = inner_size * sizeof(T); size_t byte_out_stride = indices_element_size * byte_inner_size; - auto task = [&](size_t start, size_t end) { - int count = SizeToInt(end - start); - const int8_t *in = input_tensor + start * limit * byte_inner_size; - int8_t *out = output_addr + start * byte_out_stride; - int ret = Gather(in, count, byte_inner_size, limit, indices_data, indices_element_size, out, byte_out_stride); - if (ret != 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', error_code[" << ret << "]"; - } - }; - ParallelLaunchAutoSearch(task, outer_size, this, ¶llel_search_info_); + for (size_t i = 0; i < batch_size; i++) { + auto output_ptr = output_addr + i * outer_size * byte_out_stride; + auto input_ptr = input_tensor + i * outer_size * byte_inner_size * limit; + auto indice_ptr = indices_data + i * indices_element_size; + + auto task = [&](size_t start, size_t end) { + int count = SizeToInt(end - start); + const int8_t *in = input_ptr + start * limit * byte_inner_size; + int8_t *out = output_ptr + start * byte_out_stride; + int ret = Gather(in, count, byte_inner_size, limit, indice_ptr, indices_element_size, out, byte_out_stride); + if (ret != 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', error_code[" << ret << "]"; + } + }; + ParallelLaunchAutoSearch(task, outer_size, this, ¶llel_search_info_); + } + return true; } diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/gather_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/gather_cpu_kernel.h index f39b3ba4ee2..77ce415e532 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/gather_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/gather_cpu_kernel.h @@ -79,6 +79,7 @@ class GatherCpuKernelMod : public NativeCpuKernelMod { ShapeVector indices_shape_; ShapeVector output_shape_; int64_t axis_{0}; + int64_t batch_dims_{0}; size_t input_type_size_ = 0; size_t indices_type_size_ = 0; size_t axis_type_size_ = 0; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.cc index ab0024d08b1..77c893afc4d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.cc @@ -17,6 +17,7 @@ #include "plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.h" #include #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "mindspore/core/ops/gather.h" namespace mindspore { namespace kernel { @@ -25,6 +26,12 @@ bool GatherV2FwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s const std::vector &outputs) { MS_EXCEPTION_IF_NULL(base_operator); kernel_name_ = base_operator->name(); + if (kernel_name_ == ops::kNameGather) { + auto kernel_ptr = std::dynamic_pointer_cast(base_operator); + MS_EXCEPTION_IF_NULL(kernel_ptr); + batch_dims_ = kernel_ptr->get_batch_dims(); + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); if (!is_match) { @@ -57,6 +64,9 @@ int GatherV2FwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const if (IsDynamic(input_shapes_) || IsDynamic(indices_shapes_) || IsDynamic(output_shapes_)) { return KRET_UNKNOWN_SHAPE; } + if (batch_dims_ < 0) { + batch_dims_ += SizeToLong(input_shapes_.size()); + } is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name_, "input") || CHECK_SHAPE_NULL(indices_shapes_, kernel_name_, "indices") || CHECK_SHAPE_NULL(output_shapes_, kernel_name_, "output"); @@ -100,7 +110,7 @@ bool GatherV2FwdGpuKernelMod::LaunchKernel(const std::vector &inputs MS_EXCEPTION_IF_NULL(input_addr); MS_EXCEPTION_IF_NULL(indices_addr); - GatherV2(input_addr, indices_addr, output_addr, dims_[kIndex0], dims_[kIndex1], dims_[kIndex2], + GatherV2(input_addr, indices_addr, output_addr, dims_[kIndex0], dims_[kIndex1], dims_[kIndex2], dims_[kIndex3], LongToSize(input_dim1), reinterpret_cast(stream_ptr)); return true; } diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.h index ea7196c3ca6..52f761ce970 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.h @@ -81,21 +81,27 @@ class GatherV2FwdGpuKernelMod : public NativeGpuKernelMod { if (axis_ < 0) { axis_ = axis_ + SizeToInt(input_shapes_.size()); } - int64_t dim_before_axis = 1; - for (size_t i = 0; i < std::min(IntToSize(axis_), output_shapes_.size()); i++) { - dim_before_axis *= output_shapes_[i]; + size_t batch_size = 1; + size_t batch_dims = LongToSize(batch_dims_); + for (size_t i = 0; i < batch_dims; i++) { + batch_size *= LongToSize(input_shapes_[i]); } - int64_t dim_of_indices = 1; - for (size_t i = 0; i < indices_shapes_.size(); i++) { - dim_of_indices *= indices_shapes_[i]; + size_t dim_before_axis = 1; + for (size_t i = batch_dims; i < std::min(IntToSize(axis_), output_shapes_.size()); i++) { + dim_before_axis *= LongToSize(output_shapes_[i]); } - int64_t dim_after_indices = 1; - for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) { - dim_after_indices *= output_shapes_[i]; + size_t dim_of_indices = 1; + for (size_t i = batch_dims; i < indices_shapes_.size(); i++) { + dim_of_indices *= LongToSize(indices_shapes_[i]); } - dims_[kIndex0] = dim_before_axis; - dims_[kIndex1] = dim_of_indices; - dims_[kIndex2] = dim_after_indices; + size_t dim_after_indices = 1; + for (size_t i = IntToSize(axis_) + 1; i < input_shapes_.size(); i++) { + dim_after_indices *= LongToSize(input_shapes_[i]); + } + dims_[kIndex0] = batch_size; + dims_[kIndex1] = dim_before_axis; + dims_[kIndex2] = dim_of_indices; + dims_[kIndex3] = dim_after_indices; return; } @@ -108,8 +114,9 @@ class GatherV2FwdGpuKernelMod : public NativeGpuKernelMod { std::vector input_shapes_; std::vector indices_shapes_; std::vector output_shapes_; - int64_t dims_[kIndex3] = {}; + size_t dims_[kIndex4] = {}; int64_t axis_ = 0; + int64_t batch_dims_{0}; bool is_null_input_ = false; size_t input_type_size_ = 0; size_t indices_type_size_ = 0; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/gatherv2.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/gatherv2.cu index 39fe6026cae..7609fb7e9c0 100755 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/gatherv2.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/gatherv2.cu @@ -44,101 +44,129 @@ __global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_di return; } template -void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream) { - size_t size = output_dim0 * output_dim1 * output_dim2; - GatherV2Kernel<<>>(input, indices, output, output_dim0, output_dim1, - output_dim2, input_dim1); +__global__ void GatherV2WithBatchDimsKernel(T *input, S *indices, T *output, size_t batch_size, size_t output_dim0, + size_t output_dim1, size_t output_dim2, size_t input_dim1) { + size_t num = batch_size * output_dim0 * output_dim1 * output_dim2; + size_t i, j, k, n; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; + write_index += blockDim.x * gridDim.x) { + i = write_index / (output_dim0 * output_dim1 * output_dim2) % batch_size; + j = i * output_dim1 + write_index / output_dim2 % output_dim1; + n = write_index / (output_dim1 * output_dim2) % output_dim0; + k = write_index % output_dim2; + + if ((indices[j] >= 0) && (indices[j] < input_dim1)) { + size_t read_index = + i * output_dim0 * input_dim1 * output_dim2 + n * input_dim1 * output_dim2 + indices[j] * output_dim2 + k; + output[write_index] = input[read_index]; + } else { + output[write_index] = 0; + } + } + return; } -template CUDA_LIB_EXPORT void GatherV2, int>(Complex *input, int *indices, - Complex *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); +template +void GatherV2(T *input, S *indices, T *output, size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream) { + size_t size = batch_size * output_dim0 * output_dim1 * output_dim2; + if (batch_size > 1) { + GatherV2WithBatchDimsKernel<<>>( + input, indices, output, batch_size, output_dim0, output_dim1, output_dim2, input_dim1); + } else { + GatherV2Kernel<<>>(input, indices, output, output_dim0, output_dim1, + output_dim2, input_dim1); + } + return; +} + +template CUDA_LIB_EXPORT void GatherV2, int>(Complex *input, int *indices, Complex *output, + size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); template CUDA_LIB_EXPORT void GatherV2, int64_t>(Complex *input, int64_t *indices, - Complex *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); + Complex *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, + cudaStream_t stream); template CUDA_LIB_EXPORT void GatherV2, int>(Complex *input, int *indices, - Complex *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); + Complex *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); template CUDA_LIB_EXPORT void GatherV2, int64_t>(Complex *input, int64_t *indices, - Complex *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(float *input, int *indices, float *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(float *input, int64_t *indices, float *output, + Complex *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, + cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(float *input, int *indices, float *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(float *input, int64_t *indices, float *output, size_t batch_size, size_t output_dim0, size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(half *input, int *indices, half *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(half *input, int64_t *indices, half *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(double *input, int *indices, double *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(half *input, int *indices, half *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(half *input, int64_t *indices, half *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(double *input, int *indices, double *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); template CUDA_LIB_EXPORT void GatherV2(double *input, int64_t *indices, double *output, - size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(int64_t *input, int *indices, int64_t *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); + size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(int64_t *input, int *indices, int64_t *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); template CUDA_LIB_EXPORT void GatherV2(int64_t *input, int64_t *indices, int64_t *output, - size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(int *input, int *indices, int *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(int *input, int64_t *indices, int *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(int16_t *input, int *indices, int16_t *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); + size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(int *input, int *indices, int *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(int *input, int64_t *indices, int *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(int16_t *input, int *indices, int16_t *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); template CUDA_LIB_EXPORT void GatherV2(int16_t *input, int64_t *indices, int16_t *output, - size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(int8_t *input, int *indices, int8_t *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); + size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(int8_t *input, int *indices, int8_t *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); template CUDA_LIB_EXPORT void GatherV2(int8_t *input, int64_t *indices, int8_t *output, - size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream); + size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); template CUDA_LIB_EXPORT void GatherV2(uint64_t *input, int *indices, uint64_t *output, - size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream); + size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); template CUDA_LIB_EXPORT void GatherV2(uint64_t *input, int64_t *indices, uint64_t *output, - size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream); + size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); template CUDA_LIB_EXPORT void GatherV2(uint32_t *input, int *indices, uint32_t *output, + size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(uint32_t *input, int64_t *indices, uint32_t *output, + size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(uint16_t *input, int *indices, uint16_t *output, + size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(uint16_t *input, int64_t *indices, uint16_t *output, + size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(uint8_t *input, int *indices, uint8_t *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(uint8_t *input, int64_t *indices, uint8_t *output, + size_t batch_size, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(bool *input, int *indices, bool *output, size_t batch_size, + size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); +template CUDA_LIB_EXPORT void GatherV2(bool *input, int64_t *indices, bool *output, size_t batch_size, size_t output_dim0, size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(uint32_t *input, int64_t *indices, uint32_t *output, - size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(uint16_t *input, int *indices, uint16_t *output, - size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(uint16_t *input, int64_t *indices, uint16_t *output, - size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(uint8_t *input, int *indices, uint8_t *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(uint8_t *input, int64_t *indices, uint8_t *output, - size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(bool *input, int *indices, bool *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); -template CUDA_LIB_EXPORT void GatherV2(bool *input, int64_t *indices, bool *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/gatherv2.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/gatherv2.cuh index aa82f2f74b2..7ff96f08311 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/gatherv2.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/gatherv2.cuh @@ -18,8 +18,8 @@ #define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_GATHERV2_CUH_ #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" template -CUDA_LIB_EXPORT void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, - size_t output_dim2, size_t input_dim1, cudaStream_t stream); +CUDA_LIB_EXPORT void GatherV2(T *input, S *indices, T *output, size_t batch_size, size_t output_dim0, + size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); template __global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, diff --git a/mindspore/core/ops/gather.cc b/mindspore/core/ops/gather.cc index 425a55bda79..33409dc9cd6 100644 --- a/mindspore/core/ops/gather.cc +++ b/mindspore/core/ops/gather.cc @@ -45,6 +45,56 @@ namespace mindspore { namespace ops { +constexpr auto kBatchDims = "batch_dims"; + +void Gather::set_batch_dims(int64_t batch_dims) { (void)this->AddAttr(kBatchDims, api::MakeValue(batch_dims)); } + +int64_t Gather::get_batch_dims() const { return GetValue(GetAttr(kBatchDims)); } + +void CheckBatchDims(int64_t batch_dims, int64_t axis_val, const ShapeVector ¶ms_shp, const ShapeVector &indices_shp, + const std::string &op_name) { + int64_t params_rank = static_cast(params_shp.size()); + int64_t indices_rank = static_cast(indices_shp.size()); + if (batch_dims < -indices_rank || batch_dims > indices_rank) { + MS_LOG(EXCEPTION) << "For '" << op_name << "', batch_dims must be in [" << -indices_rank << ", " << indices_rank + << "], but got batch_dims: " << batch_dims; + } + if (batch_dims < 0) { + batch_dims += indices_rank; + } + if (batch_dims > params_rank) { + MS_LOG(EXCEPTION) << "For '" << op_name + << "', batch_dims must be less than params's rank, but got batch_dims: " << batch_dims + << ", oarams's rank: " << params_rank; + } + if (axis_val < batch_dims) { + MS_LOG(EXCEPTION) << "For '" << op_name + << "', batch_dims must be less than or equal to axis, but got batch_dims: " << batch_dims + << ", axis: " << axis_val; + } + for (size_t i = 0; i < LongToSize(batch_dims); i++) { + if (params_shp[i] != indices_shp[i]) { + MS_LOG(EXCEPTION) << "For '" << op_name << "', params.shape[" << i << "], should be equal to indices.shape[" << i + << "], but got param.shape: " << params_shp << ", indices.shape: " << indices_shp; + } + } +} + +ShapeVector CalcuateGatherWithBatchDimsOutputShape(int64_t batch_dims, int64_t axis_val, const ShapeVector &ind_vec, + const ShapeVector ¶ms_vec) { + ShapeVector out_vec; + for (size_t i = 0; i < LongToSize(axis_val); i++) { + out_vec.push_back(params_vec[i]); + } + for (size_t i = LongToSize(batch_dims); i < ind_vec.size(); i++) { + out_vec.push_back(ind_vec[i]); + } + for (size_t i = LongToSize(axis_val) + 1; i < params_vec.size(); i++) { + out_vec.push_back(params_vec[i]); + } + return out_vec; +} + abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); const std::string &op_name = primitive->name(); @@ -108,6 +158,13 @@ abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::ve if (axis_val < 0) { axis_val += params_rank; } + if (op_name == kNameGather) { + int64_t batch_dims = GetValue(primitive->GetAttr(kBatchDims)); + CheckBatchDims(batch_dims, axis_val, params_shp, indices_shp, op_name); + out_shape = CalcuateGatherWithBatchDimsOutputShape(batch_dims, axis_val, indices_shp, params_shp); + return std::make_shared(out_shape); + } + auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector ¶ms_vec) -> ShapeVector { ShapeVector out_vec; (void)std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec)); @@ -142,9 +199,11 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive MS_EXCEPTION_IF_NULL(primitive); const int64_t kInputsNum = 3; CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name()); + if (!primitive->HasAttr(kBatchDims)) { + (void)primitive->AddAttr("batch_dims", MakeValue(static_cast(0))); // Add temporarily for ascend. + } auto infer_type = GatherInferType(primitive, input_args); auto infer_shape = GatherInferShape(primitive, input_args); - (void)primitive->AddAttr("batch_dims", MakeValue(static_cast(0))); // Add temporarily for gatherv2 on ascend return abstract::MakeAbstract(infer_shape, infer_type); } diff --git a/mindspore/core/ops/gather.h b/mindspore/core/ops/gather.h index 9e04a7e2a39..c0147dbff46 100644 --- a/mindspore/core/ops/gather.h +++ b/mindspore/core/ops/gather.h @@ -35,6 +35,8 @@ class MIND_API Gather : public BaseOperator { Gather() : BaseOperator(kNameGather) { InitIOName({"param", "indices", "axis"}, {"output"}); } /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Gather for the inputs. void Init() const {} + void set_batch_dims(int64_t batch_dims); + int64_t get_batch_dims() const; }; MIND_API abstract::AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index e5d6151854f..b9ff2bd65bc 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -3606,12 +3606,12 @@ def gather_nd(input_x, indices): return F.gather_nd(input_x, indices) -def gather(input_x, input_indices, axis): +def gather(input_x, input_indices, axis, batch_dims=0): r""" Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`. Refer to :func:`mindspore.ops.gather` for more detail. """ - return F.gather(input_x, input_indices, axis) + return F.gather(input_x, input_indices, axis, batch_dims) def split(x, split_size_or_sections, axis=0): diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index bd2d5edab36..676ab3b5738 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -2821,13 +2821,14 @@ class Tensor(Tensor_, metaclass=_TensorMeta): validator.check_value_type('indices', indices, (Tensor, Tensor_,), 'Tensor.gather_nd') return tensor_operator_registry.get('gather_nd')(self, indices) - def gather(self, input_indices, axis): + def gather(self, input_indices, axis, batch_dims=0): r""" For details, please refer to :func:`mindspore.ops.gather`. """ self._init_check() validator.check_is_int(axis, 'axis') - return tensor_operator_registry.get('gather')(self, input_indices, axis) + validator.check_is_int(batch_dims, "batch_dims") + return tensor_operator_registry.get('gather')(self, input_indices, axis, batch_dims) def var(self, axis=None, ddof=0, keepdims=False): """ diff --git a/mindspore/python/mindspore/ops/_grad/grad_array_ops.py b/mindspore/python/mindspore/ops/_grad/grad_array_ops.py index 1dfdaf90c80..7bd3a4a2ee0 100644 --- a/mindspore/python/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/python/mindspore/ops/_grad/grad_array_ops.py @@ -413,12 +413,12 @@ def get_bprop_slice(self): @_primexpr -def _generate_inverse_index(x_shape, axis): +def _generate_inverse_index(x_shape, axis, batch_dims=0): x_rank = len(x_shape) index = tuple(range(x_rank)) if axis < 0: axis += x_rank - perm = index[1:1 + axis] + (0,) + index[1 + axis:] + perm = index[:batch_dims] + index[batch_dims + 1:1 + axis] + (index[batch_dims],) + index[1 + axis:] return perm @@ -440,33 +440,57 @@ def _dyn_regenerate_output_shape(x_shp, ind_shp, axis): return out_shape -def _dyn_generate_shape_index(out_shape, indices_shape, axis): +def _dyn_generate_shape_index(out_shape, indices_shape, axis, batch_dims=0): """Get tranpose order""" out_rank = F.reshape(dyn_shape_op(out_shape), ()) ind_rank = F.reshape(dyn_shape_op(indices_shape), ()) if axis < 0: axis += out_rank - ind_rank + 1 perm_part1 = P.Range()(F.cast(0, mstype.int32), F.cast(20, mstype.int32), F.cast(1, mstype.int32)) - perm_part1 = perm_part1[axis: axis + ind_rank] + ind_end = axis + ind_rank - batch_dims + perm_part1 = perm_part1[axis: ind_end] index = P.Range()(F.cast(0, mstype.int32), F.cast(out_rank, mstype.int32), F.cast(1, mstype.int32)) - perm = P.Concat(0)((perm_part1, index[:axis], index[axis + ind_rank:])) + perm = F.hstack((index[:batch_dims], perm_part1, index[batch_dims:axis], index[ind_end:])) return perm -def _dyn_generate_inverse_index(x_shp, axis): +def _dyn_generate_inverse_index(x_shp, axis, batch_dims=0): """Get tranpose order""" x_rank = F.reshape(dyn_shape_op(x_shp), ()) index = P.Range()(F.cast(0, mstype.int32), F.cast(x_rank, mstype.int32), F.cast(1, mstype.int32)) if axis < 0: axis += x_rank - perm = P.Concat(0)((index[1: 1 + axis], Tensor([0], dtype=mstype.int32), index[1 + axis:])) + perm = F.hstack((index[:batch_dims], index[batch_dims + 1:1 + axis], index[batch_dims], index[1 + axis:])) return perm +def calculate_batch_gather(values, indices, x_shape, axis, batch_dims): + """Calculate gather grad with batch_dims""" + values_shape = dyn_shape_op(values) + batch_size = F.prod(x_shape[:batch_dims]) + batch_size = F.cast(batch_size, mstype.int32) + axis_dim = F.cast(x_shape[axis], mstype.int32) + + # Move batch dimension to first non-batch dimension + values = values.reshape((-1,) + values.shape[batch_dims:]) + indices = indices.reshape((-1,) + indices.shape[batch_dims:]) + offset = P.Range()(F.cast(0, mstype.int32), batch_size * axis_dim, axis_dim) + offset_shape = F.hstack([batch_size] + [Tensor(1, dtype=mstype.int32) for _ in range(len(indices.shape) - 1)]) + offset = reshape(offset, offset_shape) + indices = indices + offset + num_segments = batch_size * axis_dim + params_grad = unsorted_segment_sum(values, indices, num_segments) + grad_shape = dyn_shape_op(params_grad) + ret_shape = F.hstack([values_shape[:batch_dims], F.cast(axis_dim, mstype.int64), grad_shape[1:]]) + params_grad = reshape(params_grad, ret_shape) + return params_grad + + @bprop_getters.register(P.Gather) @bprop_getters.register(P.GatherV2) def get_bprop_gather_v2(self): """Generate bprop for GatherV2""" + batch_dims = self.batch_dims def _dyn_bprop_gather_v2(x, indices, axis, dout): """dyn shape bprop for GatherV2""" @@ -483,10 +507,13 @@ def get_bprop_gather_v2(self): dout = reshape(dout, out_shp) # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) - perm_1 = _dyn_generate_shape_index(out_shp, ind_shp, axis) + perm_1 = _dyn_generate_shape_index(out_shp, ind_shp, axis, batch_dims) values_transpose = transpose(dout, perm_1) - params_grad = unsorted_segment_sum(values_transpose, indices, x_shp[axis]) - perm_2 = _dyn_generate_inverse_index(x_shp, axis) + if batch_dims > 0: + params_grad = calculate_batch_gather(values_transpose, indices, x_shp, axis, batch_dims) + else: + params_grad = unsorted_segment_sum(values_transpose, indices, x_shp[axis]) + perm_2 = _dyn_generate_inverse_index(x_shp, axis, batch_dims) params_grad = transpose(params_grad, perm_2) return params_grad, zeros_like(orig_indices), zeros_like(axis) @@ -509,14 +536,15 @@ def get_bprop_gather_v2(self): out_shp = shape_op(dout) ind_shp = shape_op(indices) # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) - perm_1 = generate_shape_index(out_shp, ind_shp, axis) + perm_1 = generate_shape_index(out_shp, ind_shp, axis, batch_dims) values_transpose = transpose(dout, perm_1) - if F.is_sequence_value_unknown(shape_op(x)): - params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis]) + dyn_x_sape = dyn_shape_op(x) + if batch_dims > 0: + params_grad = calculate_batch_gather(values_transpose, indices, dyn_x_sape, axis, batch_dims) else: - params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) + params_grad = unsorted_segment_sum(values_transpose, indices, dyn_x_sape[axis]) # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) - perm_2 = _generate_inverse_index(x_shp, axis) + perm_2 = _generate_inverse_index(x_shp, axis, batch_dims) params_grad = transpose(params_grad, perm_2) return params_grad, zeros_like(orig_indices), zeros_like(axis) diff --git a/mindspore/python/mindspore/ops/_utils/utils.py b/mindspore/python/mindspore/ops/_utils/utils.py index fd4658d1211..81d9b6f4e75 100644 --- a/mindspore/python/mindspore/ops/_utils/utils.py +++ b/mindspore/python/mindspore/ops/_utils/utils.py @@ -127,12 +127,12 @@ def get_1d_shape(in_shape): @_primexpr -def generate_shape_index(out_shape, indices_shape, axis): +def generate_shape_index(out_shape, indices_shape, axis, batch_dims=0): out_rank = len(out_shape) ind_rank = len(indices_shape) if axis < 0: axis += out_rank - ind_rank + 1 - perm_part1 = tuple(range(axis, axis + ind_rank)) + perm_part1 = tuple(range(axis, axis + ind_rank - batch_dims)) index = tuple(range(out_rank)) - perm = perm_part1 + index[:axis] + index[axis + ind_rank:] + perm = index[:batch_dims] + perm_part1 + index[batch_dims:axis] + index[axis + ind_rank - batch_dims:] return perm diff --git a/mindspore/python/mindspore/ops/bprop_mindir/Gather_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/Gather_bprop.mindir index fb185f766a7..0c7d7258b76 100644 Binary files a/mindspore/python/mindspore/ops/bprop_mindir/Gather_bprop.mindir and b/mindspore/python/mindspore/ops/bprop_mindir/Gather_bprop.mindir differ diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index 71444e88e18..ac9f4b10b2f 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -3159,7 +3159,7 @@ def argsort(input_x, axis=-1, descending=False): return arg_sort -def gather(input_params, input_indices, axis): +def gather(input_params, input_indices, axis, batch_dims=0): r""" Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`. @@ -3231,7 +3231,8 @@ def gather(input_params, input_indices, axis): [5. 7.] [9. 11.]] """ - return gather_(input_params, input_indices, axis) + _gather = _get_cache_prim(P.Gather)(batch_dims) + return _gather(input_params, input_indices, axis) def gather_d(x, dim, index): diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 12ff38a0745..608d3691a0d 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -990,9 +990,10 @@ class Gather(Primitive): """ @prim_attr_register - def __init__(self): + def __init__(self, batch_dims=0): """Initialize Gather""" - self.add_prim_attr("batch_dims", 0) + validator.check_value_type("batch_dims", batch_dims, [int], self.name) + self.add_prim_attr("batch_dims", batch_dims) self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) @@ -1006,6 +1007,7 @@ class GatherV2(PrimitiveWithCheck): @prim_attr_register def __init__(self): """Initialize GatherV2""" + self.add_prim_attr("batch_dims", 0) self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) def __check__(self, params, indices, axis): @@ -5999,9 +6001,12 @@ class Range(PrimitiveWithCheck): self.add_prim_attr('maxlen', maxlen) def check_shape(self, start_shape, limit_shape, delta_shape): - validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name) - validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name) - validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name) + if not is_shape_unknown(start_shape): + validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name) + if not is_shape_unknown(limit_shape): + validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name) + if not is_shape_unknown(delta_shape): + validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name) def check_dtype(self, start_dtype, limit_dtype, delta_dtype): valid_dtypes = [mstype.int32, mstype.float32, mstype.int64, mstype.float64] diff --git a/tests/st/ops/cpu/test_gather_op.py b/tests/st/ops/cpu/test_gather_op.py index 445a068bb9d..7420f1d0764 100644 --- a/tests/st/ops/cpu/test_gather_op.py +++ b/tests/st/ops/cpu/test_gather_op.py @@ -315,3 +315,23 @@ def test_gather_tensor(data_type): assert out.shape == y_expect.shape np.allclose(out.asnumpy(), y_expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gather_batch_dims(): + """ + Feature: Gather + Description: test cases for Gather with batch_dims + Expectation: the result match to numpy + """ + x = np.arange(27).reshape(3, 3, 3).astype(np.int32) + indices = np.array([[0, 0], [1, 1], [1, 1]]).astype(np.int32) + axis = 1 + batch_dims = 1 + out = P.Gather(batch_dims)(Tensor(x), Tensor(indices), axis) + expect = np.array([[[0, 1, 2], [0, 1, 2]], + [[12, 13, 14], [12, 13, 14]], + [[21, 22, 23], [21, 22, 23]]]).astype(np.int32) + np.allclose(out.asnumpy(), expect)