forked from mindspore-Ecosystem/mindspore
!244 fix ref pass visit graph bug
Merge pull request !244 from dinghao/master
This commit is contained in:
commit
734f8a7fdb
|
@ -90,6 +90,7 @@ void RunOpAscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel
|
|||
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
|
||||
|
@ -126,6 +127,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
|
|||
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "kernel/oplib/oplib.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "session/kernel_graph.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -168,11 +169,18 @@ AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn
|
|||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef DealRefTransAndCast::DefinePattern() const {
|
||||
VarPtr V = std::make_shared<CondVar>(UnVisited);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({V, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!AnfAlgo::IsRealCNodeKernel(cnode)) {
|
||||
|
|
|
@ -28,6 +28,7 @@ class DealRefTransAndCast : public PatternProcessPass {
|
|||
public:
|
||||
explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {}
|
||||
~DealRefTransAndCast() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -45,6 +45,7 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
|||
bool change = (new_node != nullptr);
|
||||
if (new_node != nullptr && new_node != node) {
|
||||
(void)manager->Replace(node, new_node);
|
||||
(void)seen_node.erase(node);
|
||||
} else if (new_node == nullptr) {
|
||||
new_node = node;
|
||||
}
|
||||
|
|
|
@ -46,11 +46,13 @@ from mindspore.ops.op_info_register import op_info_register
|
|||
"dtype": [
|
||||
"bool",
|
||||
"float","float","float","float","float","float","float","float","float","float",
|
||||
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16"
|
||||
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16",
|
||||
"uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ"
|
||||
],
|
||||
"name": "src",
|
||||
|
@ -65,11 +67,13 @@ from mindspore.ops.op_info_register import op_info_register
|
|||
"dtype": [
|
||||
"bool",
|
||||
"float","float","float","float","float","float","float","float","float","float",
|
||||
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16"
|
||||
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16",
|
||||
"uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0",
|
||||
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN",
|
||||
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN",
|
||||
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN"
|
||||
],
|
||||
"name": "dst",
|
||||
|
|
Loading…
Reference in New Issue