!244 fix ref pass visit graph bug

Merge pull request !244 from dinghao/master
This commit is contained in:
mindspore-ci-bot 2020-04-13 14:38:49 +08:00 committed by Gitee
commit 734f8a7fdb
5 changed files with 18 additions and 2 deletions

View File

@ -90,6 +90,7 @@ void RunOpAscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>()); mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>()); mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
@ -126,6 +127,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>()); mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>()); mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());

View File

@ -22,6 +22,7 @@
#include "kernel/oplib/oplib.h" #include "kernel/oplib/oplib.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "session/kernel_graph.h" #include "session/kernel_graph.h"
#include "pre_activate/common/helper.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -168,11 +169,18 @@ AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn
} }
} // namespace } // namespace
const BaseRef DealRefTransAndCast::DefinePattern() const {
VarPtr V = std::make_shared<CondVar>(UnVisited);
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs});
}
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const { const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>()) { if (node == nullptr || !node->isa<CNode>()) {
return nullptr; return nullptr;
} }
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::IsRealCNodeKernel(cnode)) { if (!AnfAlgo::IsRealCNodeKernel(cnode)) {

View File

@ -28,6 +28,7 @@ class DealRefTransAndCast : public PatternProcessPass {
public: public:
explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {} explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {}
~DealRefTransAndCast() override = default; ~DealRefTransAndCast() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
}; };
} // namespace opt } // namespace opt

View File

@ -45,6 +45,7 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
bool change = (new_node != nullptr); bool change = (new_node != nullptr);
if (new_node != nullptr && new_node != node) { if (new_node != nullptr && new_node != node) {
(void)manager->Replace(node, new_node); (void)manager->Replace(node, new_node);
(void)seen_node.erase(node);
} else if (new_node == nullptr) { } else if (new_node == nullptr) {
new_node = node; new_node = node;
} }

View File

@ -46,11 +46,13 @@ from mindspore.ops.op_info_register import op_info_register
"dtype": [ "dtype": [
"bool", "bool",
"float","float","float","float","float","float","float","float","float","float", "float","float","float","float","float","float","float","float","float","float",
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16" "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16",
"uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16"
], ],
"format": [ "format": [
"DefaultFormat", "DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ", "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ",
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ",
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ" "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ"
], ],
"name": "src", "name": "src",
@ -65,11 +67,13 @@ from mindspore.ops.op_info_register import op_info_register
"dtype": [ "dtype": [
"bool", "bool",
"float","float","float","float","float","float","float","float","float","float", "float","float","float","float","float","float","float","float","float","float",
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16" "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16",
"uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16"
], ],
"format": [ "format": [
"NC1HWC0", "NC1HWC0",
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN", "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN",
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN",
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN" "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN"
], ],
"name": "dst", "name": "dst",