forked from mindspore-Ecosystem/mindspore
!28468 Change the algo for random normal sampling back
Merge pull request !28468 from zichun_ye/random_op_roll_back
This commit is contained in:
commit
38dbfae44e
|
@ -80,33 +80,8 @@ class RandomOpGpuKernel : public GpuKernel {
|
|||
|
||||
switch (random_op_type_) {
|
||||
case RANDOM_OP_NORMAL: {
|
||||
// To speed up the sampling, we use cudnn for sampling.
|
||||
// Meanwhile, to keep the same seed logic, we still reset the seed every time in pynative mode.
|
||||
float *mask_f = GetDeviceAddress<float>(outputs, 0);
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode || !states_init_) {
|
||||
int RNG_seed = 0;
|
||||
std::random_device rd;
|
||||
if (seed2_ != 0) {
|
||||
RNG_seed = seed2_;
|
||||
} else if (seed_ != 0) {
|
||||
RNG_seed = seed_;
|
||||
} else {
|
||||
RNG_seed = static_cast<int>(rd());
|
||||
}
|
||||
CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_PHILOX4_32_10),
|
||||
"Failed to create generator");
|
||||
CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, RNG_seed),
|
||||
"Failed to SetPseudoRandomGeneratorSeed");
|
||||
MS_EXCEPTION_IF_NULL(mask_generator_);
|
||||
states_init_ = true;
|
||||
}
|
||||
CHECK_CURAND_RET_WITH_EXCEPT(curandSetStream(mask_generator_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to set stream for generator");
|
||||
// curandGen only support float or double for mask.
|
||||
CHECK_CURAND_RET_WITH_EXCEPT(
|
||||
curandGenerateNormal(mask_generator_, mask_f, outputs[0]->size / sizeof(float), 0.0, 1.0),
|
||||
"Failed to generate normal");
|
||||
|
||||
StandardNormal(seed_, seed2_, devStates, output_addr, outputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case RANDOM_OP_UNIFORM_INT: {
|
||||
|
|
Loading…
Reference in New Issue