!2048 fix bug of multioutput byway fusion pass

Merge pull request !2048 from Etone.Chan/June
This commit is contained in:
mindspore-ci-bot 2020-06-12 14:37:55 +08:00 committed by Gitee
commit dae35e0a71
2 changed files with 2 additions and 2 deletions

View File

@ -68,7 +68,6 @@
#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h"
#include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h"
@ -349,7 +348,6 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
auto ub_fusion_pm = std::make_shared<PassManager>("ub_fusion_pm");
ub_fusion_pm->AddPass(std::make_shared<Conv2DBackpropEltwiseEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<Conv2DBackpropEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<StridedReadConvStridedWriteFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<ConvBnReduceFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<ConvSingleInFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseFusionPass>(fusion_id_allocator));

View File

@ -36,6 +36,8 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
if (CheckMultiOutputEltWiseNode(manager.get(), eltwise_input)) {
std::vector<int> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
(void)record.insert(eltwise_input);
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);