!244 fix ref pass visit graph bug

Merge pull request !244 from dinghao/master
This commit is contained in:
mindspore-ci-bot 2020-04-13 14:38:49 +08:00 committed by Gitee
commit 734f8a7fdb
5 changed files with 18 additions and 2 deletions

View File

@ -90,6 +90,7 @@ void RunOpAscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
@ -126,6 +127,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());

View File

@ -22,6 +22,7 @@
#include "kernel/oplib/oplib.h"
#include "session/anf_runtime_algorithm.h"
#include "session/kernel_graph.h"
#include "pre_activate/common/helper.h"
namespace mindspore {
namespace opt {
@ -168,11 +169,18 @@ AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn
}
} // namespace
const BaseRef DealRefTransAndCast::DefinePattern() const {
VarPtr V = std::make_shared<CondVar>(UnVisited);
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs});
}
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::IsRealCNodeKernel(cnode)) {

View File

@ -28,6 +28,7 @@ class DealRefTransAndCast : public PatternProcessPass {
public:
explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {}
~DealRefTransAndCast() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt

View File

@ -45,6 +45,7 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
bool change = (new_node != nullptr);
if (new_node != nullptr && new_node != node) {
(void)manager->Replace(node, new_node);
(void)seen_node.erase(node);
} else if (new_node == nullptr) {
new_node = node;
}

View File

@ -46,11 +46,13 @@ from mindspore.ops.op_info_register import op_info_register
"dtype": [
"bool",
"float","float","float","float","float","float","float","float","float","float",
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16"
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16",
"uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16"
],
"format": [
"DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ",
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ",
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ"
],
"name": "src",
@ -65,11 +67,13 @@ from mindspore.ops.op_info_register import op_info_register
"dtype": [
"bool",
"float","float","float","float","float","float","float","float","float","float",
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16"
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16",
"uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16"
],
"format": [
"NC1HWC0",
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN",
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN",
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN"
],
"name": "dst",