From 05f44ab8348316665dd9fb95e2eaab2891973cfb Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Mon, 30 Nov 2020 16:08:37 -0500 Subject: [PATCH] Fix an error that uniform_a should be less than uniform_b --- .../gpu/cuda_impl/random_op_impl.cu | 17 +++++++++++++---- .../gpu/cuda_impl/random_op_impl.cuh | 2 +- .../gpu/math/random_op_gpu_kernel.h | 10 +++++++--- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu index e2376ecd1f0..c76743a1bd1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu @@ -24,13 +24,20 @@ __global__ void NormalKernel(int seed, curandState *globalState, T *output, size return; } +__device__ bool dev_error_res = false; + template __global__ void UniformIntKernel(int seed, curandState *globalState, T *input1, size_t input_size_1, T *input2, size_t input_size_2, T *output, size_t count) { + if (!(input1[0] < input2[0])) { + dev_error_res = false; + return; + } for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { curand_init(seed, i, 0, &globalState[i]); output[i] = (T)(curand_uniform(&globalState[i]) * (input2[0] - input1[0])) + input1[0]; } + dev_error_res = true; return; } @@ -59,7 +66,7 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si } template -void UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t input_size_1, +bool UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t input_size_1, T *input2, size_t input_size_2, T *output, size_t count, cudaStream_t cuda_stream) { int RNG_seed = 0; std::random_device rd; @@ -70,9 +77,11 @@ void UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t } else { RNG_seed = static_cast(rd()); } + bool host_error_res = false; UniformIntKernel<<>> (RNG_seed, globalState, input1, input_size_1, input2, input_size_2, output, count); - return; + cudaMemcpyFromSymbol(&host_error_res, dev_error_res, sizeof(bool)); + return host_error_res; } template @@ -94,10 +103,10 @@ template void StandardNormal(int seed, int seed2, curandState *globalStat float *output, size_t count, cudaStream_t cuda_stream); template void StandardNormal(int seed, int seed2, curandState *globalState, int *output, size_t count, cudaStream_t cuda_stream); -template void UniformInt(int seed, int seed2, curandState *globalState, float *input1, size_t input_size_1, +template bool UniformInt(int seed, int seed2, curandState *globalState, float *input1, size_t input_size_1, float *input2, size_t input_size_2, float *output, size_t count, cudaStream_t cuda_stream); -template void UniformInt(int seed, int seed2, curandState *globalState, int *input1, size_t input_size_1, +template bool UniformInt(int seed, int seed2, curandState *globalState, int *input1, size_t input_size_1, int *input2, size_t input_size_2, int *output, size_t count, cudaStream_t cuda_stream); template void UniformReal(int seed, int seed2, curandState *globalState, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh index cd778d9bf71..978b5147ebb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh @@ -25,7 +25,7 @@ template void StandardNormal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream); template -void UniformInt(int seed, int seed2, curandState *globalState, +bool UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t input_size_1, T *input2, size_t input_size_2, T *output, size_t count, cudaStream_t cuda_stream); template diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h index 47a9773715d..a8e5f6a4570 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h @@ -75,9 +75,13 @@ class RandomOpGpuKernel : public GpuKernel { case RANDOM_OP_UNIFORM_INT: { T *input_addr_1 = GetDeviceAddress(inputs, 1); T *input_addr_2 = GetDeviceAddress(inputs, 2); - UniformInt(seed_, seed2_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2, - inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T), - reinterpret_cast(stream_ptr)); + bool ret = UniformInt(seed_, seed2_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2, + inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + if (!ret) { + MS_LOG(ERROR) << "For UniformInt op, `minval` should be strictly less than `maxval`"; + return false; + } break; } case RANDOM_OP_UNIFORM_REAL: {