From d7709a0c13120da2ca77b77cce7187e82604b713 Mon Sep 17 00:00:00 2001 From: liu-yongqi-63 Date: Sat, 18 Feb 2023 11:02:43 +0800 Subject: [PATCH] Fix the problem that the result of Bernoulli operator is 0 under gpu again --- .../device/gpu/kernel/cuda_impl/cuda_ops/bernoulli_impl.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bernoulli_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bernoulli_impl.cu index 119bfc88271..7e0853673af 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bernoulli_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bernoulli_impl.cu @@ -75,7 +75,8 @@ template void BroadcastBernoulliForward(const std::vector &x_dims, const std::vector &p_dims, const T *input, S *output, uint64_t seed, const size_t num_count, const uint32_t &device_id, cudaStream_t cuda_stream) { - BroadcastBernoulliForwardKernel<<>>( + int block_num = 256 > num_count ? num_count : 256; + BroadcastBernoulliForwardKernel<<>>( x_dims[0], x_dims[1], x_dims[2], x_dims[3], x_dims[4], x_dims[5], x_dims[6], p_dims[0], p_dims[1], p_dims[2], p_dims[3], p_dims[4], p_dims[5], p_dims[6], input, output, seed, num_count); }