!30634 [ME][Auto_monad] Fix bug:Remove duplicate loads before Load node grouping.

Merge pull request !30634 from Margaret_wangrui/auto_monad_eliminate_2
This commit is contained in:
i-robot 2022-02-28 06:07:48 +00:00 committed by Gitee
commit 68773a9f93
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 54 additions and 5 deletions

View File

@ -66,12 +66,17 @@ bool HasSideEffect(const CNodePtr &cnode) {
return false;
}
LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &toposet,
LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, std::vector<AnfNodePtr> *toposet,
std::vector<AnfNodePtr> *need_replace_loads, ParamUserMap *param_users,
std::vector<size_t> *special_op_indexes) {
LoadGraphMap load_groups;
for (size_t i = 0; i < toposet.size(); i++) {
auto cnode = dyn_cast<CNode>(toposet[i]);
// Record inputs of load and id of load in toposort.
// RefKey --> (Monad --> index).
std::map<std::string, std::map<AnfNodePtr, size_t>> param_monads;
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
for (size_t i = 0; i < toposet->size(); i++) {
auto cnode = dyn_cast<CNode>((*toposet)[i]);
// Exclude free variable node.
if (cnode == nullptr || cnode->func_graph() != fg) {
continue;
@ -85,7 +90,21 @@ LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNod
}
// Group load nodes by their input ref key.
auto &group = load_groups[ref_key.value()];
(void)group.emplace_back(i);
constexpr size_t monad_index = 2;
auto monad = cnode->input(monad_index);
std::map<AnfNodePtr, size_t> &cur_param_monads = param_monads[ref_key.value()];
auto iter = cur_param_monads.find(monad);
// Remove duplicate load which has the same inputs, otherwise there may be an error in the load grouping.
if (iter != cur_param_monads.end()) {
auto id = iter->second;
auto &first_load = (*toposet)[id];
mgr->Replace(cnode, first_load);
(*toposet)[i] = first_load;
continue;
} else {
cur_param_monads[monad] = i;
(void)group.emplace_back(i);
}
if (group.size() == 1) {
// The first load user of param in toposort, if it can be replace load(param, ud) with load(param, u),
// Means there are not nodes which modify param before the load.
@ -326,7 +345,7 @@ bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manage
ParamUserMap param_users;
// Record the toposort id of special_op(call, partial, switch, switch_layer), they may modify the value of param.
std::vector<size_t> special_op_indexes;
auto load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads, &param_users, &special_op_indexes);
auto load_groups = GenerateLoadGroups(fg, &toposet, &need_replace_loads, &param_users, &special_op_indexes);
// Split group if there is no-load node between two load nodes.
std::vector<std::vector<size_t>> need_merge_loads;
for (auto &load_group : load_groups) {

View File

@ -198,3 +198,33 @@ def test_load_convert_tensormove_2():
graph_forword_net = ForwardNet2()
forward_res = graph_forword_net()
assert forward_res == 3
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_load_eliminate():
"""
Feature: Auto monad feature: test load eliminate.
Description: test load eliminate.
Expectation: No exception.
"""
class Net(Cell):
def __init__(self):
super().__init__()
self.assign = P.Assign()
self.variable = Parameter(Tensor(0, ms.float32), name="global")
def construct(self, x):
out = self.variable
self.assign(self.variable, 0)
out = x ** 2 + self.variable + out
self.assign(self.variable, 1)
out = self.variable + out
return out
x = Tensor([2], ms.float32)
net = Net()
out = net(x)
assert out == 5