forked from OSSInnovation/mindspore
fix load eliminate bug
This commit is contained in:
parent
0d1d043d80
commit
daecc79e91
|
@ -24,6 +24,18 @@
|
|||
#include "frontend/operator/ops.h"
|
||||
|
||||
namespace mindspore::opt::irpass {
|
||||
// Covert:
|
||||
// load1 = load(para1, u1)
|
||||
// u2 = UpdateState(u1, load1)
|
||||
// ...
|
||||
// load2 = load(load1, u3)
|
||||
// u4 = UpdateState(u3, load2)
|
||||
// To:
|
||||
// load1 = load(para1, u1)
|
||||
// u2 = UpdateState(u1, load1)
|
||||
// ...
|
||||
// load2 = load(para1, u3) # load1 replaced by para1
|
||||
// u4 = UpdateState(u3, load2)
|
||||
AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
auto load_node = dyn_cast<CNode>(node);
|
||||
if (load_node == nullptr || load_node->inputs().empty()) {
|
||||
|
@ -32,8 +44,20 @@ AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &no
|
|||
}
|
||||
auto load_cnode = load_node->cast<CNodePtr>();
|
||||
constexpr size_t kFirstInputIndex = 1;
|
||||
if (IsPrimitiveCNode(load_cnode->input(kFirstInputIndex), prim::kPrimLoad)) {
|
||||
return load_cnode->input(kFirstInputIndex);
|
||||
constexpr size_t kSecondInputIndex = 2;
|
||||
auto &input_load = load_cnode->input(kFirstInputIndex);
|
||||
if (IsPrimitiveCNode(input_load, prim::kPrimLoad)) {
|
||||
auto load_prim = NewValueNode(prim::kPrimLoad);
|
||||
auto input_load_cnode = input_load->cast<CNodePtr>();
|
||||
auto replace_input = input_load_cnode->input(kFirstInputIndex);
|
||||
auto monad = load_cnode->input(kSecondInputIndex);
|
||||
std::vector<AnfNodePtr> new_load_inputs = {load_prim, replace_input, monad};
|
||||
auto fg = load_cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto new_load = fg->NewCNode(new_load_inputs);
|
||||
new_load->set_abstract(load_cnode->abstract());
|
||||
new_load->set_scope(load_cnode->scope());
|
||||
return new_load;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue