diff --git a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc index 41497456377..d978ed033e9 100644 --- a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc @@ -21,70 +21,120 @@ #include #include #include +#include +#include #include "base/core_ops.h" +#include "abstract/abstract_value.h" +#include "utils/ordered_map.h" namespace mindspore { namespace opt { -using MapParamUserIndexs = std::unordered_map>; -std::vector> GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector &toposet, - std::vector *need_replace_loads, - MapParamUserIndexs *unload_users_record, - std::vector *special_op_indexs) { - std::unordered_map load_groups_record; - std::vector> load_groups; +namespace { + +using ParamUserMap = std::unordered_map>; +using LoadGraphMap = OrderedMap>; + +std::optional GetRefKey(const AnfNodePtr &node) { + auto abs = node->abstract(); + if (abs == nullptr) { + // Abstract for some Depends node are not proper set, we follow its input. + if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + return GetRefKey(node->cast()->input(1)); + } + // Abstract should be set except UpdateState nodes. + if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) { + MS_LOG(WARNING) << "Abstract not set for " << node->DebugString(); + } + return std::nullopt; + } + auto abs_ref = abs->cast(); + if (abs_ref == nullptr) { + return std::nullopt; + } + auto ref_key = abs_ref->ref_key_value(); + if (ref_key == nullptr) { + return std::nullopt; + } + return ref_key->name(); +} + +bool HasMemoryEffect(const CNodePtr &cnode) { + const auto &inputs = cnode->inputs(); + if (HasAbstractUMonad(inputs.back())) { + // The last input is UMonad. + return true; + } + constexpr size_t kRequiredArgs = 2; + if (inputs.size() > kRequiredArgs) { + // The last two inputs are UMonad and IOMonad. + return HasAbstractIOMonad(inputs.back()) && HasAbstractUMonad(inputs.rbegin()[1]); + } + return false; +} + +LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector &toposet, + std::vector *need_replace_loads, ParamUserMap *param_users, + std::vector *special_op_indexes) { + LoadGraphMap load_groups; for (size_t i = 0; i < toposet.size(); i++) { - auto &node = toposet[i]; - auto cnode = node->cast(); + auto cnode = dyn_cast(toposet[i]); // Exclude free variable node. if (cnode == nullptr || cnode->func_graph() != fg) { continue; } - bool is_special_op = IsPrimitiveCNode(cnode, prim::kPrimCall) || IsValueNode(cnode->input(0)) || - IsPrimitiveCNode(cnode, prim::kPrimPartial) || IsPrimitiveCNode(cnode, prim::kPrimSwitch) || - IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer); - if (is_special_op) { - (void)special_op_indexs->emplace_back(i); - } - - // Record param user in toposort nodes. - if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) { - for (const auto &input : cnode->inputs()) { - AnfNodePtr cur_param = nullptr; - if (input->isa()) { - cur_param = input; - } else if (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast()->input(1)->isa()) { - cur_param = input->cast()->input(1); - } - if (cur_param != nullptr) { - (void)(*unload_users_record)[cur_param].emplace_back(i); + // Handle Load node. + if (cnode->IsApply(prim::kPrimLoad)) { + auto ref_key = GetRefKey(cnode->input(1)); + if (!ref_key.has_value()) { + MS_LOG(WARNING) << "Load without ref key: " << cnode->DebugString(); + continue; + } + // Group load nodes by their input ref key. + auto &group = load_groups[ref_key.value()]; + (void)group.emplace_back(i); + if (group.size() == 1) { + // The first load user of param in toposort, if it can be replace load(param, ud) with load(param, u), + // Means there are not nodes which modify param before the load. + const bool param_not_used = (param_users->find(ref_key.value()) == param_users->end()); + const bool can_replace = (param_not_used && special_op_indexes->empty()); + if (can_replace) { + (void)need_replace_loads->emplace_back(cnode); } } continue; } - - auto load_param = cnode->input(1); - // first time get same input1 of load. - if (load_groups_record.find(load_param) == load_groups_record.end()) { - load_groups_record[load_param] = load_groups.size(); - load_groups.push_back({i}); - // The first load user of param in toposort, if it can be replace load(param, ud) with load(param, u) - // Means there are not nodes which modify param before the load - bool can_replace = (*unload_users_record)[load_param].empty() && special_op_indexs->empty(); - if (can_replace) { - need_replace_loads->emplace_back(cnode); + // Record special cnode. + bool is_special_op = IsValueNode(cnode->input(0)) || cnode->IsApply(prim::kPrimCall) || + cnode->IsApply(prim::kPrimPartial) || cnode->IsApply(prim::kPrimSwitch) || + cnode->IsApply(prim::kPrimSwitchLayer); + if (is_special_op) { + (void)special_op_indexes->emplace_back(i); + continue; + } + // Record param user in toposort nodes. + // We only check memory side effect cnodes or Depend nodes. + if (HasMemoryEffect(cnode) || cnode->IsApply(prim::kPrimDepend)) { + for (size_t n = 1; n < cnode->size(); ++n) { + const auto &input = cnode->input(n); + auto ref_key = GetRefKey(input); + if (ref_key.has_value()) { + (void)(*param_users)[ref_key.value()].emplace_back(i); + } } - } else { - // not first time get same input1 of load - load_groups[load_groups_record[load_param]].push_back(i); } } return load_groups; } +bool HasIndexBetween(const std::vector &indexes, size_t first, size_t second) { + return std::any_of(indexes.begin(), indexes.end(), + [&first, &second](size_t index) { return index > first && index < second; }); +} + std::vector> SplitGroup(const std::vector &group, - const std::vector &unload_user_indexs, - const std::vector &special_op_indexs) { + const std::vector ¶m_user_indexes, + const std::vector &special_op_indexes) { if (group.size() <= 1) { return {}; } @@ -93,19 +143,13 @@ std::vector> SplitGroup(const std::vector &group, std::vector cur_group = {group[pre_load_index]}; std::vector> split_groups; while (cur_load_index < group.size()) { - const auto &cur_load = group[cur_load_index]; - const auto &prev_load = group[pre_load_index]; + const auto cur_load = group[cur_load_index]; + const auto prev_load = group[pre_load_index]; // Exist node which is the user of load_param between prev_load and cur_load, // Do not divide into the same group. - const auto param_used_by_other = - std::any_of(unload_user_indexs.begin(), unload_user_indexs.end(), - [&cur_load, &prev_load](size_t index) { return index > prev_load && index < cur_load; }); - const auto param_used_by_special_op = - std::any_of(special_op_indexs.begin(), special_op_indexs.end(), - [&cur_load, &prev_load](size_t index) { return index > prev_load && index < cur_load; }); - if (param_used_by_other || param_used_by_special_op) { - split_groups.push_back(cur_group); - cur_group.clear(); + if (HasIndexBetween(param_user_indexes, prev_load, cur_load) || + HasIndexBetween(special_op_indexes, prev_load, cur_load)) { + (void)split_groups.emplace_back(std::move(cur_group)); } cur_group.push_back(cur_load); pre_load_index++; @@ -272,6 +316,7 @@ bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector // Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,... @@ -282,17 +327,17 @@ bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manage // Record the set of the first load of param which no nodes modify param before the load in toposort. std::vector need_replace_loads; // Record the param and the toposort id of the unload user of param, they may modify the value of param. - MapParamUserIndexs unload_users_record; + ParamUserMap param_users; // Record the toposort id of special_op(call, partial, switch, switch_layer), they may modify the value of param. - std::vector special_op_indexs; - std::vector> load_groups = - GenerateLoadGroups(fg, toposet, &need_replace_loads, &unload_users_record, &special_op_indexs); - // split group if there is no-load node between two load nodes. + std::vector special_op_indexes; + auto load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads, ¶m_users, &special_op_indexes); + // Split group if there is no-load node between two load nodes. std::vector> need_merge_loads; - for (auto &group : load_groups) { - auto load_param = toposet[group.back()]->cast()->input(1); - const auto &unload_user_indexs = unload_users_record[load_param]; - auto groups = SplitGroup(group, unload_user_indexs, special_op_indexs); + for (auto &load_group : load_groups) { + auto &ref_key = load_group.first; + auto &group = load_group.second; + const auto ¶m_user_indexes = param_users[ref_key]; + auto groups = SplitGroup(group, param_user_indexes, special_op_indexes); need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end()); } for (auto &group : need_merge_loads) { diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index b77112022a6..65cf3875a60 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -563,6 +563,12 @@ bool Tensor::operator==(const Tensor &tensor) const { } bool Tensor::ValueEqual(const Tensor &tensor) const { + if (is_parameter_ != tensor.is_parameter_) { + return false; + } + if (is_parameter_ && param_info_->name() != tensor.param_info_->name()) { + return false; + } return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_))); } diff --git a/tests/st/auto_monad/test_auto_monad_mindtester.py b/tests/st/auto_monad/test_auto_monad_mindtester.py index 76df8175729..01742f87146 100644 --- a/tests/st/auto_monad/test_auto_monad_mindtester.py +++ b/tests/st/auto_monad/test_auto_monad_mindtester.py @@ -695,3 +695,40 @@ def test_side_effect_grad_control_flow_assign_depend_while_net(): allclose_nparray(out1[1][0].asnumpy(), expect2, 0.001, 0.001) finally: context.set_context(mode=context.GRAPH_MODE) + + +class AssignInZipLoop(Cell): + def __init__(self): + super().__init__() + self.conv1 = ms.nn.Conv2d(3, 2, 1, weight_init="zero") + self.conv2 = ms.nn.Conv2d(3, 2, 1, weight_init="zero") + self.params1 = self.conv1.trainable_params() + self.params2 = self.conv2.trainable_params() + + def construct(self, x): + for p1, p2 in zip(self.params1, self.params2): + P.Assign()(p2, p1 + x) + + out = 0 + for p1, p2 in zip(self.params1, self.params2): + out = p1 + p2 + print(p1) + print(p2) + + return out + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_assign_in_zip_loop(): + """ + Feature: Auto-monad load grouping and merge. + Description: Assign/Load inside a zip loop. + Expectation: 'p1 + p2' should be executed after Assign, and out is 1. + """ + x = Tensor.from_numpy(np.ones([1], np.float32)) + net = AssignInZipLoop() + out = net(x) + assert np.all(out.asnumpy() == 1)