!14855 fix a allreduce bug in pynative mode

From: @lvchangquan
Reviewed-by: @zhoufeng54,@jjfeing
Signed-off-by: @jjfeing
This commit is contained in:
mindspore-ci-bot 2021-04-12 09:26:20 +08:00 committed by Gitee
commit 109c29c546
1 changed files with 10 additions and 0 deletions

View File

@ -2294,6 +2294,7 @@ std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const
}
void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) {
MS_EXCEPTION_IF_NULL(split_index);
if (split_index->empty()) {
return;
}
@ -2313,6 +2314,15 @@ void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split
void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
auto parallel_mode = parallel_context->parallel_mode();
if (!pynative_mode || parallel_mode != parallel::DATA_PARALLEL) {
return;
}
SetGraphBpropAttr(graph);
if (!graph->is_bprop()) {