From 572a7c474155e924ecbb095126ce6cb02f19c89a Mon Sep 17 00:00:00 2001 From: baihuawei Date: Tue, 1 Sep 2020 15:27:36 +0800 Subject: [PATCH] fix nccl broadcast --- .../kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc | 6 +++--- .../kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc | 6 +++--- .../backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h | 10 +++++++--- mindspore/nn/probability/distribution/categorical.py | 6 ++++-- mindspore/ops/operations/random_ops.py | 2 +- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc index 72ce1fd9c1b..585c5270399 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc @@ -86,6 +86,9 @@ void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) { num_directions_ = 2; } const int gate_size = 4 * hidden_size_; + if (num_layers_ <= 0) { + MS_LOG(EXCEPTION) << "layers must be greater than zero!"; + } for (int i = 0; i < num_layers_; ++i) { weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); weight_h_size_ += gate_size * hidden_size_; @@ -95,9 +98,6 @@ void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) { if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { MS_LOG(EXCEPTION) << "error iteration shape!"; } - if (num_layers_ <= 0) { - MS_LOG(EXCEPTION) << "layers must be greater than zero!"; - } if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { MS_LOG(EXCEPTION) << "lstm only support 3-D input!"; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc index 52642911078..e141a38ffe6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc @@ -104,6 +104,9 @@ void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { num_directions_ = 2; } const int gate_size = 4 * hidden_size_; + if (num_layers_ <= 0) { + MS_LOG(EXCEPTION) << "layers must be greater than zero!"; + } for (int i = 0; i < num_layers_; ++i) { weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); weight_h_size_ += gate_size * hidden_size_; @@ -113,9 +116,6 @@ void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { MS_LOG(EXCEPTION) << "error iteration shape!"; } - if (num_layers_ <= 0) { - MS_LOG(EXCEPTION) << "layers must be greater than zero!"; - } if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { MS_LOG(EXCEPTION) << "lstm only support 3-D input!"; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h index 529045c5436..2ffe5fa486d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h @@ -109,9 +109,13 @@ class NcclGpuKernel : public GpuKernel { auto broadcast_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "Broadcast")); MS_EXCEPTION_IF_NULL(broadcast_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT((*broadcast_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), - nccl_data_type_, root_, stream, group_name_), - "ncclBroadcast failed"); + for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) { + input_addr = GetDeviceAddress(inputs, i); + output_addr = GetDeviceAddress(outputs, i); + CHECK_NCCL_RET_WITH_EXCEPT((*broadcast_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), + nccl_data_type_, root_, stream, group_name_), + "ncclBroadcast failed"); + } break; } default: { diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index 98901058c68..88541db0e3c 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -28,7 +28,7 @@ class Categorical(Distribution): probs (Tensor, list, numpy.ndarray, Parameter, float): event probabilities. logits (Tensor, list, numpy.ndarray, Parameter, float): event log-odds. seed (int): seed to use in sampling. Default: 0. - dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. + dtype (mstype.int32): type of the distribution. Default: mstype.int32. name (str): name of the distribution. Default: Categorical. Note: @@ -49,7 +49,7 @@ class Categorical(Distribution): >>> >>> # Similar calls can be made to logits >>> ans = self.ca.probs - >>> # value should be Tensor + >>> # value should be Tensor(mstype.float32, bool, mstype.int32) >>> ans = self.ca.log_prob(value) >>> >>> # Usage of enumerate_support @@ -210,6 +210,8 @@ class Categorical(Distribution): def enumerate_support(self, expand=True): r""" Enumerate categories. + Args: + expand (Bool): whether to expand. """ num_events = self._num_events values = nn.Range(0., num_events, 1)() diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 0b07c0e08b3..49566089621 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -439,7 +439,7 @@ class Multinomial(PrimitiveWithInfer): Args: seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. Must be non-negative. Default: 0. - replacement(bool) - whether to draw with replacement or not. + replacement(bool): Whether to draw with replacement or not. Inputs: - **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dims.