Fix eliminate get ref parameter

This commit is contained in:
fary86 2020-06-18 12:09:59 +08:00
parent 625f2421b5
commit 15b3fba0ef
6 changed files with 57 additions and 0 deletions

View File

@ -82,6 +82,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Ref eliminate
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate",
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});

View File

@ -58,6 +58,7 @@ class OptimizeIRPassLib {
// Ref eliminate
SubstitutionPtr make_ref_eliminate_;
SubstitutionPtr get_ref_param_eliminate_;
SubstitutionPtr get_make_ref_eliminate_;
SubstitutionPtr replace_refkey_by_param_;
SubstitutionPtr replace_old_param_;

View File

@ -46,6 +46,26 @@ class MakeRefEliminater : public AnfVisitor {
AnfNodePtr y_{nullptr};
};
// {prim::kPrimGetRefValue, Parameter} -> Parameter
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
class GetRefParamEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
x_ = nullptr;
AnfVisitor::Match(prim::kPrimGetRefOrigin, {IsParam})(node);
if (x_ != nullptr) {
return x_;
}
AnfVisitor::Match(prim::kPrimGetRefValue, {IsParam})(node);
return x_;
}
void Visit(const AnfNodePtr &node) override { x_ = node; }
private:
AnfNodePtr x_{nullptr};
};
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z

View File

@ -29,6 +29,7 @@
#include "debug/draw.h"
#include "debug/anf_ir_dump.h"
#include "debug/anf_ir_utils.h"
#include "debug/trace.h"
#include "optimizer/opt.h"
#include "pipeline/resource.h"
@ -175,6 +176,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
func_graph->DumpFuncGraph(fg_name);
DumpIR(fg_name + ".ir", func_graph);
ExportIR(fg_name + ".dat", "", func_graph);
MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
}
}

View File

@ -151,6 +151,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_,
irpass.get_ref_param_eliminate_,
});
OptPassGroupMap map({
{"b_1", b_1},

View File

@ -0,0 +1,31 @@
import numpy as np
from mindspore import context, nn, Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class Net(nn.Cell):
def __init__(self, data):
super(Net, self).__init__()
self.start = Tensor(0, dtype=mstype.int32)
self.end = Tensor(2, dtype=mstype.int32)
self.max_output = Parameter(data, "output_x")
self.upd = P.ScatterNdUpdate()
self.zero = Tensor(np.ones([1], dtype=np.int32))
def construct(self, inputs):
idx = self.start
end = self.end
while idx < end:
xi = inputs[idx, :, :]
self.upd(self.max_output, idx + self.zero, xi)
idx = idx + 1
return self.max_output + 0
def test_x():
x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32))
net = Net(x)
net(x)