From 48268a390661c2c284c8ebd5ce78cbee714474a7 Mon Sep 17 00:00:00 2001 From: Zichun Ye Date: Fri, 31 Dec 2021 11:56:37 +0800 Subject: [PATCH] roll back random normal op --- .../gpu/math/random_op_gpu_kernel.h | 29 ++----------------- 1 file changed, 2 insertions(+), 27 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h index ef4687ffdfe..aaedf3f1e8b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h @@ -80,33 +80,8 @@ class RandomOpGpuKernel : public GpuKernel { switch (random_op_type_) { case RANDOM_OP_NORMAL: { - // To speed up the sampling, we use cudnn for sampling. - // Meanwhile, to keep the same seed logic, we still reset the seed every time in pynative mode. - float *mask_f = GetDeviceAddress(outputs, 0); - if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode || !states_init_) { - int RNG_seed = 0; - std::random_device rd; - if (seed2_ != 0) { - RNG_seed = seed2_; - } else if (seed_ != 0) { - RNG_seed = seed_; - } else { - RNG_seed = static_cast(rd()); - } - CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_PHILOX4_32_10), - "Failed to create generator"); - CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, RNG_seed), - "Failed to SetPseudoRandomGeneratorSeed"); - MS_EXCEPTION_IF_NULL(mask_generator_); - states_init_ = true; - } - CHECK_CURAND_RET_WITH_EXCEPT(curandSetStream(mask_generator_, reinterpret_cast(stream_ptr)), - "Failed to set stream for generator"); - // curandGen only support float or double for mask. - CHECK_CURAND_RET_WITH_EXCEPT( - curandGenerateNormal(mask_generator_, mask_f, outputs[0]->size / sizeof(float), 0.0, 1.0), - "Failed to generate normal"); - + StandardNormal(seed_, seed2_, devStates, output_addr, outputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); break; } case RANDOM_OP_UNIFORM_INT: {