From 37ba21c271d25ec95f903b26953e61d1ba9e6ac7 Mon Sep 17 00:00:00 2001 From: dinghao Date: Sun, 12 Apr 2020 09:55:03 +0800 Subject: [PATCH] fix ref pass visit graph bug --- .../pre_activate/ascend/ascend_backend_optimization.cc | 2 ++ .../ascend/format_type/deal_ref_trans_and_cast.cc | 8 ++++++++ .../ascend/format_type/deal_ref_trans_and_cast.h | 1 + mindspore/ccsrc/pre_activate/common/node_pass.cc | 1 + mindspore/ops/_op_impl/tbe/trans_data.py | 8 ++++++-- 5 files changed, 18 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 432d88e7a4f..023838c3a5e 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -90,6 +90,7 @@ void RunOpAscendMixPrecision(const std::shared_ptr &kernel mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); @@ -126,6 +127,7 @@ void AscendMixPrecision(const std::shared_ptr &kernel_grap mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc index fd206114150..81e5c4b4860 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc @@ -22,6 +22,7 @@ #include "kernel/oplib/oplib.h" #include "session/anf_runtime_algorithm.h" #include "session/kernel_graph.h" +#include "pre_activate/common/helper.h" namespace mindspore { namespace opt { @@ -168,11 +169,18 @@ AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn } } // namespace +const BaseRef DealRefTransAndCast::DefinePattern() const { + VarPtr V = std::make_shared(UnVisited); + VarPtr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { if (node == nullptr || !node->isa()) { return nullptr; } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!AnfAlgo::IsRealCNodeKernel(cnode)) { diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h index 9ed55d8b297..1b54a7b111d 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h @@ -28,6 +28,7 @@ class DealRefTransAndCast : public PatternProcessPass { public: explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {} ~DealRefTransAndCast() override = default; + const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; }; } // namespace opt diff --git a/mindspore/ccsrc/pre_activate/common/node_pass.cc b/mindspore/ccsrc/pre_activate/common/node_pass.cc index cd213f8263d..a6e93d2f074 100644 --- a/mindspore/ccsrc/pre_activate/common/node_pass.cc +++ b/mindspore/ccsrc/pre_activate/common/node_pass.cc @@ -45,6 +45,7 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) { bool change = (new_node != nullptr); if (new_node != nullptr && new_node != node) { (void)manager->Replace(node, new_node); + (void)seen_node.erase(node); } else if (new_node == nullptr) { new_node = node; } diff --git a/mindspore/ops/_op_impl/tbe/trans_data.py b/mindspore/ops/_op_impl/tbe/trans_data.py index 1b7c8fa25df..c6628c76381 100644 --- a/mindspore/ops/_op_impl/tbe/trans_data.py +++ b/mindspore/ops/_op_impl/tbe/trans_data.py @@ -46,11 +46,13 @@ from mindspore.ops.op_info_register import op_info_register "dtype": [ "bool", "float","float","float","float","float","float","float","float","float","float", - "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16" + "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16", + "uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16" ], "format": [ "DefaultFormat", "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ", + "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ", "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ" ], "name": "src", @@ -65,11 +67,13 @@ from mindspore.ops.op_info_register import op_info_register "dtype": [ "bool", "float","float","float","float","float","float","float","float","float","float", - "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16" + "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16", + "uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16" ], "format": [ "NC1HWC0", "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN", + "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN", "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN" ], "name": "dst",