forked from mindspore-Ecosystem/mindspore
!8591 gatherv2 dynamic shape
From: @jonwe Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @robingrosman
This commit is contained in:
commit
95ce3eab08
|
@ -34,5 +34,19 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
SparseGatherV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2GpuFwdKernel, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(GatherV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2GpuFwdKernel, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(GatherV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2GpuFwdKernel, half, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_GATHER_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh"
|
||||
|
@ -41,31 +42,78 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|||
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
auto input_dim1 = input_shapes_[IntToSize(axis_)];
|
||||
GatherV2(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
if (is_dynamic_shape_) {
|
||||
// if we are in dynamic shape mode, we don't know dims_, so we need to store the input_shape_ and indices_shape_,
|
||||
// and axis_ in the workspace to calculate dims_
|
||||
size_t *input_shape_device_address = GetDeviceAddress<size_t>(workspace, 0);
|
||||
size_t *indices_shape_device_address = GetDeviceAddress<size_t>(workspace, 1);
|
||||
int64_t *axis_device_address = GetDeviceAddress<int64_t>(workspace, 2);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(input_shape_device_address, input_shapes_.data(), workspace_size_list_[0],
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(indices_shape_device_address, indices_shapes_.data(), workspace_size_list_[1],
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync indices_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(axis_device_address, &axis_, workspace_size_list_[2],
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync axis_ failed");
|
||||
|
||||
// output shape will be here for us to copy back to host
|
||||
size_t *output_shape_device_address = GetDeviceAddress<size_t>(workspace, 3);
|
||||
CalGatherV2DynamicShape(input_addr, indices_addr, output_addr, input_shape_device_address, input_shapes_.size(),
|
||||
indices_shape_device_address, indices_shapes_.size(), axis_device_address,
|
||||
output_shape_device_address, max_output_size_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
size_t output_rank = input_shapes_.size() - 1 + indices_shapes_.size();
|
||||
real_output_shape_.resize(output_rank);
|
||||
CHECK_CUDA_RET_WITH_ERROR(
|
||||
cudaMemcpyAsync(&real_output_shape_[0], output_shape_device_address, output_rank * sizeof(int32_t),
|
||||
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to copy gpu memory.");
|
||||
|
||||
} else {
|
||||
auto input_dim1 = input_shapes_[IntToSize(axis_)];
|
||||
CalGatherV2StaticShape(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
InitResource();
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
if (input_num == 3) {
|
||||
is_dynamic_shape_ = true;
|
||||
} else if (input_num != 2) {
|
||||
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuV2FwdKernel needs 2.";
|
||||
}
|
||||
|
||||
input_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
indices_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
|
||||
output_shapes_ = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
|
||||
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
if (axis_ < 0) {
|
||||
axis_ = axis_ + SizeToInt(input_shapes_.size());
|
||||
if (is_dynamic_shape_) {
|
||||
c_node_ptr_ = kernel_node;
|
||||
size_t input_shape_min = *std::min_element(input_shapes_.begin(), input_shapes_.end());
|
||||
max_output_size_ = (GetSize(input_shapes_) / input_shape_min) * GetSize(indices_shapes_);
|
||||
} else {
|
||||
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
if (axis_ < 0) {
|
||||
axis_ = axis_ + SizeToInt(input_shapes_.size());
|
||||
}
|
||||
|
||||
Reshape();
|
||||
}
|
||||
|
||||
Reshape();
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
void ResetResource() noexcept override {
|
||||
is_dynamic_shape_ = false;
|
||||
max_output_size_ = -1;
|
||||
input_shapes_.clear();
|
||||
indices_shapes_.clear();
|
||||
output_shapes_.clear();
|
||||
|
@ -84,8 +132,29 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|||
size = GetSize(indices_shapes_);
|
||||
input_size_list_.push_back(size);
|
||||
|
||||
size = GetSize(output_shapes_);
|
||||
output_size_list_.push_back(size);
|
||||
if (is_dynamic_shape_) {
|
||||
// add by chenweifeng
|
||||
input_size_list_.push_back(sizeof(S));
|
||||
|
||||
// allocate maximum size needed
|
||||
output_size_list_.push_back(max_output_size_);
|
||||
|
||||
// allocate workspace memory for input, indices, axis, and output shape respectively
|
||||
size = GetSize(input_shapes_);
|
||||
workspace_size_list_.push_back(size);
|
||||
|
||||
size = GetSize(indices_shapes_);
|
||||
workspace_size_list_.push_back(size);
|
||||
|
||||
size = sizeof(int32_t);
|
||||
workspace_size_list_.push_back(size);
|
||||
|
||||
size = GetSize(input_shapes_);
|
||||
workspace_size_list_.push_back(size);
|
||||
} else {
|
||||
size = GetSize(output_shapes_);
|
||||
output_size_list_.push_back(size);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -126,7 +195,11 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|||
std::vector<size_t> output_shapes_;
|
||||
|
||||
size_t dims_[3] = {};
|
||||
int axis_;
|
||||
int64_t axis_;
|
||||
bool is_dynamic_shape_;
|
||||
int max_output_size_;
|
||||
std::vector<size_t> real_output_shape_;
|
||||
CNodePtr c_node_ptr_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
#include "backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T, typename S>
|
||||
__global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1) {
|
||||
__device__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1) {
|
||||
int num = output_dim0 * output_dim1 * output_dim2;
|
||||
int i, j, k;
|
||||
for (int write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num;
|
||||
|
@ -38,17 +38,90 @@ __global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_di
|
|||
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream) {
|
||||
__global__ void GatherV2StaticShapeWrapper(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1) {
|
||||
GatherV2Kernel(input, indices, output, output_dim0, output_dim1, output_dim2, input_dim1);
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void GatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank,
|
||||
size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp,
|
||||
size_t *output_shape_wksp, const int max_output_size) {
|
||||
int gt_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
size_t axis = (size_t)(*axis_wksp);
|
||||
|
||||
int output_shape_index = 0;
|
||||
size_t output_dim0 = 1;
|
||||
for (size_t i = 0; i < axis; i++) {
|
||||
output_dim0 *= input_shape_wksp[i];
|
||||
|
||||
if (gt_id == 0) {
|
||||
output_shape_wksp[output_shape_index] = input_shape_wksp[i];
|
||||
output_shape_index++;
|
||||
}
|
||||
}
|
||||
|
||||
size_t output_dim1 = 1;
|
||||
for (size_t i = 0; i < indices_rank; i++) {
|
||||
output_dim1 *= indices_shape_wksp[i];
|
||||
|
||||
if (gt_id == 0) {
|
||||
output_shape_wksp[output_shape_index] = indices_shape_wksp[i];
|
||||
output_shape_index++;
|
||||
}
|
||||
}
|
||||
|
||||
size_t output_dim2 = 1;
|
||||
for (size_t i = axis + 1; i < input_rank; i++) {
|
||||
output_dim2 *= indices_shape_wksp[i];
|
||||
|
||||
if (gt_id == 0) {
|
||||
output_shape_wksp[output_shape_index] = input_shape_wksp[i];
|
||||
output_shape_index++;
|
||||
}
|
||||
}
|
||||
|
||||
size_t input_dim1 = (size_t)(input_shape_wksp[axis]);
|
||||
|
||||
GatherV2Kernel(input, indices, output, output_dim0, output_dim1, output_dim2, input_dim1);
|
||||
}
|
||||
|
||||
// entry points from gpu kernel's .h file
|
||||
template <typename T, typename S>
|
||||
void CalGatherV2StaticShape(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream) {
|
||||
int size = output_dim0 * output_dim1 * output_dim2;
|
||||
GatherV2Kernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, output_dim1,
|
||||
output_dim2, input_dim1);
|
||||
GatherV2StaticShapeWrapper<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0,
|
||||
output_dim1, output_dim2, input_dim1);
|
||||
return;
|
||||
}
|
||||
|
||||
template void GatherV2<float, int>(float *input, int *indices, float *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template <typename T, typename S>
|
||||
void CalGatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank,
|
||||
size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp,
|
||||
size_t *output_shape_wksp, const int max_output_size, cudaStream_t stream) {
|
||||
GatherV2DynamicShape<<<GET_BLOCKS(max_output_size), GET_THREADS, 0, stream>>>(
|
||||
input, indices, output, input_shape_wksp, input_rank, indices_shape_wksp, indices_rank, axis_wksp,
|
||||
output_shape_wksp, max_output_size);
|
||||
}
|
||||
|
||||
template void GatherV2<half, int>(half *input, int *indices, half *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
// template instantiations
|
||||
template void CalGatherV2StaticShape<float, int>(float *input, int *indices, float *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void CalGatherV2StaticShape<half, int>(half *input, int *indices, half *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void CalGatherV2DynamicShape<float, int>(float *input, int *indices, float *output, size_t *input_shape_wksp,
|
||||
size_t input_rank, size_t *indices_shape_wksp, size_t indices_rank,
|
||||
int64_t *axis_wksp, size_t *output_shape_wksp,
|
||||
const int max_output_size, cudaStream_t stream);
|
||||
|
||||
template void CalGatherV2DynamicShape<half, int>(half *input, int *indices, half *output, size_t *input_shape_wksp,
|
||||
size_t input_rank, size_t *indices_shape_wksp, size_t indices_rank,
|
||||
int64_t *axis_wksp, size_t *output_shape_wksp,
|
||||
const int max_output_size, cudaStream_t stream);
|
||||
|
|
|
@ -17,7 +17,11 @@
|
|||
#ifndef MINDSPORE_GATHER_GPU_CU_H
|
||||
#define MINDSPORE_GATHER_GPU_CU_H
|
||||
template <typename T, typename S>
|
||||
void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
void CalGatherV2StaticShape(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalGatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank,
|
||||
size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp,
|
||||
size_t *output_shape_wksp, const int max_output_size, cudaStream_t stream);
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue