!6601 Change the convention that random seed is generated at GPU back-end

Merge pull request !6601 from peixu_ren/custom_gpu2
This commit is contained in:
mindspore-ci-bot 2020-09-21 09:21:43 +08:00 committed by Gitee
commit 3621bb2348
2 changed files with 7 additions and 3 deletions

View File

@ -46,12 +46,13 @@ __global__ void UniformRealKernel(int seed, curandState *globalState, T *output,
template <typename T>
void StandardNormal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream) {
int RNG_seed = 0;
std::random_device rd;
if (seed2 != 0) {
RNG_seed = seed2;
} else if (seed != 0) {
RNG_seed = seed;
} else {
RNG_seed = time(NULL);
RNG_seed = static_cast<int>(rd());
}
NormalKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState, output, count);
return;
@ -61,12 +62,13 @@ template <typename T>
void UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t input_size_1,
T *input2, size_t input_size_2, T *output, size_t count, cudaStream_t cuda_stream) {
int RNG_seed = 0;
std::random_device rd;
if (seed2 != 0) {
RNG_seed = seed2;
} else if (seed != 0) {
RNG_seed = seed;
} else {
RNG_seed = time(NULL);
RNG_seed = static_cast<int>(rd());
}
UniformIntKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>
(RNG_seed, globalState, input1, input_size_1, input2, input_size_2, output, count);
@ -76,12 +78,13 @@ void UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t
template <typename T>
void UniformReal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream) {
int RNG_seed = 0;
std::random_device rd;
if (seed2 != 0) {
RNG_seed = seed2;
} else if (seed != 0) {
RNG_seed = seed;
} else {
RNG_seed = time(NULL);
RNG_seed = static_cast<int>(rd());
}
UniformRealKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState, output, count);
return;

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_
#include <curand_kernel.h>
#include <random>
#include "runtime/device/gpu/cuda_common.h"
template <typename T>