From 15b3fba0ef4933872c0dd0f0f26d03c902447f94 Mon Sep 17 00:00:00 2001 From: fary86 Date: Thu, 18 Jun 2020 12:09:59 +0800 Subject: [PATCH] Fix eliminate get ref parameter --- mindspore/ccsrc/optimizer/irpass.cc | 2 ++ mindspore/ccsrc/optimizer/irpass.h | 1 + .../ccsrc/optimizer/irpass/ref_eliminate.h | 20 ++++++++++++ mindspore/ccsrc/optimizer/optimizer.h | 2 ++ mindspore/ccsrc/pipeline/pass.cc | 1 + .../optimizer/test_while_ScatterNdUpdate.py | 31 +++++++++++++++++++ 6 files changed, 57 insertions(+) create mode 100644 tests/ut/python/optimizer/test_while_ScatterNdUpdate.py diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 13c68416044..a75aa418ed7 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -82,6 +82,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // Ref eliminate make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); + get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate", + {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h index e834d69b699..529776e124b 100644 --- a/mindspore/ccsrc/optimizer/irpass.h +++ b/mindspore/ccsrc/optimizer/irpass.h @@ -58,6 +58,7 @@ class OptimizeIRPassLib { // Ref eliminate SubstitutionPtr make_ref_eliminate_; + SubstitutionPtr get_ref_param_eliminate_; SubstitutionPtr get_make_ref_eliminate_; SubstitutionPtr replace_refkey_by_param_; SubstitutionPtr replace_old_param_; diff --git a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h index 201992ef13d..ab4f9bc32ea 100644 --- a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h @@ -46,6 +46,26 @@ class MakeRefEliminater : public AnfVisitor { AnfNodePtr y_{nullptr}; }; +// {prim::kPrimGetRefValue, Parameter} -> Parameter +// {prim::kPrimGetRefOrigin, Parameter} -> Parameter +class GetRefParamEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + x_ = nullptr; + AnfVisitor::Match(prim::kPrimGetRefOrigin, {IsParam})(node); + if (x_ != nullptr) { + return x_; + } + AnfVisitor::Match(prim::kPrimGetRefValue, {IsParam})(node); + return x_; + } + + void Visit(const AnfNodePtr &node) override { x_ = node; } + + private: + AnfNodePtr x_{nullptr}; +}; + // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y // {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z diff --git a/mindspore/ccsrc/optimizer/optimizer.h b/mindspore/ccsrc/optimizer/optimizer.h index d5808b48188..805543f45b6 100644 --- a/mindspore/ccsrc/optimizer/optimizer.h +++ b/mindspore/ccsrc/optimizer/optimizer.h @@ -29,6 +29,7 @@ #include "debug/draw.h" #include "debug/anf_ir_dump.h" +#include "debug/anf_ir_utils.h" #include "debug/trace.h" #include "optimizer/opt.h" #include "pipeline/resource.h" @@ -175,6 +176,7 @@ class Optimizer : public std::enable_shared_from_this { "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; func_graph->DumpFuncGraph(fg_name); DumpIR(fg_name + ".ir", func_graph); + ExportIR(fg_name + ".dat", "", func_graph); MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; } } diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index 7ee8a4ecb0e..ac2a51a33ba 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -151,6 +151,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, irpass.make_ref_eliminate_, + irpass.get_ref_param_eliminate_, }); OptPassGroupMap map({ {"b_1", b_1}, diff --git a/tests/ut/python/optimizer/test_while_ScatterNdUpdate.py b/tests/ut/python/optimizer/test_while_ScatterNdUpdate.py new file mode 100644 index 00000000000..a21955b2b6d --- /dev/null +++ b/tests/ut/python/optimizer/test_while_ScatterNdUpdate.py @@ -0,0 +1,31 @@ +import numpy as np +from mindspore import context, nn, Tensor, Parameter +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + + +context.set_context(mode=context.GRAPH_MODE, save_graphs=False) + +class Net(nn.Cell): + def __init__(self, data): + super(Net, self).__init__() + self.start = Tensor(0, dtype=mstype.int32) + self.end = Tensor(2, dtype=mstype.int32) + self.max_output = Parameter(data, "output_x") + self.upd = P.ScatterNdUpdate() + self.zero = Tensor(np.ones([1], dtype=np.int32)) + + def construct(self, inputs): + idx = self.start + end = self.end + while idx < end: + xi = inputs[idx, :, :] + self.upd(self.max_output, idx + self.zero, xi) + idx = idx + 1 + return self.max_output + 0 + + +def test_x(): + x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32)) + net = Net(x) + net(x)