!22797 Fix wide&deep allreduce_fusion_index bug in PyNative mode

Merge pull request !22797 from caifubi/master-pynative-allreduce-fusion-index
This commit is contained in:
i-robot 2021-09-03 07:18:05 +00:00 committed by Gitee
commit d397fa22ec
1 changed files with 4 additions and 2 deletions

View File

@ -2428,8 +2428,10 @@ void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split
// obtain graph output tensor num
auto grads_count = GetBpropGraphGradsCount(graph);
if (split_index_num >= grads_count) {
MS_LOG(EXCEPTION) << "Invalid all_reduce_fusion_config:" << *split_index
<< ". fusion index should be smaller than:" << grads_count;
MS_LOG(WARNING) << "Invalid all_reduce_fusion_config:" << *split_index << " total grads count:" << grads_count
<< ". All AllReduce operators will be fused into one.";
split_index->clear();
split_index->push_back(grads_count - 1);
} else if (split_index_num < grads_count - 1) {
split_index->push_back(grads_count - 1);
}