forked from mindspore-Ecosystem/mindspore
Fix eliminate get ref parameter
This commit is contained in:
parent
625f2421b5
commit
15b3fba0ef
|
@ -82,6 +82,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
|
|
||||||
// Ref eliminate
|
// Ref eliminate
|
||||||
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
|
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
|
||||||
|
get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate",
|
||||||
|
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
|
||||||
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
|
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
|
||||||
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
|
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
|
||||||
|
|
||||||
|
|
|
@ -58,6 +58,7 @@ class OptimizeIRPassLib {
|
||||||
|
|
||||||
// Ref eliminate
|
// Ref eliminate
|
||||||
SubstitutionPtr make_ref_eliminate_;
|
SubstitutionPtr make_ref_eliminate_;
|
||||||
|
SubstitutionPtr get_ref_param_eliminate_;
|
||||||
SubstitutionPtr get_make_ref_eliminate_;
|
SubstitutionPtr get_make_ref_eliminate_;
|
||||||
SubstitutionPtr replace_refkey_by_param_;
|
SubstitutionPtr replace_refkey_by_param_;
|
||||||
SubstitutionPtr replace_old_param_;
|
SubstitutionPtr replace_old_param_;
|
||||||
|
|
|
@ -46,6 +46,26 @@ class MakeRefEliminater : public AnfVisitor {
|
||||||
AnfNodePtr y_{nullptr};
|
AnfNodePtr y_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// {prim::kPrimGetRefValue, Parameter} -> Parameter
|
||||||
|
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
|
||||||
|
class GetRefParamEliminater : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
x_ = nullptr;
|
||||||
|
AnfVisitor::Match(prim::kPrimGetRefOrigin, {IsParam})(node);
|
||||||
|
if (x_ != nullptr) {
|
||||||
|
return x_;
|
||||||
|
}
|
||||||
|
AnfVisitor::Match(prim::kPrimGetRefValue, {IsParam})(node);
|
||||||
|
return x_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Visit(const AnfNodePtr &node) override { x_ = node; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
AnfNodePtr x_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
|
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
|
||||||
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
|
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
|
||||||
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
|
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
|
|
||||||
#include "debug/draw.h"
|
#include "debug/draw.h"
|
||||||
#include "debug/anf_ir_dump.h"
|
#include "debug/anf_ir_dump.h"
|
||||||
|
#include "debug/anf_ir_utils.h"
|
||||||
#include "debug/trace.h"
|
#include "debug/trace.h"
|
||||||
#include "optimizer/opt.h"
|
#include "optimizer/opt.h"
|
||||||
#include "pipeline/resource.h"
|
#include "pipeline/resource.h"
|
||||||
|
@ -175,6 +176,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
||||||
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
|
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
|
||||||
func_graph->DumpFuncGraph(fg_name);
|
func_graph->DumpFuncGraph(fg_name);
|
||||||
DumpIR(fg_name + ".ir", func_graph);
|
DumpIR(fg_name + ".ir", func_graph);
|
||||||
|
ExportIR(fg_name + ".dat", "", func_graph);
|
||||||
MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
|
MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -151,6 +151,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
opt::OptPassConfig b_2 = opt::OptPassConfig({
|
opt::OptPassConfig b_2 = opt::OptPassConfig({
|
||||||
irpass.replace_refkey_by_param_,
|
irpass.replace_refkey_by_param_,
|
||||||
irpass.make_ref_eliminate_,
|
irpass.make_ref_eliminate_,
|
||||||
|
irpass.get_ref_param_eliminate_,
|
||||||
});
|
});
|
||||||
OptPassGroupMap map({
|
OptPassGroupMap map({
|
||||||
{"b_1", b_1},
|
{"b_1", b_1},
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import context, nn, Tensor, Parameter
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, data):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.start = Tensor(0, dtype=mstype.int32)
|
||||||
|
self.end = Tensor(2, dtype=mstype.int32)
|
||||||
|
self.max_output = Parameter(data, "output_x")
|
||||||
|
self.upd = P.ScatterNdUpdate()
|
||||||
|
self.zero = Tensor(np.ones([1], dtype=np.int32))
|
||||||
|
|
||||||
|
def construct(self, inputs):
|
||||||
|
idx = self.start
|
||||||
|
end = self.end
|
||||||
|
while idx < end:
|
||||||
|
xi = inputs[idx, :, :]
|
||||||
|
self.upd(self.max_output, idx + self.zero, xi)
|
||||||
|
idx = idx + 1
|
||||||
|
return self.max_output + 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_x():
|
||||||
|
x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32))
|
||||||
|
net = Net(x)
|
||||||
|
net(x)
|
Loading…
Reference in New Issue