forked from mindspore-Ecosystem/mindspore
Fix eliminate get ref parameter
This commit is contained in:
parent
625f2421b5
commit
15b3fba0ef
|
@ -82,6 +82,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
|
||||
// Ref eliminate
|
||||
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
|
||||
get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate",
|
||||
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
|
||||
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
|
||||
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
|
||||
|
||||
|
|
|
@ -58,6 +58,7 @@ class OptimizeIRPassLib {
|
|||
|
||||
// Ref eliminate
|
||||
SubstitutionPtr make_ref_eliminate_;
|
||||
SubstitutionPtr get_ref_param_eliminate_;
|
||||
SubstitutionPtr get_make_ref_eliminate_;
|
||||
SubstitutionPtr replace_refkey_by_param_;
|
||||
SubstitutionPtr replace_old_param_;
|
||||
|
|
|
@ -46,6 +46,26 @@ class MakeRefEliminater : public AnfVisitor {
|
|||
AnfNodePtr y_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimGetRefValue, Parameter} -> Parameter
|
||||
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
|
||||
class GetRefParamEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
x_ = nullptr;
|
||||
AnfVisitor::Match(prim::kPrimGetRefOrigin, {IsParam})(node);
|
||||
if (x_ != nullptr) {
|
||||
return x_;
|
||||
}
|
||||
AnfVisitor::Match(prim::kPrimGetRefValue, {IsParam})(node);
|
||||
return x_;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override { x_ = node; }
|
||||
|
||||
private:
|
||||
AnfNodePtr x_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
|
||||
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
|
||||
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
|
||||
#include "debug/draw.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/anf_ir_utils.h"
|
||||
#include "debug/trace.h"
|
||||
#include "optimizer/opt.h"
|
||||
#include "pipeline/resource.h"
|
||||
|
@ -175,6 +176,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
|
||||
func_graph->DumpFuncGraph(fg_name);
|
||||
DumpIR(fg_name + ".ir", func_graph);
|
||||
ExportIR(fg_name + ".dat", "", func_graph);
|
||||
MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -151,6 +151,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
opt::OptPassConfig b_2 = opt::OptPassConfig({
|
||||
irpass.replace_refkey_by_param_,
|
||||
irpass.make_ref_eliminate_,
|
||||
irpass.get_ref_param_eliminate_,
|
||||
});
|
||||
OptPassGroupMap map({
|
||||
{"b_1", b_1},
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
import numpy as np
|
||||
from mindspore import context, nn, Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, data):
|
||||
super(Net, self).__init__()
|
||||
self.start = Tensor(0, dtype=mstype.int32)
|
||||
self.end = Tensor(2, dtype=mstype.int32)
|
||||
self.max_output = Parameter(data, "output_x")
|
||||
self.upd = P.ScatterNdUpdate()
|
||||
self.zero = Tensor(np.ones([1], dtype=np.int32))
|
||||
|
||||
def construct(self, inputs):
|
||||
idx = self.start
|
||||
end = self.end
|
||||
while idx < end:
|
||||
xi = inputs[idx, :, :]
|
||||
self.upd(self.max_output, idx + self.zero, xi)
|
||||
idx = idx + 1
|
||||
return self.max_output + 0
|
||||
|
||||
|
||||
def test_x():
|
||||
x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32))
|
||||
net = Net(x)
|
||||
net(x)
|
Loading…
Reference in New Issue