!20038 fix ReduceEltwiseFusionPass

Merge pull request !20038 from yuchaojie/ub_fusion
This commit is contained in:
i-robot 2021-07-13 02:41:25 +00:00 committed by Gitee
commit 63672de8e6
2 changed files with 5 additions and 3 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,6 +24,7 @@
#include "base/core_ops.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/fusion_id_allocator.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
@ -48,7 +49,8 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
return;
}
if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::COMMREDUCE) {
AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::COMMREDUCE &&
GetNodeOutputTotalUsedNum(kernel_graph, eltwise_input) == 1) {
(void)record.insert(eltwise_input);
auto previous_input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(previous_input_cnode);

View File

@ -1006,7 +1006,7 @@ std::vector<int64_t> GetNodeOutputUsedNum(const session::KernelGraph &kernel_gra
int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
auto output_used_num = GetNodeOutputUsedNum(kernel_graph, node);
return std::accumulate(output_used_num.begin(), output_used_num.end(), 0);
return std::accumulate(output_used_num.begin(), output_used_num.end(), int64_t(0));
}
} // namespace opt
} // namespace mindspore