forked from mindspore-Ecosystem/mindspore
fix nccl broadcast
This commit is contained in:
parent
1a4d3e351e
commit
572a7c4741
|
@ -86,6 +86,9 @@ void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
|||
num_directions_ = 2;
|
||||
}
|
||||
const int gate_size = 4 * hidden_size_;
|
||||
if (num_layers_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
|
||||
}
|
||||
for (int i = 0; i < num_layers_; ++i) {
|
||||
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
||||
weight_h_size_ += gate_size * hidden_size_;
|
||||
|
@ -95,9 +98,6 @@ void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
|||
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
|
||||
MS_LOG(EXCEPTION) << "error iteration shape!";
|
||||
}
|
||||
if (num_layers_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
|
||||
}
|
||||
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "lstm only support 3-D input!";
|
||||
}
|
||||
|
|
|
@ -104,6 +104,9 @@ void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
|||
num_directions_ = 2;
|
||||
}
|
||||
const int gate_size = 4 * hidden_size_;
|
||||
if (num_layers_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
|
||||
}
|
||||
for (int i = 0; i < num_layers_; ++i) {
|
||||
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
||||
weight_h_size_ += gate_size * hidden_size_;
|
||||
|
@ -113,9 +116,6 @@ void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
|||
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
|
||||
MS_LOG(EXCEPTION) << "error iteration shape!";
|
||||
}
|
||||
if (num_layers_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
|
||||
}
|
||||
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "lstm only support 3-D input!";
|
||||
}
|
||||
|
|
|
@ -109,9 +109,13 @@ class NcclGpuKernel : public GpuKernel {
|
|||
auto broadcast_funcptr =
|
||||
reinterpret_cast<Broadcast>(dlsym(const_cast<void *>(collective_handle_), "Broadcast"));
|
||||
MS_EXCEPTION_IF_NULL(broadcast_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT((*broadcast_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
|
||||
nccl_data_type_, root_, stream, group_name_),
|
||||
"ncclBroadcast failed");
|
||||
for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) {
|
||||
input_addr = GetDeviceAddress<T>(inputs, i);
|
||||
output_addr = GetDeviceAddress<T>(outputs, i);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT((*broadcast_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
|
||||
nccl_data_type_, root_, stream, group_name_),
|
||||
"ncclBroadcast failed");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
|
|
@ -28,7 +28,7 @@ class Categorical(Distribution):
|
|||
probs (Tensor, list, numpy.ndarray, Parameter, float): event probabilities.
|
||||
logits (Tensor, list, numpy.ndarray, Parameter, float): event log-odds.
|
||||
seed (int): seed to use in sampling. Default: 0.
|
||||
dtype (mindspore.dtype): type of the distribution. Default: mstype.int32.
|
||||
dtype (mstype.int32): type of the distribution. Default: mstype.int32.
|
||||
name (str): name of the distribution. Default: Categorical.
|
||||
|
||||
Note:
|
||||
|
@ -49,7 +49,7 @@ class Categorical(Distribution):
|
|||
>>>
|
||||
>>> # Similar calls can be made to logits
|
||||
>>> ans = self.ca.probs
|
||||
>>> # value should be Tensor
|
||||
>>> # value should be Tensor(mstype.float32, bool, mstype.int32)
|
||||
>>> ans = self.ca.log_prob(value)
|
||||
>>>
|
||||
>>> # Usage of enumerate_support
|
||||
|
@ -210,6 +210,8 @@ class Categorical(Distribution):
|
|||
def enumerate_support(self, expand=True):
|
||||
r"""
|
||||
Enumerate categories.
|
||||
Args:
|
||||
expand (Bool): whether to expand.
|
||||
"""
|
||||
num_events = self._num_events
|
||||
values = nn.Range(0., num_events, 1)()
|
||||
|
|
|
@ -439,7 +439,7 @@ class Multinomial(PrimitiveWithInfer):
|
|||
Args:
|
||||
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||
Must be non-negative. Default: 0.
|
||||
replacement(bool) - whether to draw with replacement or not.
|
||||
replacement(bool): Whether to draw with replacement or not.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dims.
|
||||
|
|
Loading…
Reference in New Issue