forked from mindspore-Ecosystem/mindspore
!244 fix ref pass visit graph bug
Merge pull request !244 from dinghao/master
This commit is contained in:
commit
734f8a7fdb
|
@ -90,6 +90,7 @@ void RunOpAscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel
|
||||||
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
|
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
|
||||||
|
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
|
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
|
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
|
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
|
||||||
|
@ -126,6 +127,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
|
||||||
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
|
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
|
||||||
|
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
|
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
|
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
|
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "kernel/oplib/oplib.h"
|
#include "kernel/oplib/oplib.h"
|
||||||
#include "session/anf_runtime_algorithm.h"
|
#include "session/anf_runtime_algorithm.h"
|
||||||
#include "session/kernel_graph.h"
|
#include "session/kernel_graph.h"
|
||||||
|
#include "pre_activate/common/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -168,11 +169,18 @@ AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
const BaseRef DealRefTransAndCast::DefinePattern() const {
|
||||||
|
VarPtr V = std::make_shared<CondVar>(UnVisited);
|
||||||
|
VarPtr Xs = std::make_shared<SeqVar>();
|
||||||
|
return VectorRef({V, Xs});
|
||||||
|
}
|
||||||
|
|
||||||
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||||
const EquivPtr &) const {
|
const EquivPtr &) const {
|
||||||
if (node == nullptr || !node->isa<CNode>()) {
|
if (node == nullptr || !node->isa<CNode>()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (!AnfAlgo::IsRealCNodeKernel(cnode)) {
|
if (!AnfAlgo::IsRealCNodeKernel(cnode)) {
|
||||||
|
|
|
@ -28,6 +28,7 @@ class DealRefTransAndCast : public PatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {}
|
explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {}
|
||||||
~DealRefTransAndCast() override = default;
|
~DealRefTransAndCast() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
};
|
};
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -45,6 +45,7 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
||||||
bool change = (new_node != nullptr);
|
bool change = (new_node != nullptr);
|
||||||
if (new_node != nullptr && new_node != node) {
|
if (new_node != nullptr && new_node != node) {
|
||||||
(void)manager->Replace(node, new_node);
|
(void)manager->Replace(node, new_node);
|
||||||
|
(void)seen_node.erase(node);
|
||||||
} else if (new_node == nullptr) {
|
} else if (new_node == nullptr) {
|
||||||
new_node = node;
|
new_node = node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,11 +46,13 @@ from mindspore.ops.op_info_register import op_info_register
|
||||||
"dtype": [
|
"dtype": [
|
||||||
"bool",
|
"bool",
|
||||||
"float","float","float","float","float","float","float","float","float","float",
|
"float","float","float","float","float","float","float","float","float","float",
|
||||||
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16"
|
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16",
|
||||||
|
"uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16"
|
||||||
],
|
],
|
||||||
"format": [
|
"format": [
|
||||||
"DefaultFormat",
|
"DefaultFormat",
|
||||||
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ",
|
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ",
|
||||||
|
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ",
|
||||||
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ"
|
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ"
|
||||||
],
|
],
|
||||||
"name": "src",
|
"name": "src",
|
||||||
|
@ -65,11 +67,13 @@ from mindspore.ops.op_info_register import op_info_register
|
||||||
"dtype": [
|
"dtype": [
|
||||||
"bool",
|
"bool",
|
||||||
"float","float","float","float","float","float","float","float","float","float",
|
"float","float","float","float","float","float","float","float","float","float",
|
||||||
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16"
|
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16",
|
||||||
|
"uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16"
|
||||||
],
|
],
|
||||||
"format": [
|
"format": [
|
||||||
"NC1HWC0",
|
"NC1HWC0",
|
||||||
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN",
|
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN",
|
||||||
|
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN",
|
||||||
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN"
|
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN"
|
||||||
],
|
],
|
||||||
"name": "dst",
|
"name": "dst",
|
||||||
|
|
Loading…
Reference in New Issue