forked from mindspore-Ecosystem/mindspore
!30634 [ME][Auto_monad] Fix bug:Remove duplicate loads before Load node grouping.
Merge pull request !30634 from Margaret_wangrui/auto_monad_eliminate_2
This commit is contained in:
commit
68773a9f93
|
@ -66,12 +66,17 @@ bool HasSideEffect(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
|
||||
LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &toposet,
|
||||
LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, std::vector<AnfNodePtr> *toposet,
|
||||
std::vector<AnfNodePtr> *need_replace_loads, ParamUserMap *param_users,
|
||||
std::vector<size_t> *special_op_indexes) {
|
||||
LoadGraphMap load_groups;
|
||||
for (size_t i = 0; i < toposet.size(); i++) {
|
||||
auto cnode = dyn_cast<CNode>(toposet[i]);
|
||||
// Record inputs of load and id of load in toposort.
|
||||
// RefKey --> (Monad --> index).
|
||||
std::map<std::string, std::map<AnfNodePtr, size_t>> param_monads;
|
||||
auto mgr = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mgr);
|
||||
for (size_t i = 0; i < toposet->size(); i++) {
|
||||
auto cnode = dyn_cast<CNode>((*toposet)[i]);
|
||||
// Exclude free variable node.
|
||||
if (cnode == nullptr || cnode->func_graph() != fg) {
|
||||
continue;
|
||||
|
@ -85,7 +90,21 @@ LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNod
|
|||
}
|
||||
// Group load nodes by their input ref key.
|
||||
auto &group = load_groups[ref_key.value()];
|
||||
(void)group.emplace_back(i);
|
||||
constexpr size_t monad_index = 2;
|
||||
auto monad = cnode->input(monad_index);
|
||||
std::map<AnfNodePtr, size_t> &cur_param_monads = param_monads[ref_key.value()];
|
||||
auto iter = cur_param_monads.find(monad);
|
||||
// Remove duplicate load which has the same inputs, otherwise there may be an error in the load grouping.
|
||||
if (iter != cur_param_monads.end()) {
|
||||
auto id = iter->second;
|
||||
auto &first_load = (*toposet)[id];
|
||||
mgr->Replace(cnode, first_load);
|
||||
(*toposet)[i] = first_load;
|
||||
continue;
|
||||
} else {
|
||||
cur_param_monads[monad] = i;
|
||||
(void)group.emplace_back(i);
|
||||
}
|
||||
if (group.size() == 1) {
|
||||
// The first load user of param in toposort, if it can be replace load(param, ud) with load(param, u),
|
||||
// Means there are not nodes which modify param before the load.
|
||||
|
@ -326,7 +345,7 @@ bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manage
|
|||
ParamUserMap param_users;
|
||||
// Record the toposort id of special_op(call, partial, switch, switch_layer), they may modify the value of param.
|
||||
std::vector<size_t> special_op_indexes;
|
||||
auto load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads, ¶m_users, &special_op_indexes);
|
||||
auto load_groups = GenerateLoadGroups(fg, &toposet, &need_replace_loads, ¶m_users, &special_op_indexes);
|
||||
// Split group if there is no-load node between two load nodes.
|
||||
std::vector<std::vector<size_t>> need_merge_loads;
|
||||
for (auto &load_group : load_groups) {
|
||||
|
|
|
@ -198,3 +198,33 @@ def test_load_convert_tensormove_2():
|
|||
graph_forword_net = ForwardNet2()
|
||||
forward_res = graph_forword_net()
|
||||
assert forward_res == 3
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_load_eliminate():
|
||||
"""
|
||||
Feature: Auto monad feature: test load eliminate.
|
||||
Description: test load eliminate.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.assign = P.Assign()
|
||||
self.variable = Parameter(Tensor(0, ms.float32), name="global")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.variable
|
||||
self.assign(self.variable, 0)
|
||||
out = x ** 2 + self.variable + out
|
||||
self.assign(self.variable, 1)
|
||||
out = self.variable + out
|
||||
return out
|
||||
|
||||
x = Tensor([2], ms.float32)
|
||||
net = Net()
|
||||
out = net(x)
|
||||
assert out == 5
|
||||
|
|
Loading…
Reference in New Issue