forked from mindspore-Ecosystem/mindspore
!49175 add gather batch_dims
Merge pull request !49175 from 范吉斌/gather_batch_dims
This commit is contained in:
commit
a9de0b252b
|
@ -22,6 +22,7 @@
|
|||
#include "nnacl/gather_parameter.h"
|
||||
#include "nnacl/base/gather_base.h"
|
||||
#include "include/common/thread_pool.h"
|
||||
#include "mindspore/core/ops/gather.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -34,7 +35,12 @@ constexpr size_t kGatherInputParamsMaxDim = 7;
|
|||
} // namespace
|
||||
bool GatherCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Gather>(base_operator);
|
||||
MS_EXCEPTION_IF_NULL(kernel_ptr);
|
||||
batch_dims_ = kernel_ptr->get_batch_dims();
|
||||
|
||||
size_t input_num = inputs.size();
|
||||
if (input_num != kGatherInputsNum) {
|
||||
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherCPUKernel needs 2.";
|
||||
|
@ -62,6 +68,9 @@ int GatherCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::
|
|||
if (IsDynamic(input_shape_) || IsDynamic(indices_shape_) || IsDynamic(output_shape_)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
if (batch_dims_ < 0) {
|
||||
batch_dims_ += SizeToLong(input_shapes_.size());
|
||||
}
|
||||
is_null_input_ = input_shape_.empty() || indices_shape_.empty() || output_shape_.empty();
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
|
@ -94,31 +103,45 @@ bool GatherCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
|
|||
axis_ = axis_ + dims;
|
||||
}
|
||||
|
||||
size_t outer_size = 1, inner_size = 1;
|
||||
auto axis = static_cast<size_t>(axis_);
|
||||
for (size_t i = 0; i < axis; ++i) {
|
||||
size_t batch_size = 1;
|
||||
size_t outer_size = 1;
|
||||
size_t indices_element_size = 1;
|
||||
size_t inner_size = 1;
|
||||
auto axis = LongToSize(axis_);
|
||||
auto batch_dims = LongToSize(batch_dims_);
|
||||
for (size_t i = 0; i < batch_dims; i++) {
|
||||
batch_size *= LongToSize(input_shape_.at(i));
|
||||
}
|
||||
for (size_t i = batch_dims; i < axis; ++i) {
|
||||
outer_size *= LongToSize(input_shape_.at(i));
|
||||
}
|
||||
for (size_t i = axis + 1; i < input_shape_.size(); ++i) {
|
||||
inner_size *= LongToSize(input_shape_.at(i));
|
||||
}
|
||||
size_t indices_element_size = 1;
|
||||
for (size_t i = 0; i < indices_shape_.size(); i++) {
|
||||
for (size_t i = batch_dims; i < indices_shape_.size(); i++) {
|
||||
indices_element_size *= LongToSize(indices_shape_.at(i));
|
||||
}
|
||||
|
||||
auto limit = LongToSize(input_shape_.at(axis));
|
||||
size_t byte_inner_size = inner_size * sizeof(T);
|
||||
size_t byte_out_stride = indices_element_size * byte_inner_size;
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
int count = SizeToInt(end - start);
|
||||
const int8_t *in = input_tensor + start * limit * byte_inner_size;
|
||||
int8_t *out = output_addr + start * byte_out_stride;
|
||||
int ret = Gather(in, count, byte_inner_size, limit, indices_data, indices_element_size, out, byte_out_stride);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', error_code[" << ret << "]";
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, outer_size, this, ¶llel_search_info_);
|
||||
for (size_t i = 0; i < batch_size; i++) {
|
||||
auto output_ptr = output_addr + i * outer_size * byte_out_stride;
|
||||
auto input_ptr = input_tensor + i * outer_size * byte_inner_size * limit;
|
||||
auto indice_ptr = indices_data + i * indices_element_size;
|
||||
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
int count = SizeToInt(end - start);
|
||||
const int8_t *in = input_ptr + start * limit * byte_inner_size;
|
||||
int8_t *out = output_ptr + start * byte_out_stride;
|
||||
int ret = Gather(in, count, byte_inner_size, limit, indice_ptr, indices_element_size, out, byte_out_stride);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', error_code[" << ret << "]";
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, outer_size, this, ¶llel_search_info_);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -79,6 +79,7 @@ class GatherCpuKernelMod : public NativeCpuKernelMod {
|
|||
ShapeVector indices_shape_;
|
||||
ShapeVector output_shape_;
|
||||
int64_t axis_{0};
|
||||
int64_t batch_dims_{0};
|
||||
size_t input_type_size_ = 0;
|
||||
size_t indices_type_size_ = 0;
|
||||
size_t axis_type_size_ = 0;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.h"
|
||||
#include <memory>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
#include "mindspore/core/ops/gather.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -25,6 +26,12 @@ bool GatherV2FwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
|
|||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
if (kernel_name_ == ops::kNameGather) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Gather>(base_operator);
|
||||
MS_EXCEPTION_IF_NULL(kernel_ptr);
|
||||
batch_dims_ = kernel_ptr->get_batch_dims();
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
|
@ -57,6 +64,9 @@ int GatherV2FwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
if (IsDynamic(input_shapes_) || IsDynamic(indices_shapes_) || IsDynamic(output_shapes_)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
if (batch_dims_ < 0) {
|
||||
batch_dims_ += SizeToLong(input_shapes_.size());
|
||||
}
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name_, "input") ||
|
||||
CHECK_SHAPE_NULL(indices_shapes_, kernel_name_, "indices") ||
|
||||
CHECK_SHAPE_NULL(output_shapes_, kernel_name_, "output");
|
||||
|
@ -100,7 +110,7 @@ bool GatherV2FwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs
|
|||
|
||||
MS_EXCEPTION_IF_NULL(input_addr);
|
||||
MS_EXCEPTION_IF_NULL(indices_addr);
|
||||
GatherV2(input_addr, indices_addr, output_addr, dims_[kIndex0], dims_[kIndex1], dims_[kIndex2],
|
||||
GatherV2(input_addr, indices_addr, output_addr, dims_[kIndex0], dims_[kIndex1], dims_[kIndex2], dims_[kIndex3],
|
||||
LongToSize(input_dim1), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -81,21 +81,27 @@ class GatherV2FwdGpuKernelMod : public NativeGpuKernelMod {
|
|||
if (axis_ < 0) {
|
||||
axis_ = axis_ + SizeToInt(input_shapes_.size());
|
||||
}
|
||||
int64_t dim_before_axis = 1;
|
||||
for (size_t i = 0; i < std::min(IntToSize(axis_), output_shapes_.size()); i++) {
|
||||
dim_before_axis *= output_shapes_[i];
|
||||
size_t batch_size = 1;
|
||||
size_t batch_dims = LongToSize(batch_dims_);
|
||||
for (size_t i = 0; i < batch_dims; i++) {
|
||||
batch_size *= LongToSize(input_shapes_[i]);
|
||||
}
|
||||
int64_t dim_of_indices = 1;
|
||||
for (size_t i = 0; i < indices_shapes_.size(); i++) {
|
||||
dim_of_indices *= indices_shapes_[i];
|
||||
size_t dim_before_axis = 1;
|
||||
for (size_t i = batch_dims; i < std::min(IntToSize(axis_), output_shapes_.size()); i++) {
|
||||
dim_before_axis *= LongToSize(output_shapes_[i]);
|
||||
}
|
||||
int64_t dim_after_indices = 1;
|
||||
for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) {
|
||||
dim_after_indices *= output_shapes_[i];
|
||||
size_t dim_of_indices = 1;
|
||||
for (size_t i = batch_dims; i < indices_shapes_.size(); i++) {
|
||||
dim_of_indices *= LongToSize(indices_shapes_[i]);
|
||||
}
|
||||
dims_[kIndex0] = dim_before_axis;
|
||||
dims_[kIndex1] = dim_of_indices;
|
||||
dims_[kIndex2] = dim_after_indices;
|
||||
size_t dim_after_indices = 1;
|
||||
for (size_t i = IntToSize(axis_) + 1; i < input_shapes_.size(); i++) {
|
||||
dim_after_indices *= LongToSize(input_shapes_[i]);
|
||||
}
|
||||
dims_[kIndex0] = batch_size;
|
||||
dims_[kIndex1] = dim_before_axis;
|
||||
dims_[kIndex2] = dim_of_indices;
|
||||
dims_[kIndex3] = dim_after_indices;
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -108,8 +114,9 @@ class GatherV2FwdGpuKernelMod : public NativeGpuKernelMod {
|
|||
std::vector<int64_t> input_shapes_;
|
||||
std::vector<int64_t> indices_shapes_;
|
||||
std::vector<int64_t> output_shapes_;
|
||||
int64_t dims_[kIndex3] = {};
|
||||
size_t dims_[kIndex4] = {};
|
||||
int64_t axis_ = 0;
|
||||
int64_t batch_dims_{0};
|
||||
bool is_null_input_ = false;
|
||||
size_t input_type_size_ = 0;
|
||||
size_t indices_type_size_ = 0;
|
||||
|
|
|
@ -44,101 +44,129 @@ __global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_di
|
|||
return;
|
||||
}
|
||||
template <typename T, typename S>
|
||||
void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream) {
|
||||
size_t size = output_dim0 * output_dim1 * output_dim2;
|
||||
GatherV2Kernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, output_dim1,
|
||||
output_dim2, input_dim1);
|
||||
__global__ void GatherV2WithBatchDimsKernel(T *input, S *indices, T *output, size_t batch_size, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1) {
|
||||
size_t num = batch_size * output_dim0 * output_dim1 * output_dim2;
|
||||
size_t i, j, k, n;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num;
|
||||
write_index += blockDim.x * gridDim.x) {
|
||||
i = write_index / (output_dim0 * output_dim1 * output_dim2) % batch_size;
|
||||
j = i * output_dim1 + write_index / output_dim2 % output_dim1;
|
||||
n = write_index / (output_dim1 * output_dim2) % output_dim0;
|
||||
k = write_index % output_dim2;
|
||||
|
||||
if ((indices[j] >= 0) && (indices[j] < input_dim1)) {
|
||||
size_t read_index =
|
||||
i * output_dim0 * input_dim1 * output_dim2 + n * input_dim1 * output_dim2 + indices[j] * output_dim2 + k;
|
||||
output[write_index] = input[read_index];
|
||||
} else {
|
||||
output[write_index] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void GatherV2<Complex<float>, int>(Complex<float> *input, int *indices,
|
||||
Complex<float> *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template <typename T, typename S>
|
||||
void GatherV2(T *input, S *indices, T *output, size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream) {
|
||||
size_t size = batch_size * output_dim0 * output_dim1 * output_dim2;
|
||||
if (batch_size > 1) {
|
||||
GatherV2WithBatchDimsKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(
|
||||
input, indices, output, batch_size, output_dim0, output_dim1, output_dim2, input_dim1);
|
||||
} else {
|
||||
GatherV2Kernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, output_dim1,
|
||||
output_dim2, input_dim1);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void GatherV2<Complex<float>, int>(Complex<float> *input, int *indices, Complex<float> *output,
|
||||
size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<Complex<float>, int64_t>(Complex<float> *input, int64_t *indices,
|
||||
Complex<float> *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
Complex<float> *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<Complex<double>, int>(Complex<double> *input, int *indices,
|
||||
Complex<double> *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
Complex<double> *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<Complex<double>, int64_t>(Complex<double> *input, int64_t *indices,
|
||||
Complex<double> *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<float, int>(float *input, int *indices, float *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<float, int64_t>(float *input, int64_t *indices, float *output,
|
||||
Complex<double> *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<float, int>(float *input, int *indices, float *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<float, int64_t>(float *input, int64_t *indices, float *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<half, int>(half *input, int *indices, half *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<half, int64_t>(half *input, int64_t *indices, half *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<double, int>(double *input, int *indices, double *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<half, int>(half *input, int *indices, half *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<half, int64_t>(half *input, int64_t *indices, half *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<double, int>(double *input, int *indices, double *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<double, int64_t>(double *input, int64_t *indices, double *output,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int64_t, int>(int64_t *input, int *indices, int64_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int64_t, int>(int64_t *input, int *indices, int64_t *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int64_t, int64_t>(int64_t *input, int64_t *indices, int64_t *output,
|
||||
size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int, int>(int *input, int *indices, int *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int, int64_t>(int *input, int64_t *indices, int *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int16_t, int>(int16_t *input, int *indices, int16_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int, int>(int *input, int *indices, int *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int, int64_t>(int *input, int64_t *indices, int *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int16_t, int>(int16_t *input, int *indices, int16_t *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int16_t, int64_t>(int16_t *input, int64_t *indices, int16_t *output,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int8_t, int>(int8_t *input, int *indices, int8_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int8_t, int>(int8_t *input, int *indices, int8_t *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<int8_t, int64_t>(int8_t *input, int64_t *indices, int8_t *output,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint64_t, int>(uint64_t *input, int *indices, uint64_t *output,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint64_t, int64_t>(uint64_t *input, int64_t *indices, uint64_t *output,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint32_t, int>(uint32_t *input, int *indices, uint32_t *output,
|
||||
size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint32_t, int64_t>(uint32_t *input, int64_t *indices, uint32_t *output,
|
||||
size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint16_t, int>(uint16_t *input, int *indices, uint16_t *output,
|
||||
size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint16_t, int64_t>(uint16_t *input, int64_t *indices, uint16_t *output,
|
||||
size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint8_t, int>(uint8_t *input, int *indices, uint8_t *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint8_t, int64_t>(uint8_t *input, int64_t *indices, uint8_t *output,
|
||||
size_t batch_size, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<bool, int>(bool *input, int *indices, bool *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<bool, int64_t>(bool *input, int64_t *indices, bool *output, size_t batch_size,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint32_t, int64_t>(uint32_t *input, int64_t *indices, uint32_t *output,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint16_t, int>(uint16_t *input, int *indices, uint16_t *output,
|
||||
size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint16_t, int64_t>(uint16_t *input, int64_t *indices, uint16_t *output,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint8_t, int>(uint8_t *input, int *indices, uint8_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<uint8_t, int64_t>(uint8_t *input, int64_t *indices, uint8_t *output,
|
||||
size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<bool, int>(bool *input, int *indices, bool *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void GatherV2<bool, int64_t>(bool *input, int64_t *indices, bool *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_GATHERV2_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
template <typename T, typename S>
|
||||
CUDA_LIB_EXPORT void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
CUDA_LIB_EXPORT void GatherV2(T *input, S *indices, T *output, size_t batch_size, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1,
|
||||
|
|
|
@ -45,6 +45,56 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kBatchDims = "batch_dims";
|
||||
|
||||
void Gather::set_batch_dims(int64_t batch_dims) { (void)this->AddAttr(kBatchDims, api::MakeValue(batch_dims)); }
|
||||
|
||||
int64_t Gather::get_batch_dims() const { return GetValue<int64_t>(GetAttr(kBatchDims)); }
|
||||
|
||||
void CheckBatchDims(int64_t batch_dims, int64_t axis_val, const ShapeVector ¶ms_shp, const ShapeVector &indices_shp,
|
||||
const std::string &op_name) {
|
||||
int64_t params_rank = static_cast<int64_t>(params_shp.size());
|
||||
int64_t indices_rank = static_cast<int64_t>(indices_shp.size());
|
||||
if (batch_dims < -indices_rank || batch_dims > indices_rank) {
|
||||
MS_LOG(EXCEPTION) << "For '" << op_name << "', batch_dims must be in [" << -indices_rank << ", " << indices_rank
|
||||
<< "], but got batch_dims: " << batch_dims;
|
||||
}
|
||||
if (batch_dims < 0) {
|
||||
batch_dims += indices_rank;
|
||||
}
|
||||
if (batch_dims > params_rank) {
|
||||
MS_LOG(EXCEPTION) << "For '" << op_name
|
||||
<< "', batch_dims must be less than params's rank, but got batch_dims: " << batch_dims
|
||||
<< ", oarams's rank: " << params_rank;
|
||||
}
|
||||
if (axis_val < batch_dims) {
|
||||
MS_LOG(EXCEPTION) << "For '" << op_name
|
||||
<< "', batch_dims must be less than or equal to axis, but got batch_dims: " << batch_dims
|
||||
<< ", axis: " << axis_val;
|
||||
}
|
||||
for (size_t i = 0; i < LongToSize(batch_dims); i++) {
|
||||
if (params_shp[i] != indices_shp[i]) {
|
||||
MS_LOG(EXCEPTION) << "For '" << op_name << "', params.shape[" << i << "], should be equal to indices.shape[" << i
|
||||
<< "], but got param.shape: " << params_shp << ", indices.shape: " << indices_shp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ShapeVector CalcuateGatherWithBatchDimsOutputShape(int64_t batch_dims, int64_t axis_val, const ShapeVector &ind_vec,
|
||||
const ShapeVector ¶ms_vec) {
|
||||
ShapeVector out_vec;
|
||||
for (size_t i = 0; i < LongToSize(axis_val); i++) {
|
||||
out_vec.push_back(params_vec[i]);
|
||||
}
|
||||
for (size_t i = LongToSize(batch_dims); i < ind_vec.size(); i++) {
|
||||
out_vec.push_back(ind_vec[i]);
|
||||
}
|
||||
for (size_t i = LongToSize(axis_val) + 1; i < params_vec.size(); i++) {
|
||||
out_vec.push_back(params_vec[i]);
|
||||
}
|
||||
return out_vec;
|
||||
}
|
||||
|
||||
abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const std::string &op_name = primitive->name();
|
||||
|
@ -108,6 +158,13 @@ abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
if (axis_val < 0) {
|
||||
axis_val += params_rank;
|
||||
}
|
||||
if (op_name == kNameGather) {
|
||||
int64_t batch_dims = GetValue<int64_t>(primitive->GetAttr(kBatchDims));
|
||||
CheckBatchDims(batch_dims, axis_val, params_shp, indices_shp, op_name);
|
||||
out_shape = CalcuateGatherWithBatchDimsOutputShape(batch_dims, axis_val, indices_shp, params_shp);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector ¶ms_vec) -> ShapeVector {
|
||||
ShapeVector out_vec;
|
||||
(void)std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec));
|
||||
|
@ -142,9 +199,11 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 3;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
if (!primitive->HasAttr(kBatchDims)) {
|
||||
(void)primitive->AddAttr("batch_dims", MakeValue(static_cast<int64_t>(0))); // Add temporarily for ascend.
|
||||
}
|
||||
auto infer_type = GatherInferType(primitive, input_args);
|
||||
auto infer_shape = GatherInferShape(primitive, input_args);
|
||||
(void)primitive->AddAttr("batch_dims", MakeValue(static_cast<int64_t>(0))); // Add temporarily for gatherv2 on ascend
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
|
|
|
@ -35,6 +35,8 @@ class MIND_API Gather : public BaseOperator {
|
|||
Gather() : BaseOperator(kNameGather) { InitIOName({"param", "indices", "axis"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Gather for the inputs.
|
||||
void Init() const {}
|
||||
void set_batch_dims(int64_t batch_dims);
|
||||
int64_t get_batch_dims() const;
|
||||
};
|
||||
|
||||
MIND_API abstract::AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -3606,12 +3606,12 @@ def gather_nd(input_x, indices):
|
|||
return F.gather_nd(input_x, indices)
|
||||
|
||||
|
||||
def gather(input_x, input_indices, axis):
|
||||
def gather(input_x, input_indices, axis, batch_dims=0):
|
||||
r"""
|
||||
Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`.
|
||||
Refer to :func:`mindspore.ops.gather` for more detail.
|
||||
"""
|
||||
return F.gather(input_x, input_indices, axis)
|
||||
return F.gather(input_x, input_indices, axis, batch_dims)
|
||||
|
||||
|
||||
def split(x, split_size_or_sections, axis=0):
|
||||
|
|
|
@ -2821,13 +2821,14 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|||
validator.check_value_type('indices', indices, (Tensor, Tensor_,), 'Tensor.gather_nd')
|
||||
return tensor_operator_registry.get('gather_nd')(self, indices)
|
||||
|
||||
def gather(self, input_indices, axis):
|
||||
def gather(self, input_indices, axis, batch_dims=0):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.gather`.
|
||||
"""
|
||||
self._init_check()
|
||||
validator.check_is_int(axis, 'axis')
|
||||
return tensor_operator_registry.get('gather')(self, input_indices, axis)
|
||||
validator.check_is_int(batch_dims, "batch_dims")
|
||||
return tensor_operator_registry.get('gather')(self, input_indices, axis, batch_dims)
|
||||
|
||||
def var(self, axis=None, ddof=0, keepdims=False):
|
||||
"""
|
||||
|
|
|
@ -413,12 +413,12 @@ def get_bprop_slice(self):
|
|||
|
||||
|
||||
@_primexpr
|
||||
def _generate_inverse_index(x_shape, axis):
|
||||
def _generate_inverse_index(x_shape, axis, batch_dims=0):
|
||||
x_rank = len(x_shape)
|
||||
index = tuple(range(x_rank))
|
||||
if axis < 0:
|
||||
axis += x_rank
|
||||
perm = index[1:1 + axis] + (0,) + index[1 + axis:]
|
||||
perm = index[:batch_dims] + index[batch_dims + 1:1 + axis] + (index[batch_dims],) + index[1 + axis:]
|
||||
return perm
|
||||
|
||||
|
||||
|
@ -440,33 +440,57 @@ def _dyn_regenerate_output_shape(x_shp, ind_shp, axis):
|
|||
return out_shape
|
||||
|
||||
|
||||
def _dyn_generate_shape_index(out_shape, indices_shape, axis):
|
||||
def _dyn_generate_shape_index(out_shape, indices_shape, axis, batch_dims=0):
|
||||
"""Get tranpose order"""
|
||||
out_rank = F.reshape(dyn_shape_op(out_shape), ())
|
||||
ind_rank = F.reshape(dyn_shape_op(indices_shape), ())
|
||||
if axis < 0:
|
||||
axis += out_rank - ind_rank + 1
|
||||
perm_part1 = P.Range()(F.cast(0, mstype.int32), F.cast(20, mstype.int32), F.cast(1, mstype.int32))
|
||||
perm_part1 = perm_part1[axis: axis + ind_rank]
|
||||
ind_end = axis + ind_rank - batch_dims
|
||||
perm_part1 = perm_part1[axis: ind_end]
|
||||
index = P.Range()(F.cast(0, mstype.int32), F.cast(out_rank, mstype.int32), F.cast(1, mstype.int32))
|
||||
perm = P.Concat(0)((perm_part1, index[:axis], index[axis + ind_rank:]))
|
||||
perm = F.hstack((index[:batch_dims], perm_part1, index[batch_dims:axis], index[ind_end:]))
|
||||
return perm
|
||||
|
||||
|
||||
def _dyn_generate_inverse_index(x_shp, axis):
|
||||
def _dyn_generate_inverse_index(x_shp, axis, batch_dims=0):
|
||||
"""Get tranpose order"""
|
||||
x_rank = F.reshape(dyn_shape_op(x_shp), ())
|
||||
index = P.Range()(F.cast(0, mstype.int32), F.cast(x_rank, mstype.int32), F.cast(1, mstype.int32))
|
||||
if axis < 0:
|
||||
axis += x_rank
|
||||
perm = P.Concat(0)((index[1: 1 + axis], Tensor([0], dtype=mstype.int32), index[1 + axis:]))
|
||||
perm = F.hstack((index[:batch_dims], index[batch_dims + 1:1 + axis], index[batch_dims], index[1 + axis:]))
|
||||
return perm
|
||||
|
||||
|
||||
def calculate_batch_gather(values, indices, x_shape, axis, batch_dims):
|
||||
"""Calculate gather grad with batch_dims"""
|
||||
values_shape = dyn_shape_op(values)
|
||||
batch_size = F.prod(x_shape[:batch_dims])
|
||||
batch_size = F.cast(batch_size, mstype.int32)
|
||||
axis_dim = F.cast(x_shape[axis], mstype.int32)
|
||||
|
||||
# Move batch dimension to first non-batch dimension
|
||||
values = values.reshape((-1,) + values.shape[batch_dims:])
|
||||
indices = indices.reshape((-1,) + indices.shape[batch_dims:])
|
||||
offset = P.Range()(F.cast(0, mstype.int32), batch_size * axis_dim, axis_dim)
|
||||
offset_shape = F.hstack([batch_size] + [Tensor(1, dtype=mstype.int32) for _ in range(len(indices.shape) - 1)])
|
||||
offset = reshape(offset, offset_shape)
|
||||
indices = indices + offset
|
||||
num_segments = batch_size * axis_dim
|
||||
params_grad = unsorted_segment_sum(values, indices, num_segments)
|
||||
grad_shape = dyn_shape_op(params_grad)
|
||||
ret_shape = F.hstack([values_shape[:batch_dims], F.cast(axis_dim, mstype.int64), grad_shape[1:]])
|
||||
params_grad = reshape(params_grad, ret_shape)
|
||||
return params_grad
|
||||
|
||||
|
||||
@bprop_getters.register(P.Gather)
|
||||
@bprop_getters.register(P.GatherV2)
|
||||
def get_bprop_gather_v2(self):
|
||||
"""Generate bprop for GatherV2"""
|
||||
batch_dims = self.batch_dims
|
||||
|
||||
def _dyn_bprop_gather_v2(x, indices, axis, dout):
|
||||
"""dyn shape bprop for GatherV2"""
|
||||
|
@ -483,10 +507,13 @@ def get_bprop_gather_v2(self):
|
|||
dout = reshape(dout, out_shp)
|
||||
|
||||
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
|
||||
perm_1 = _dyn_generate_shape_index(out_shp, ind_shp, axis)
|
||||
perm_1 = _dyn_generate_shape_index(out_shp, ind_shp, axis, batch_dims)
|
||||
values_transpose = transpose(dout, perm_1)
|
||||
params_grad = unsorted_segment_sum(values_transpose, indices, x_shp[axis])
|
||||
perm_2 = _dyn_generate_inverse_index(x_shp, axis)
|
||||
if batch_dims > 0:
|
||||
params_grad = calculate_batch_gather(values_transpose, indices, x_shp, axis, batch_dims)
|
||||
else:
|
||||
params_grad = unsorted_segment_sum(values_transpose, indices, x_shp[axis])
|
||||
perm_2 = _dyn_generate_inverse_index(x_shp, axis, batch_dims)
|
||||
params_grad = transpose(params_grad, perm_2)
|
||||
return params_grad, zeros_like(orig_indices), zeros_like(axis)
|
||||
|
||||
|
@ -509,14 +536,15 @@ def get_bprop_gather_v2(self):
|
|||
out_shp = shape_op(dout)
|
||||
ind_shp = shape_op(indices)
|
||||
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
|
||||
perm_1 = generate_shape_index(out_shp, ind_shp, axis)
|
||||
perm_1 = generate_shape_index(out_shp, ind_shp, axis, batch_dims)
|
||||
values_transpose = transpose(dout, perm_1)
|
||||
if F.is_sequence_value_unknown(shape_op(x)):
|
||||
params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis])
|
||||
dyn_x_sape = dyn_shape_op(x)
|
||||
if batch_dims > 0:
|
||||
params_grad = calculate_batch_gather(values_transpose, indices, dyn_x_sape, axis, batch_dims)
|
||||
else:
|
||||
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
|
||||
params_grad = unsorted_segment_sum(values_transpose, indices, dyn_x_sape[axis])
|
||||
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
|
||||
perm_2 = _generate_inverse_index(x_shp, axis)
|
||||
perm_2 = _generate_inverse_index(x_shp, axis, batch_dims)
|
||||
params_grad = transpose(params_grad, perm_2)
|
||||
return params_grad, zeros_like(orig_indices), zeros_like(axis)
|
||||
|
||||
|
|
|
@ -127,12 +127,12 @@ def get_1d_shape(in_shape):
|
|||
|
||||
|
||||
@_primexpr
|
||||
def generate_shape_index(out_shape, indices_shape, axis):
|
||||
def generate_shape_index(out_shape, indices_shape, axis, batch_dims=0):
|
||||
out_rank = len(out_shape)
|
||||
ind_rank = len(indices_shape)
|
||||
if axis < 0:
|
||||
axis += out_rank - ind_rank + 1
|
||||
perm_part1 = tuple(range(axis, axis + ind_rank))
|
||||
perm_part1 = tuple(range(axis, axis + ind_rank - batch_dims))
|
||||
index = tuple(range(out_rank))
|
||||
perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
|
||||
perm = index[:batch_dims] + perm_part1 + index[batch_dims:axis] + index[axis + ind_rank - batch_dims:]
|
||||
return perm
|
||||
|
|
Binary file not shown.
|
@ -3159,7 +3159,7 @@ def argsort(input_x, axis=-1, descending=False):
|
|||
return arg_sort
|
||||
|
||||
|
||||
def gather(input_params, input_indices, axis):
|
||||
def gather(input_params, input_indices, axis, batch_dims=0):
|
||||
r"""
|
||||
Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`.
|
||||
|
||||
|
@ -3231,7 +3231,8 @@ def gather(input_params, input_indices, axis):
|
|||
[5. 7.]
|
||||
[9. 11.]]
|
||||
"""
|
||||
return gather_(input_params, input_indices, axis)
|
||||
_gather = _get_cache_prim(P.Gather)(batch_dims)
|
||||
return _gather(input_params, input_indices, axis)
|
||||
|
||||
|
||||
def gather_d(x, dim, index):
|
||||
|
|
|
@ -990,9 +990,10 @@ class Gather(Primitive):
|
|||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
def __init__(self, batch_dims=0):
|
||||
"""Initialize Gather"""
|
||||
self.add_prim_attr("batch_dims", 0)
|
||||
validator.check_value_type("batch_dims", batch_dims, [int], self.name)
|
||||
self.add_prim_attr("batch_dims", batch_dims)
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
|
||||
|
||||
|
@ -1006,6 +1007,7 @@ class GatherV2(PrimitiveWithCheck):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize GatherV2"""
|
||||
self.add_prim_attr("batch_dims", 0)
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
|
||||
def __check__(self, params, indices, axis):
|
||||
|
@ -5999,9 +6001,12 @@ class Range(PrimitiveWithCheck):
|
|||
self.add_prim_attr('maxlen', maxlen)
|
||||
|
||||
def check_shape(self, start_shape, limit_shape, delta_shape):
|
||||
validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name)
|
||||
validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name)
|
||||
validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name)
|
||||
if not is_shape_unknown(start_shape):
|
||||
validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name)
|
||||
if not is_shape_unknown(limit_shape):
|
||||
validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name)
|
||||
if not is_shape_unknown(delta_shape):
|
||||
validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name)
|
||||
|
||||
def check_dtype(self, start_dtype, limit_dtype, delta_dtype):
|
||||
valid_dtypes = [mstype.int32, mstype.float32, mstype.int64, mstype.float64]
|
||||
|
|
|
@ -315,3 +315,23 @@ def test_gather_tensor(data_type):
|
|||
|
||||
assert out.shape == y_expect.shape
|
||||
np.allclose(out.asnumpy(), y_expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_batch_dims():
|
||||
"""
|
||||
Feature: Gather
|
||||
Description: test cases for Gather with batch_dims
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
x = np.arange(27).reshape(3, 3, 3).astype(np.int32)
|
||||
indices = np.array([[0, 0], [1, 1], [1, 1]]).astype(np.int32)
|
||||
axis = 1
|
||||
batch_dims = 1
|
||||
out = P.Gather(batch_dims)(Tensor(x), Tensor(indices), axis)
|
||||
expect = np.array([[[0, 1, 2], [0, 1, 2]],
|
||||
[[12, 13, 14], [12, 13, 14]],
|
||||
[[21, 22, 23], [21, 22, 23]]]).astype(np.int32)
|
||||
np.allclose(out.asnumpy(), expect)
|
||||
|
|
Loading…
Reference in New Issue