!26842 Speed up random normal sampling
Merge pull request !26842 from zichun_ye/random_normal_speed_up
This commit is contained in:
commit
bfd190482f
|
@ -70,8 +70,9 @@ class RandomOpGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
curandState *devStates = nullptr;
|
curandState *devStates = nullptr;
|
||||||
// Operator CudnnUniformReal does not need workspace memory.
|
// Operator StandardNormal and CudnnUniformReal use curand
|
||||||
if (random_op_type_ != RANDOM_OP_CUDNN_UNIFORM_REAL) {
|
// so they do not need workspace memory.
|
||||||
|
if (random_op_type_ >= RANDOM_OP_UNIFORM_INT && random_op_type_ <= RANDOM_OP_UNIFORM_REAL) {
|
||||||
void *workspace_addr = GetDeviceAddress<void *>(workspace, 0);
|
void *workspace_addr = GetDeviceAddress<void *>(workspace, 0);
|
||||||
devStates = reinterpret_cast<curandState *>(workspace_addr);
|
devStates = reinterpret_cast<curandState *>(workspace_addr);
|
||||||
}
|
}
|
||||||
|
@ -79,8 +80,30 @@ class RandomOpGpuKernel : public GpuKernel {
|
||||||
|
|
||||||
switch (random_op_type_) {
|
switch (random_op_type_) {
|
||||||
case RANDOM_OP_NORMAL: {
|
case RANDOM_OP_NORMAL: {
|
||||||
StandardNormal(seed_, seed2_, devStates, output_addr, outputs[0]->size / sizeof(T),
|
float *mask_f = GetDeviceAddress<float>(outputs, 0);
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
if (!states_init_) {
|
||||||
|
int RNG_seed = 0;
|
||||||
|
std::random_device rd;
|
||||||
|
if (seed2_ != 0) {
|
||||||
|
RNG_seed = seed2_;
|
||||||
|
} else if (seed_ != 0) {
|
||||||
|
RNG_seed = seed_;
|
||||||
|
} else {
|
||||||
|
RNG_seed = static_cast<int>(rd());
|
||||||
|
}
|
||||||
|
CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_PHILOX4_32_10),
|
||||||
|
"Failed to create generator");
|
||||||
|
CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, RNG_seed),
|
||||||
|
"Failed to SetPseudoRandomGeneratorSeed");
|
||||||
|
MS_EXCEPTION_IF_NULL(mask_generator_);
|
||||||
|
states_init_ = true;
|
||||||
|
}
|
||||||
|
CHECK_CURAND_RET_WITH_EXCEPT(curandSetStream(mask_generator_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"Failed to set stream for generator");
|
||||||
|
// curandGen only support float or double for mask.
|
||||||
|
CHECK_CURAND_RET_WITH_EXCEPT(
|
||||||
|
curandGenerateNormal(mask_generator_, mask_f, outputs[0]->size / sizeof(float), 0.0, 1.0),
|
||||||
|
"Failed to generate uniform");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case RANDOM_OP_UNIFORM_INT: {
|
case RANDOM_OP_UNIFORM_INT: {
|
||||||
|
@ -103,7 +126,7 @@ class RandomOpGpuKernel : public GpuKernel {
|
||||||
case RANDOM_OP_CUDNN_UNIFORM_REAL: {
|
case RANDOM_OP_CUDNN_UNIFORM_REAL: {
|
||||||
float *mask_f = GetDeviceAddress<float>(outputs, 0);
|
float *mask_f = GetDeviceAddress<float>(outputs, 0);
|
||||||
if (!states_init_) {
|
if (!states_init_) {
|
||||||
CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT),
|
CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_PHILOX4_32_10),
|
||||||
"Failed to create generator");
|
"Failed to create generator");
|
||||||
CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, seed_),
|
CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, seed_),
|
||||||
"Failed to SetPseudoRandomGeneratorSeed");
|
"Failed to SetPseudoRandomGeneratorSeed");
|
||||||
|
|
Loading…
Reference in New Issue