!9251 Fix an error as uniform_a should be less than uniform_b
From: @peixu_ren Reviewed-by: @sunnybeike,@zichun_ye Signed-off-by: @sunnybeike
This commit is contained in:
commit
89e81abc01
|
@ -24,13 +24,20 @@ __global__ void NormalKernel(int seed, curandState *globalState, T *output, size
|
|||
return;
|
||||
}
|
||||
|
||||
__device__ bool dev_error_res = false;
|
||||
|
||||
template <typename T>
|
||||
__global__ void UniformIntKernel(int seed, curandState *globalState, T *input1, size_t input_size_1,
|
||||
T *input2, size_t input_size_2, T *output, size_t count) {
|
||||
if (!(input1[0] < input2[0])) {
|
||||
dev_error_res = false;
|
||||
return;
|
||||
}
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
curand_init(seed, i, 0, &globalState[i]);
|
||||
output[i] = (T)(curand_uniform(&globalState[i]) * (input2[0] - input1[0])) + input1[0];
|
||||
}
|
||||
dev_error_res = true;
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -59,7 +66,7 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t input_size_1,
|
||||
bool UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t input_size_1,
|
||||
T *input2, size_t input_size_2, T *output, size_t count, cudaStream_t cuda_stream) {
|
||||
int RNG_seed = 0;
|
||||
std::random_device rd;
|
||||
|
@ -70,9 +77,11 @@ void UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t
|
|||
} else {
|
||||
RNG_seed = static_cast<int>(rd());
|
||||
}
|
||||
bool host_error_res = false;
|
||||
UniformIntKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>
|
||||
(RNG_seed, globalState, input1, input_size_1, input2, input_size_2, output, count);
|
||||
return;
|
||||
cudaMemcpyFromSymbol(&host_error_res, dev_error_res, sizeof(bool));
|
||||
return host_error_res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -94,10 +103,10 @@ template void StandardNormal<float>(int seed, int seed2, curandState *globalStat
|
|||
float *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void StandardNormal<int>(int seed, int seed2, curandState *globalState,
|
||||
int *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void UniformInt<float>(int seed, int seed2, curandState *globalState, float *input1, size_t input_size_1,
|
||||
template bool UniformInt<float>(int seed, int seed2, curandState *globalState, float *input1, size_t input_size_1,
|
||||
float *input2, size_t input_size_2, float *output, size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void UniformInt<int>(int seed, int seed2, curandState *globalState, int *input1, size_t input_size_1,
|
||||
template bool UniformInt<int>(int seed, int seed2, curandState *globalState, int *input1, size_t input_size_1,
|
||||
int *input2, size_t input_size_2, int *output, size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void UniformReal<float>(int seed, int seed2, curandState *globalState,
|
||||
|
|
|
@ -25,7 +25,7 @@ template <typename T>
|
|||
void StandardNormal(int seed, int seed2, curandState *globalState,
|
||||
T *output, size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void UniformInt(int seed, int seed2, curandState *globalState,
|
||||
bool UniformInt(int seed, int seed2, curandState *globalState,
|
||||
T *input1, size_t input_size_1, T *input2, size_t input_size_2,
|
||||
T *output, size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
|
|
|
@ -75,9 +75,13 @@ class RandomOpGpuKernel : public GpuKernel {
|
|||
case RANDOM_OP_UNIFORM_INT: {
|
||||
T *input_addr_1 = GetDeviceAddress<T>(inputs, 1);
|
||||
T *input_addr_2 = GetDeviceAddress<T>(inputs, 2);
|
||||
UniformInt(seed_, seed2_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2,
|
||||
inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
bool ret = UniformInt(seed_, seed2_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2,
|
||||
inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "For UniformInt op, `minval` should be strictly less than `maxval`";
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RANDOM_OP_UNIFORM_REAL: {
|
||||
|
|
Loading…
Reference in New Issue