!49175 add gather batch_dims

Merge pull request !49175 from 范吉斌/gather_batch_dims
This commit is contained in:
i-robot 2023-03-06 02:58:11 +00:00 committed by Gitee
commit a9de0b252b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
16 changed files with 327 additions and 142 deletions

View File

@ -22,6 +22,7 @@
#include "nnacl/gather_parameter.h"
#include "nnacl/base/gather_base.h"
#include "include/common/thread_pool.h"
#include "mindspore/core/ops/gather.h"
namespace mindspore {
namespace kernel {
@ -34,7 +35,12 @@ constexpr size_t kGatherInputParamsMaxDim = 7;
} // namespace
bool GatherCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
auto kernel_ptr = std::dynamic_pointer_cast<ops::Gather>(base_operator);
MS_EXCEPTION_IF_NULL(kernel_ptr);
batch_dims_ = kernel_ptr->get_batch_dims();
size_t input_num = inputs.size();
if (input_num != kGatherInputsNum) {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherCPUKernel needs 2.";
@ -62,6 +68,9 @@ int GatherCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::
if (IsDynamic(input_shape_) || IsDynamic(indices_shape_) || IsDynamic(output_shape_)) {
return KRET_UNKNOWN_SHAPE;
}
if (batch_dims_ < 0) {
batch_dims_ += SizeToLong(input_shapes_.size());
}
is_null_input_ = input_shape_.empty() || indices_shape_.empty() || output_shape_.empty();
if (is_null_input_) {
InitSizeLists();
@ -94,31 +103,45 @@ bool GatherCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
axis_ = axis_ + dims;
}
size_t outer_size = 1, inner_size = 1;
auto axis = static_cast<size_t>(axis_);
for (size_t i = 0; i < axis; ++i) {
size_t batch_size = 1;
size_t outer_size = 1;
size_t indices_element_size = 1;
size_t inner_size = 1;
auto axis = LongToSize(axis_);
auto batch_dims = LongToSize(batch_dims_);
for (size_t i = 0; i < batch_dims; i++) {
batch_size *= LongToSize(input_shape_.at(i));
}
for (size_t i = batch_dims; i < axis; ++i) {
outer_size *= LongToSize(input_shape_.at(i));
}
for (size_t i = axis + 1; i < input_shape_.size(); ++i) {
inner_size *= LongToSize(input_shape_.at(i));
}
size_t indices_element_size = 1;
for (size_t i = 0; i < indices_shape_.size(); i++) {
for (size_t i = batch_dims; i < indices_shape_.size(); i++) {
indices_element_size *= LongToSize(indices_shape_.at(i));
}
auto limit = LongToSize(input_shape_.at(axis));
size_t byte_inner_size = inner_size * sizeof(T);
size_t byte_out_stride = indices_element_size * byte_inner_size;
auto task = [&](size_t start, size_t end) {
int count = SizeToInt(end - start);
const int8_t *in = input_tensor + start * limit * byte_inner_size;
int8_t *out = output_addr + start * byte_out_stride;
int ret = Gather(in, count, byte_inner_size, limit, indices_data, indices_element_size, out, byte_out_stride);
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', error_code[" << ret << "]";
}
};
ParallelLaunchAutoSearch(task, outer_size, this, &parallel_search_info_);
for (size_t i = 0; i < batch_size; i++) {
auto output_ptr = output_addr + i * outer_size * byte_out_stride;
auto input_ptr = input_tensor + i * outer_size * byte_inner_size * limit;
auto indice_ptr = indices_data + i * indices_element_size;
auto task = [&](size_t start, size_t end) {
int count = SizeToInt(end - start);
const int8_t *in = input_ptr + start * limit * byte_inner_size;
int8_t *out = output_ptr + start * byte_out_stride;
int ret = Gather(in, count, byte_inner_size, limit, indice_ptr, indices_element_size, out, byte_out_stride);
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', error_code[" << ret << "]";
}
};
ParallelLaunchAutoSearch(task, outer_size, this, &parallel_search_info_);
}
return true;
}

View File

@ -79,6 +79,7 @@ class GatherCpuKernelMod : public NativeCpuKernelMod {
ShapeVector indices_shape_;
ShapeVector output_shape_;
int64_t axis_{0};
int64_t batch_dims_{0};
size_t input_type_size_ = 0;
size_t indices_type_size_ = 0;
size_t axis_type_size_ = 0;

View File

@ -17,6 +17,7 @@
#include "plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.h"
#include <memory>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "mindspore/core/ops/gather.h"
namespace mindspore {
namespace kernel {
@ -25,6 +26,12 @@ bool GatherV2FwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
if (kernel_name_ == ops::kNameGather) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::Gather>(base_operator);
MS_EXCEPTION_IF_NULL(kernel_ptr);
batch_dims_ = kernel_ptr->get_batch_dims();
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
@ -57,6 +64,9 @@ int GatherV2FwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
if (IsDynamic(input_shapes_) || IsDynamic(indices_shapes_) || IsDynamic(output_shapes_)) {
return KRET_UNKNOWN_SHAPE;
}
if (batch_dims_ < 0) {
batch_dims_ += SizeToLong(input_shapes_.size());
}
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name_, "input") ||
CHECK_SHAPE_NULL(indices_shapes_, kernel_name_, "indices") ||
CHECK_SHAPE_NULL(output_shapes_, kernel_name_, "output");
@ -100,7 +110,7 @@ bool GatherV2FwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs
MS_EXCEPTION_IF_NULL(input_addr);
MS_EXCEPTION_IF_NULL(indices_addr);
GatherV2(input_addr, indices_addr, output_addr, dims_[kIndex0], dims_[kIndex1], dims_[kIndex2],
GatherV2(input_addr, indices_addr, output_addr, dims_[kIndex0], dims_[kIndex1], dims_[kIndex2], dims_[kIndex3],
LongToSize(input_dim1), reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}

View File

@ -81,21 +81,27 @@ class GatherV2FwdGpuKernelMod : public NativeGpuKernelMod {
if (axis_ < 0) {
axis_ = axis_ + SizeToInt(input_shapes_.size());
}
int64_t dim_before_axis = 1;
for (size_t i = 0; i < std::min(IntToSize(axis_), output_shapes_.size()); i++) {
dim_before_axis *= output_shapes_[i];
size_t batch_size = 1;
size_t batch_dims = LongToSize(batch_dims_);
for (size_t i = 0; i < batch_dims; i++) {
batch_size *= LongToSize(input_shapes_[i]);
}
int64_t dim_of_indices = 1;
for (size_t i = 0; i < indices_shapes_.size(); i++) {
dim_of_indices *= indices_shapes_[i];
size_t dim_before_axis = 1;
for (size_t i = batch_dims; i < std::min(IntToSize(axis_), output_shapes_.size()); i++) {
dim_before_axis *= LongToSize(output_shapes_[i]);
}
int64_t dim_after_indices = 1;
for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) {
dim_after_indices *= output_shapes_[i];
size_t dim_of_indices = 1;
for (size_t i = batch_dims; i < indices_shapes_.size(); i++) {
dim_of_indices *= LongToSize(indices_shapes_[i]);
}
dims_[kIndex0] = dim_before_axis;
dims_[kIndex1] = dim_of_indices;
dims_[kIndex2] = dim_after_indices;
size_t dim_after_indices = 1;
for (size_t i = IntToSize(axis_) + 1; i < input_shapes_.size(); i++) {
dim_after_indices *= LongToSize(input_shapes_[i]);
}
dims_[kIndex0] = batch_size;
dims_[kIndex1] = dim_before_axis;
dims_[kIndex2] = dim_of_indices;
dims_[kIndex3] = dim_after_indices;
return;
}
@ -108,8 +114,9 @@ class GatherV2FwdGpuKernelMod : public NativeGpuKernelMod {
std::vector<int64_t> input_shapes_;
std::vector<int64_t> indices_shapes_;
std::vector<int64_t> output_shapes_;
int64_t dims_[kIndex3] = {};
size_t dims_[kIndex4] = {};
int64_t axis_ = 0;
int64_t batch_dims_{0};
bool is_null_input_ = false;
size_t input_type_size_ = 0;
size_t indices_type_size_ = 0;

View File

@ -44,101 +44,129 @@ __global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_di
return;
}
template <typename T, typename S>
void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream) {
size_t size = output_dim0 * output_dim1 * output_dim2;
GatherV2Kernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, output_dim1,
output_dim2, input_dim1);
__global__ void GatherV2WithBatchDimsKernel(T *input, S *indices, T *output, size_t batch_size, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1) {
size_t num = batch_size * output_dim0 * output_dim1 * output_dim2;
size_t i, j, k, n;
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num;
write_index += blockDim.x * gridDim.x) {
i = write_index / (output_dim0 * output_dim1 * output_dim2) % batch_size;
j = i * output_dim1 + write_index / output_dim2 % output_dim1;
n = write_index / (output_dim1 * output_dim2) % output_dim0;
k = write_index % output_dim2;
if ((indices[j] >= 0) && (indices[j] < input_dim1)) {
size_t read_index =
i * output_dim0 * input_dim1 * output_dim2 + n * input_dim1 * output_dim2 + indices[j] * output_dim2 + k;
output[write_index] = input[read_index];
} else {
output[write_index] = 0;
}
}
return;
}
template CUDA_LIB_EXPORT void GatherV2<Complex<float>, int>(Complex<float> *input, int *indices,
Complex<float> *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template <typename T, typename S>
void GatherV2(T *input, S *indices, T *output, size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream) {
size_t size = batch_size * output_dim0 * output_dim1 * output_dim2;
if (batch_size > 1) {
GatherV2WithBatchDimsKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(
input, indices, output, batch_size, output_dim0, output_dim1, output_dim2, input_dim1);
} else {
GatherV2Kernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, output_dim1,
output_dim2, input_dim1);
}
return;
}
template CUDA_LIB_EXPORT void GatherV2<Complex<float>, int>(Complex<float> *input, int *indices, Complex<float> *output,
size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<Complex<float>, int64_t>(Complex<float> *input, int64_t *indices,
Complex<float> *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
Complex<float> *output, size_t batch_size,
size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<Complex<double>, int>(Complex<double> *input, int *indices,
Complex<double> *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
Complex<double> *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<Complex<double>, int64_t>(Complex<double> *input, int64_t *indices,
Complex<double> *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<float, int>(float *input, int *indices, float *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<float, int64_t>(float *input, int64_t *indices, float *output,
Complex<double> *output, size_t batch_size,
size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<float, int>(float *input, int *indices, float *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<float, int64_t>(float *input, int64_t *indices, float *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<half, int>(half *input, int *indices, half *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<half, int64_t>(half *input, int64_t *indices, half *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<double, int>(double *input, int *indices, double *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<half, int>(half *input, int *indices, half *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<half, int64_t>(half *input, int64_t *indices, half *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<double, int>(double *input, int *indices, double *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<double, int64_t>(double *input, int64_t *indices, double *output,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int64_t, int>(int64_t *input, int *indices, int64_t *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int64_t, int>(int64_t *input, int *indices, int64_t *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int64_t, int64_t>(int64_t *input, int64_t *indices, int64_t *output,
size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int, int>(int *input, int *indices, int *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int, int64_t>(int *input, int64_t *indices, int *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int16_t, int>(int16_t *input, int *indices, int16_t *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int, int>(int *input, int *indices, int *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int, int64_t>(int *input, int64_t *indices, int *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int16_t, int>(int16_t *input, int *indices, int16_t *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int16_t, int64_t>(int16_t *input, int64_t *indices, int16_t *output,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int8_t, int>(int8_t *input, int *indices, int8_t *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int8_t, int>(int8_t *input, int *indices, int8_t *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<int8_t, int64_t>(int8_t *input, int64_t *indices, int8_t *output,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint64_t, int>(uint64_t *input, int *indices, uint64_t *output,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint64_t, int64_t>(uint64_t *input, int64_t *indices, uint64_t *output,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint32_t, int>(uint32_t *input, int *indices, uint32_t *output,
size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint32_t, int64_t>(uint32_t *input, int64_t *indices, uint32_t *output,
size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint16_t, int>(uint16_t *input, int *indices, uint16_t *output,
size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint16_t, int64_t>(uint16_t *input, int64_t *indices, uint16_t *output,
size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint8_t, int>(uint8_t *input, int *indices, uint8_t *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint8_t, int64_t>(uint8_t *input, int64_t *indices, uint8_t *output,
size_t batch_size, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<bool, int>(bool *input, int *indices, bool *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<bool, int64_t>(bool *input, int64_t *indices, bool *output, size_t batch_size,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint32_t, int64_t>(uint32_t *input, int64_t *indices, uint32_t *output,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint16_t, int>(uint16_t *input, int *indices, uint16_t *output,
size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint16_t, int64_t>(uint16_t *input, int64_t *indices, uint16_t *output,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint8_t, int>(uint8_t *input, int *indices, uint8_t *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<uint8_t, int64_t>(uint8_t *input, int64_t *indices, uint8_t *output,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<bool, int>(bool *input, int *indices, bool *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template CUDA_LIB_EXPORT void GatherV2<bool, int64_t>(bool *input, int64_t *indices, bool *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);

View File

@ -18,8 +18,8 @@
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_GATHERV2_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T, typename S>
CUDA_LIB_EXPORT void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
CUDA_LIB_EXPORT void GatherV2(T *input, S *indices, T *output, size_t batch_size, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template <typename T, typename S>
__global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1,

View File

@ -45,6 +45,56 @@
namespace mindspore {
namespace ops {
constexpr auto kBatchDims = "batch_dims";
void Gather::set_batch_dims(int64_t batch_dims) { (void)this->AddAttr(kBatchDims, api::MakeValue(batch_dims)); }
int64_t Gather::get_batch_dims() const { return GetValue<int64_t>(GetAttr(kBatchDims)); }
void CheckBatchDims(int64_t batch_dims, int64_t axis_val, const ShapeVector &params_shp, const ShapeVector &indices_shp,
const std::string &op_name) {
int64_t params_rank = static_cast<int64_t>(params_shp.size());
int64_t indices_rank = static_cast<int64_t>(indices_shp.size());
if (batch_dims < -indices_rank || batch_dims > indices_rank) {
MS_LOG(EXCEPTION) << "For '" << op_name << "', batch_dims must be in [" << -indices_rank << ", " << indices_rank
<< "], but got batch_dims: " << batch_dims;
}
if (batch_dims < 0) {
batch_dims += indices_rank;
}
if (batch_dims > params_rank) {
MS_LOG(EXCEPTION) << "For '" << op_name
<< "', batch_dims must be less than params's rank, but got batch_dims: " << batch_dims
<< ", oarams's rank: " << params_rank;
}
if (axis_val < batch_dims) {
MS_LOG(EXCEPTION) << "For '" << op_name
<< "', batch_dims must be less than or equal to axis, but got batch_dims: " << batch_dims
<< ", axis: " << axis_val;
}
for (size_t i = 0; i < LongToSize(batch_dims); i++) {
if (params_shp[i] != indices_shp[i]) {
MS_LOG(EXCEPTION) << "For '" << op_name << "', params.shape[" << i << "], should be equal to indices.shape[" << i
<< "], but got param.shape: " << params_shp << ", indices.shape: " << indices_shp;
}
}
}
ShapeVector CalcuateGatherWithBatchDimsOutputShape(int64_t batch_dims, int64_t axis_val, const ShapeVector &ind_vec,
const ShapeVector &params_vec) {
ShapeVector out_vec;
for (size_t i = 0; i < LongToSize(axis_val); i++) {
out_vec.push_back(params_vec[i]);
}
for (size_t i = LongToSize(batch_dims); i < ind_vec.size(); i++) {
out_vec.push_back(ind_vec[i]);
}
for (size_t i = LongToSize(axis_val) + 1; i < params_vec.size(); i++) {
out_vec.push_back(params_vec[i]);
}
return out_vec;
}
abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const std::string &op_name = primitive->name();
@ -108,6 +158,13 @@ abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::ve
if (axis_val < 0) {
axis_val += params_rank;
}
if (op_name == kNameGather) {
int64_t batch_dims = GetValue<int64_t>(primitive->GetAttr(kBatchDims));
CheckBatchDims(batch_dims, axis_val, params_shp, indices_shp, op_name);
out_shape = CalcuateGatherWithBatchDimsOutputShape(batch_dims, axis_val, indices_shp, params_shp);
return std::make_shared<abstract::Shape>(out_shape);
}
auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector &params_vec) -> ShapeVector {
ShapeVector out_vec;
(void)std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec));
@ -142,9 +199,11 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputsNum = 3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
if (!primitive->HasAttr(kBatchDims)) {
(void)primitive->AddAttr("batch_dims", MakeValue(static_cast<int64_t>(0))); // Add temporarily for ascend.
}
auto infer_type = GatherInferType(primitive, input_args);
auto infer_shape = GatherInferShape(primitive, input_args);
(void)primitive->AddAttr("batch_dims", MakeValue(static_cast<int64_t>(0))); // Add temporarily for gatherv2 on ascend
return abstract::MakeAbstract(infer_shape, infer_type);
}

View File

@ -35,6 +35,8 @@ class MIND_API Gather : public BaseOperator {
Gather() : BaseOperator(kNameGather) { InitIOName({"param", "indices", "axis"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Gather for the inputs.
void Init() const {}
void set_batch_dims(int64_t batch_dims);
int64_t get_batch_dims() const;
};
MIND_API abstract::AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -3606,12 +3606,12 @@ def gather_nd(input_x, indices):
return F.gather_nd(input_x, indices)
def gather(input_x, input_indices, axis):
def gather(input_x, input_indices, axis, batch_dims=0):
r"""
Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`.
Refer to :func:`mindspore.ops.gather` for more detail.
"""
return F.gather(input_x, input_indices, axis)
return F.gather(input_x, input_indices, axis, batch_dims)
def split(x, split_size_or_sections, axis=0):

View File

@ -2821,13 +2821,14 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
validator.check_value_type('indices', indices, (Tensor, Tensor_,), 'Tensor.gather_nd')
return tensor_operator_registry.get('gather_nd')(self, indices)
def gather(self, input_indices, axis):
def gather(self, input_indices, axis, batch_dims=0):
r"""
For details, please refer to :func:`mindspore.ops.gather`.
"""
self._init_check()
validator.check_is_int(axis, 'axis')
return tensor_operator_registry.get('gather')(self, input_indices, axis)
validator.check_is_int(batch_dims, "batch_dims")
return tensor_operator_registry.get('gather')(self, input_indices, axis, batch_dims)
def var(self, axis=None, ddof=0, keepdims=False):
"""

View File

@ -413,12 +413,12 @@ def get_bprop_slice(self):
@_primexpr
def _generate_inverse_index(x_shape, axis):
def _generate_inverse_index(x_shape, axis, batch_dims=0):
x_rank = len(x_shape)
index = tuple(range(x_rank))
if axis < 0:
axis += x_rank
perm = index[1:1 + axis] + (0,) + index[1 + axis:]
perm = index[:batch_dims] + index[batch_dims + 1:1 + axis] + (index[batch_dims],) + index[1 + axis:]
return perm
@ -440,33 +440,57 @@ def _dyn_regenerate_output_shape(x_shp, ind_shp, axis):
return out_shape
def _dyn_generate_shape_index(out_shape, indices_shape, axis):
def _dyn_generate_shape_index(out_shape, indices_shape, axis, batch_dims=0):
"""Get tranpose order"""
out_rank = F.reshape(dyn_shape_op(out_shape), ())
ind_rank = F.reshape(dyn_shape_op(indices_shape), ())
if axis < 0:
axis += out_rank - ind_rank + 1
perm_part1 = P.Range()(F.cast(0, mstype.int32), F.cast(20, mstype.int32), F.cast(1, mstype.int32))
perm_part1 = perm_part1[axis: axis + ind_rank]
ind_end = axis + ind_rank - batch_dims
perm_part1 = perm_part1[axis: ind_end]
index = P.Range()(F.cast(0, mstype.int32), F.cast(out_rank, mstype.int32), F.cast(1, mstype.int32))
perm = P.Concat(0)((perm_part1, index[:axis], index[axis + ind_rank:]))
perm = F.hstack((index[:batch_dims], perm_part1, index[batch_dims:axis], index[ind_end:]))
return perm
def _dyn_generate_inverse_index(x_shp, axis):
def _dyn_generate_inverse_index(x_shp, axis, batch_dims=0):
"""Get tranpose order"""
x_rank = F.reshape(dyn_shape_op(x_shp), ())
index = P.Range()(F.cast(0, mstype.int32), F.cast(x_rank, mstype.int32), F.cast(1, mstype.int32))
if axis < 0:
axis += x_rank
perm = P.Concat(0)((index[1: 1 + axis], Tensor([0], dtype=mstype.int32), index[1 + axis:]))
perm = F.hstack((index[:batch_dims], index[batch_dims + 1:1 + axis], index[batch_dims], index[1 + axis:]))
return perm
def calculate_batch_gather(values, indices, x_shape, axis, batch_dims):
"""Calculate gather grad with batch_dims"""
values_shape = dyn_shape_op(values)
batch_size = F.prod(x_shape[:batch_dims])
batch_size = F.cast(batch_size, mstype.int32)
axis_dim = F.cast(x_shape[axis], mstype.int32)
# Move batch dimension to first non-batch dimension
values = values.reshape((-1,) + values.shape[batch_dims:])
indices = indices.reshape((-1,) + indices.shape[batch_dims:])
offset = P.Range()(F.cast(0, mstype.int32), batch_size * axis_dim, axis_dim)
offset_shape = F.hstack([batch_size] + [Tensor(1, dtype=mstype.int32) for _ in range(len(indices.shape) - 1)])
offset = reshape(offset, offset_shape)
indices = indices + offset
num_segments = batch_size * axis_dim
params_grad = unsorted_segment_sum(values, indices, num_segments)
grad_shape = dyn_shape_op(params_grad)
ret_shape = F.hstack([values_shape[:batch_dims], F.cast(axis_dim, mstype.int64), grad_shape[1:]])
params_grad = reshape(params_grad, ret_shape)
return params_grad
@bprop_getters.register(P.Gather)
@bprop_getters.register(P.GatherV2)
def get_bprop_gather_v2(self):
"""Generate bprop for GatherV2"""
batch_dims = self.batch_dims
def _dyn_bprop_gather_v2(x, indices, axis, dout):
"""dyn shape bprop for GatherV2"""
@ -483,10 +507,13 @@ def get_bprop_gather_v2(self):
dout = reshape(dout, out_shp)
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
perm_1 = _dyn_generate_shape_index(out_shp, ind_shp, axis)
perm_1 = _dyn_generate_shape_index(out_shp, ind_shp, axis, batch_dims)
values_transpose = transpose(dout, perm_1)
params_grad = unsorted_segment_sum(values_transpose, indices, x_shp[axis])
perm_2 = _dyn_generate_inverse_index(x_shp, axis)
if batch_dims > 0:
params_grad = calculate_batch_gather(values_transpose, indices, x_shp, axis, batch_dims)
else:
params_grad = unsorted_segment_sum(values_transpose, indices, x_shp[axis])
perm_2 = _dyn_generate_inverse_index(x_shp, axis, batch_dims)
params_grad = transpose(params_grad, perm_2)
return params_grad, zeros_like(orig_indices), zeros_like(axis)
@ -509,14 +536,15 @@ def get_bprop_gather_v2(self):
out_shp = shape_op(dout)
ind_shp = shape_op(indices)
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
perm_1 = generate_shape_index(out_shp, ind_shp, axis)
perm_1 = generate_shape_index(out_shp, ind_shp, axis, batch_dims)
values_transpose = transpose(dout, perm_1)
if F.is_sequence_value_unknown(shape_op(x)):
params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis])
dyn_x_sape = dyn_shape_op(x)
if batch_dims > 0:
params_grad = calculate_batch_gather(values_transpose, indices, dyn_x_sape, axis, batch_dims)
else:
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
params_grad = unsorted_segment_sum(values_transpose, indices, dyn_x_sape[axis])
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
perm_2 = _generate_inverse_index(x_shp, axis)
perm_2 = _generate_inverse_index(x_shp, axis, batch_dims)
params_grad = transpose(params_grad, perm_2)
return params_grad, zeros_like(orig_indices), zeros_like(axis)

View File

@ -127,12 +127,12 @@ def get_1d_shape(in_shape):
@_primexpr
def generate_shape_index(out_shape, indices_shape, axis):
def generate_shape_index(out_shape, indices_shape, axis, batch_dims=0):
out_rank = len(out_shape)
ind_rank = len(indices_shape)
if axis < 0:
axis += out_rank - ind_rank + 1
perm_part1 = tuple(range(axis, axis + ind_rank))
perm_part1 = tuple(range(axis, axis + ind_rank - batch_dims))
index = tuple(range(out_rank))
perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
perm = index[:batch_dims] + perm_part1 + index[batch_dims:axis] + index[axis + ind_rank - batch_dims:]
return perm

View File

@ -3159,7 +3159,7 @@ def argsort(input_x, axis=-1, descending=False):
return arg_sort
def gather(input_params, input_indices, axis):
def gather(input_params, input_indices, axis, batch_dims=0):
r"""
Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`.
@ -3231,7 +3231,8 @@ def gather(input_params, input_indices, axis):
[5. 7.]
[9. 11.]]
"""
return gather_(input_params, input_indices, axis)
_gather = _get_cache_prim(P.Gather)(batch_dims)
return _gather(input_params, input_indices, axis)
def gather_d(x, dim, index):

View File

@ -990,9 +990,10 @@ class Gather(Primitive):
"""
@prim_attr_register
def __init__(self):
def __init__(self, batch_dims=0):
"""Initialize Gather"""
self.add_prim_attr("batch_dims", 0)
validator.check_value_type("batch_dims", batch_dims, [int], self.name)
self.add_prim_attr("batch_dims", batch_dims)
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
@ -1006,6 +1007,7 @@ class GatherV2(PrimitiveWithCheck):
@prim_attr_register
def __init__(self):
"""Initialize GatherV2"""
self.add_prim_attr("batch_dims", 0)
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
def __check__(self, params, indices, axis):
@ -5999,9 +6001,12 @@ class Range(PrimitiveWithCheck):
self.add_prim_attr('maxlen', maxlen)
def check_shape(self, start_shape, limit_shape, delta_shape):
validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name)
validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name)
validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name)
if not is_shape_unknown(start_shape):
validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name)
if not is_shape_unknown(limit_shape):
validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name)
if not is_shape_unknown(delta_shape):
validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name)
def check_dtype(self, start_dtype, limit_dtype, delta_dtype):
valid_dtypes = [mstype.int32, mstype.float32, mstype.int64, mstype.float64]

View File

@ -315,3 +315,23 @@ def test_gather_tensor(data_type):
assert out.shape == y_expect.shape
np.allclose(out.asnumpy(), y_expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gather_batch_dims():
"""
Feature: Gather
Description: test cases for Gather with batch_dims
Expectation: the result match to numpy
"""
x = np.arange(27).reshape(3, 3, 3).astype(np.int32)
indices = np.array([[0, 0], [1, 1], [1, 1]]).astype(np.int32)
axis = 1
batch_dims = 1
out = P.Gather(batch_dims)(Tensor(x), Tensor(indices), axis)
expect = np.array([[[0, 1, 2], [0, 1, 2]],
[[12, 13, 14], [12, 13, 14]],
[[21, 22, 23], [21, 22, 23]]]).astype(np.int32)
np.allclose(out.asnumpy(), expect)