diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h index 9485f293161..876710c5578 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h @@ -70,8 +70,9 @@ class RandomOpGpuKernel : public GpuKernel { } curandState *devStates = nullptr; - // Operator CudnnUniformReal does not need workspace memory. - if (random_op_type_ != RANDOM_OP_CUDNN_UNIFORM_REAL) { + // Operator StandardNormal and CudnnUniformReal use curand + // so they do not need workspace memory. + if (random_op_type_ >= RANDOM_OP_UNIFORM_INT && random_op_type_ <= RANDOM_OP_UNIFORM_REAL) { void *workspace_addr = GetDeviceAddress(workspace, 0); devStates = reinterpret_cast(workspace_addr); } @@ -79,8 +80,30 @@ class RandomOpGpuKernel : public GpuKernel { switch (random_op_type_) { case RANDOM_OP_NORMAL: { - StandardNormal(seed_, seed2_, devStates, output_addr, outputs[0]->size / sizeof(T), - reinterpret_cast(stream_ptr)); + float *mask_f = GetDeviceAddress(outputs, 0); + if (!states_init_) { + int RNG_seed = 0; + std::random_device rd; + if (seed2_ != 0) { + RNG_seed = seed2_; + } else if (seed_ != 0) { + RNG_seed = seed_; + } else { + RNG_seed = static_cast(rd()); + } + CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_PHILOX4_32_10), + "Failed to create generator"); + CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, RNG_seed), + "Failed to SetPseudoRandomGeneratorSeed"); + MS_EXCEPTION_IF_NULL(mask_generator_); + states_init_ = true; + } + CHECK_CURAND_RET_WITH_EXCEPT(curandSetStream(mask_generator_, reinterpret_cast(stream_ptr)), + "Failed to set stream for generator"); + // curandGen only support float or double for mask. + CHECK_CURAND_RET_WITH_EXCEPT( + curandGenerateNormal(mask_generator_, mask_f, outputs[0]->size / sizeof(float), 0.0, 1.0), + "Failed to generate uniform"); break; } case RANDOM_OP_UNIFORM_INT: { @@ -103,7 +126,7 @@ class RandomOpGpuKernel : public GpuKernel { case RANDOM_OP_CUDNN_UNIFORM_REAL: { float *mask_f = GetDeviceAddress(outputs, 0); if (!states_init_) { - CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT), + CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_PHILOX4_32_10), "Failed to create generator"); CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, seed_), "Failed to SetPseudoRandomGeneratorSeed");