forked from mindspore-Ecosystem/mindspore
!10802 fix potential size problem in gpu argmaxwithvalue
From: @TFbunny Reviewed-by: @robingrosman Signed-off-by:
This commit is contained in:
commit
a76e2b2649
|
@ -46,8 +46,8 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
|
|||
bool Init(const CNodePtr &kernel_node) override {
|
||||
std::vector<size_t> shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1);
|
||||
int dims = shape.size();
|
||||
int axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
int64_t dims = shape.size();
|
||||
int64_t axis = GetAttr<int64_t>(kernel_node, "axis");
|
||||
if (axis < 0) {
|
||||
axis += dims;
|
||||
}
|
||||
|
@ -59,14 +59,16 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
|
|||
for (auto x : output_shape) {
|
||||
output_size_ *= x;
|
||||
}
|
||||
bound_ = shape[axis];
|
||||
bound_ = static_cast<S>(shape[axis]);
|
||||
if (shape[axis] != static_cast<size_t>(bound_)) {
|
||||
MS_LOG(EXCEPTION) << "bound's shape is larger than index type and overflows when casting.";
|
||||
}
|
||||
outerSize_ = 1;
|
||||
for (int i = axis - 1; i >= 0; i--) {
|
||||
for (int64_t i = axis - 1; i >= 0; i--) {
|
||||
outerSize_ *= shape[i];
|
||||
}
|
||||
|
||||
innerSize_ = 1;
|
||||
for (int i = axis + 1; i < dims; i++) {
|
||||
for (int64_t i = axis + 1; i < dims; i++) {
|
||||
innerSize_ *= shape[i];
|
||||
}
|
||||
InitSizeLists();
|
||||
|
@ -86,7 +88,7 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
|
|||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
size_t bound_;
|
||||
S bound_;
|
||||
size_t outerSize_;
|
||||
size_t innerSize_;
|
||||
};
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "include/cuda_fp16.h"
|
||||
template <typename T, typename S>
|
||||
__global__ void ArgmaxWithValue(const T *input, const size_t bound, size_t outerSize,
|
||||
__global__ void ArgmaxWithValue(const T *input, const S bound, size_t outerSize,
|
||||
size_t innerSize, S *index, T *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outerSize * innerSize;
|
||||
pos += gridDim.x * blockDim.x) {
|
||||
|
@ -27,7 +27,7 @@ __global__ void ArgmaxWithValue(const T *input, const size_t bound, size_t outer
|
|||
S idx = 0;
|
||||
size_t InputOffset = x * bound * innerSize + 0 * innerSize + y;
|
||||
T maxData = input[InputOffset];
|
||||
for (size_t i = 0; i < bound; i++) {
|
||||
for (S i = 0; i < bound; i++) {
|
||||
InputOffset = x * bound * innerSize + i * innerSize + y;
|
||||
auto inputData = input[InputOffset];
|
||||
idx = inputData > maxData ? i : idx;
|
||||
|
@ -40,16 +40,16 @@ __global__ void ArgmaxWithValue(const T *input, const size_t bound, size_t outer
|
|||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalArgmaxWithValue(const T *input, const size_t bound_, const size_t outerSize_, const size_t innerSize_,
|
||||
void CalArgmaxWithValue(const T *input, const S bound_, const size_t outerSize_, const size_t innerSize_,
|
||||
S *index, T *output, cudaStream_t cuda_stream) {
|
||||
ArgmaxWithValue<<<GET_BLOCKS(outerSize_), GET_THREADS, 0, cuda_stream>>>(input, bound_, outerSize_, innerSize_,
|
||||
index, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalArgmaxWithValue<float, int>(const float *input, const size_t bound_, const size_t outerSize_,
|
||||
template void CalArgmaxWithValue<float, int>(const float *input, const int bound_, const size_t outerSize_,
|
||||
const size_t innerSize_, int *index, float *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalArgmaxWithValue<half, int>(const half *input, const size_t bound_, const size_t outerSize_,
|
||||
template void CalArgmaxWithValue<half, int>(const half *input, const int bound_, const size_t outerSize_,
|
||||
const size_t innerSize_, int *index, half *output,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -17,6 +17,6 @@
|
|||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
|
||||
template <typename T, typename S>
|
||||
void CalArgmaxWithValue(const T *input, const size_t bound_, const size_t outerSize_, const size_t innerSize_, S *index,
|
||||
void CalArgmaxWithValue(const T *input, const S bound_, const size_t outerSize_, const size_t innerSize_, S *index,
|
||||
T *output, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
|
||||
|
|
Loading…
Reference in New Issue