forked from mindspore-Ecosystem/mindspore
!22797 Fix wide&deep allreduce_fusion_index bug in PyNative mode
Merge pull request !22797 from caifubi/master-pynative-allreduce-fusion-index
This commit is contained in:
commit
d397fa22ec
|
@ -2428,8 +2428,10 @@ void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split
|
|||
// obtain graph output tensor num
|
||||
auto grads_count = GetBpropGraphGradsCount(graph);
|
||||
if (split_index_num >= grads_count) {
|
||||
MS_LOG(EXCEPTION) << "Invalid all_reduce_fusion_config:" << *split_index
|
||||
<< ". fusion index should be smaller than:" << grads_count;
|
||||
MS_LOG(WARNING) << "Invalid all_reduce_fusion_config:" << *split_index << " total grads count:" << grads_count
|
||||
<< ". All AllReduce operators will be fused into one.";
|
||||
split_index->clear();
|
||||
split_index->push_back(grads_count - 1);
|
||||
} else if (split_index_num < grads_count - 1) {
|
||||
split_index->push_back(grads_count - 1);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue