forked from mindspore-Ecosystem/mindspore
Group load nodes by ref key
1. Group load nodes by ref key; 2. Compare parameter name in Tensor::ValueEqual().
This commit is contained in:
parent
b3096f3e4f
commit
dcc19144ff
|
@ -21,70 +21,120 @@
|
|||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <optional>
|
||||
|
||||
#include "base/core_ops.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/ordered_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
using MapParamUserIndexs = std::unordered_map<AnfNodePtr, std::vector<size_t>>;
|
||||
std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &toposet,
|
||||
std::vector<AnfNodePtr> *need_replace_loads,
|
||||
MapParamUserIndexs *unload_users_record,
|
||||
std::vector<size_t> *special_op_indexs) {
|
||||
std::unordered_map<AnfNodePtr, size_t> load_groups_record;
|
||||
std::vector<std::vector<size_t>> load_groups;
|
||||
namespace {
|
||||
|
||||
using ParamUserMap = std::unordered_map<std::string, std::vector<size_t>>;
|
||||
using LoadGraphMap = OrderedMap<std::string, std::vector<size_t>>;
|
||||
|
||||
std::optional<std::string> GetRefKey(const AnfNodePtr &node) {
|
||||
auto abs = node->abstract();
|
||||
if (abs == nullptr) {
|
||||
// Abstract for some Depends node are not proper set, we follow its input.
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
return GetRefKey(node->cast<CNodePtr>()->input(1));
|
||||
}
|
||||
// Abstract should be set except UpdateState nodes.
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
MS_LOG(WARNING) << "Abstract not set for " << node->DebugString();
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
auto abs_ref = abs->cast<abstract::AbstractRefPtr>();
|
||||
if (abs_ref == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto ref_key = abs_ref->ref_key_value();
|
||||
if (ref_key == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return ref_key->name();
|
||||
}
|
||||
|
||||
bool HasMemoryEffect(const CNodePtr &cnode) {
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (HasAbstractUMonad(inputs.back())) {
|
||||
// The last input is UMonad.
|
||||
return true;
|
||||
}
|
||||
constexpr size_t kRequiredArgs = 2;
|
||||
if (inputs.size() > kRequiredArgs) {
|
||||
// The last two inputs are UMonad and IOMonad.
|
||||
return HasAbstractIOMonad(inputs.back()) && HasAbstractUMonad(inputs.rbegin()[1]);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &toposet,
|
||||
std::vector<AnfNodePtr> *need_replace_loads, ParamUserMap *param_users,
|
||||
std::vector<size_t> *special_op_indexes) {
|
||||
LoadGraphMap load_groups;
|
||||
for (size_t i = 0; i < toposet.size(); i++) {
|
||||
auto &node = toposet[i];
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto cnode = dyn_cast<CNode>(toposet[i]);
|
||||
// Exclude free variable node.
|
||||
if (cnode == nullptr || cnode->func_graph() != fg) {
|
||||
continue;
|
||||
}
|
||||
bool is_special_op = IsPrimitiveCNode(cnode, prim::kPrimCall) || IsValueNode<FuncGraph>(cnode->input(0)) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimPartial) || IsPrimitiveCNode(cnode, prim::kPrimSwitch) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer);
|
||||
if (is_special_op) {
|
||||
(void)special_op_indexs->emplace_back(i);
|
||||
}
|
||||
|
||||
// Record param user in toposort nodes.
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
|
||||
for (const auto &input : cnode->inputs()) {
|
||||
AnfNodePtr cur_param = nullptr;
|
||||
if (input->isa<Parameter>()) {
|
||||
cur_param = input;
|
||||
} else if (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast<CNodePtr>()->input(1)->isa<Parameter>()) {
|
||||
cur_param = input->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
if (cur_param != nullptr) {
|
||||
(void)(*unload_users_record)[cur_param].emplace_back(i);
|
||||
// Handle Load node.
|
||||
if (cnode->IsApply(prim::kPrimLoad)) {
|
||||
auto ref_key = GetRefKey(cnode->input(1));
|
||||
if (!ref_key.has_value()) {
|
||||
MS_LOG(WARNING) << "Load without ref key: " << cnode->DebugString();
|
||||
continue;
|
||||
}
|
||||
// Group load nodes by their input ref key.
|
||||
auto &group = load_groups[ref_key.value()];
|
||||
(void)group.emplace_back(i);
|
||||
if (group.size() == 1) {
|
||||
// The first load user of param in toposort, if it can be replace load(param, ud) with load(param, u),
|
||||
// Means there are not nodes which modify param before the load.
|
||||
const bool param_not_used = (param_users->find(ref_key.value()) == param_users->end());
|
||||
const bool can_replace = (param_not_used && special_op_indexes->empty());
|
||||
if (can_replace) {
|
||||
(void)need_replace_loads->emplace_back(cnode);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto load_param = cnode->input(1);
|
||||
// first time get same input1 of load.
|
||||
if (load_groups_record.find(load_param) == load_groups_record.end()) {
|
||||
load_groups_record[load_param] = load_groups.size();
|
||||
load_groups.push_back({i});
|
||||
// The first load user of param in toposort, if it can be replace load(param, ud) with load(param, u)
|
||||
// Means there are not nodes which modify param before the load
|
||||
bool can_replace = (*unload_users_record)[load_param].empty() && special_op_indexs->empty();
|
||||
if (can_replace) {
|
||||
need_replace_loads->emplace_back(cnode);
|
||||
// Record special cnode.
|
||||
bool is_special_op = IsValueNode<FuncGraph>(cnode->input(0)) || cnode->IsApply(prim::kPrimCall) ||
|
||||
cnode->IsApply(prim::kPrimPartial) || cnode->IsApply(prim::kPrimSwitch) ||
|
||||
cnode->IsApply(prim::kPrimSwitchLayer);
|
||||
if (is_special_op) {
|
||||
(void)special_op_indexes->emplace_back(i);
|
||||
continue;
|
||||
}
|
||||
// Record param user in toposort nodes.
|
||||
// We only check memory side effect cnodes or Depend nodes.
|
||||
if (HasMemoryEffect(cnode) || cnode->IsApply(prim::kPrimDepend)) {
|
||||
for (size_t n = 1; n < cnode->size(); ++n) {
|
||||
const auto &input = cnode->input(n);
|
||||
auto ref_key = GetRefKey(input);
|
||||
if (ref_key.has_value()) {
|
||||
(void)(*param_users)[ref_key.value()].emplace_back(i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// not first time get same input1 of load
|
||||
load_groups[load_groups_record[load_param]].push_back(i);
|
||||
}
|
||||
}
|
||||
return load_groups;
|
||||
}
|
||||
|
||||
bool HasIndexBetween(const std::vector<size_t> &indexes, size_t first, size_t second) {
|
||||
return std::any_of(indexes.begin(), indexes.end(),
|
||||
[&first, &second](size_t index) { return index > first && index < second; });
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> SplitGroup(const std::vector<size_t> &group,
|
||||
const std::vector<size_t> &unload_user_indexs,
|
||||
const std::vector<size_t> &special_op_indexs) {
|
||||
const std::vector<size_t> ¶m_user_indexes,
|
||||
const std::vector<size_t> &special_op_indexes) {
|
||||
if (group.size() <= 1) {
|
||||
return {};
|
||||
}
|
||||
|
@ -93,19 +143,13 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<size_t> &group,
|
|||
std::vector<size_t> cur_group = {group[pre_load_index]};
|
||||
std::vector<std::vector<size_t>> split_groups;
|
||||
while (cur_load_index < group.size()) {
|
||||
const auto &cur_load = group[cur_load_index];
|
||||
const auto &prev_load = group[pre_load_index];
|
||||
const auto cur_load = group[cur_load_index];
|
||||
const auto prev_load = group[pre_load_index];
|
||||
// Exist node which is the user of load_param between prev_load and cur_load,
|
||||
// Do not divide into the same group.
|
||||
const auto param_used_by_other =
|
||||
std::any_of(unload_user_indexs.begin(), unload_user_indexs.end(),
|
||||
[&cur_load, &prev_load](size_t index) { return index > prev_load && index < cur_load; });
|
||||
const auto param_used_by_special_op =
|
||||
std::any_of(special_op_indexs.begin(), special_op_indexs.end(),
|
||||
[&cur_load, &prev_load](size_t index) { return index > prev_load && index < cur_load; });
|
||||
if (param_used_by_other || param_used_by_special_op) {
|
||||
split_groups.push_back(cur_group);
|
||||
cur_group.clear();
|
||||
if (HasIndexBetween(param_user_indexes, prev_load, cur_load) ||
|
||||
HasIndexBetween(special_op_indexes, prev_load, cur_load)) {
|
||||
(void)split_groups.emplace_back(std::move(cur_group));
|
||||
}
|
||||
cur_group.push_back(cur_load);
|
||||
pre_load_index++;
|
||||
|
@ -272,6 +316,7 @@ bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNode
|
|||
}
|
||||
return change;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... =>
|
||||
// Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,...
|
||||
|
@ -282,17 +327,17 @@ bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manage
|
|||
// Record the set of the first load of param which no nodes modify param before the load in toposort.
|
||||
std::vector<AnfNodePtr> need_replace_loads;
|
||||
// Record the param and the toposort id of the unload user of param, they may modify the value of param.
|
||||
MapParamUserIndexs unload_users_record;
|
||||
ParamUserMap param_users;
|
||||
// Record the toposort id of special_op(call, partial, switch, switch_layer), they may modify the value of param.
|
||||
std::vector<size_t> special_op_indexs;
|
||||
std::vector<std::vector<size_t>> load_groups =
|
||||
GenerateLoadGroups(fg, toposet, &need_replace_loads, &unload_users_record, &special_op_indexs);
|
||||
// split group if there is no-load node between two load nodes.
|
||||
std::vector<size_t> special_op_indexes;
|
||||
auto load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads, ¶m_users, &special_op_indexes);
|
||||
// Split group if there is no-load node between two load nodes.
|
||||
std::vector<std::vector<size_t>> need_merge_loads;
|
||||
for (auto &group : load_groups) {
|
||||
auto load_param = toposet[group.back()]->cast<CNodePtr>()->input(1);
|
||||
const auto &unload_user_indexs = unload_users_record[load_param];
|
||||
auto groups = SplitGroup(group, unload_user_indexs, special_op_indexs);
|
||||
for (auto &load_group : load_groups) {
|
||||
auto &ref_key = load_group.first;
|
||||
auto &group = load_group.second;
|
||||
const auto ¶m_user_indexes = param_users[ref_key];
|
||||
auto groups = SplitGroup(group, param_user_indexes, special_op_indexes);
|
||||
need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end());
|
||||
}
|
||||
for (auto &group : need_merge_loads) {
|
||||
|
|
|
@ -563,6 +563,12 @@ bool Tensor::operator==(const Tensor &tensor) const {
|
|||
}
|
||||
|
||||
bool Tensor::ValueEqual(const Tensor &tensor) const {
|
||||
if (is_parameter_ != tensor.is_parameter_) {
|
||||
return false;
|
||||
}
|
||||
if (is_parameter_ && param_info_->name() != tensor.param_info_->name()) {
|
||||
return false;
|
||||
}
|
||||
return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_)));
|
||||
}
|
||||
|
||||
|
|
|
@ -695,3 +695,40 @@ def test_side_effect_grad_control_flow_assign_depend_while_net():
|
|||
allclose_nparray(out1[1][0].asnumpy(), expect2, 0.001, 0.001)
|
||||
finally:
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class AssignInZipLoop(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = ms.nn.Conv2d(3, 2, 1, weight_init="zero")
|
||||
self.conv2 = ms.nn.Conv2d(3, 2, 1, weight_init="zero")
|
||||
self.params1 = self.conv1.trainable_params()
|
||||
self.params2 = self.conv2.trainable_params()
|
||||
|
||||
def construct(self, x):
|
||||
for p1, p2 in zip(self.params1, self.params2):
|
||||
P.Assign()(p2, p1 + x)
|
||||
|
||||
out = 0
|
||||
for p1, p2 in zip(self.params1, self.params2):
|
||||
out = p1 + p2
|
||||
print(p1)
|
||||
print(p2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_assign_in_zip_loop():
|
||||
"""
|
||||
Feature: Auto-monad load grouping and merge.
|
||||
Description: Assign/Load inside a zip loop.
|
||||
Expectation: 'p1 + p2' should be executed after Assign, and out is 1.
|
||||
"""
|
||||
x = Tensor.from_numpy(np.ones([1], np.float32))
|
||||
net = AssignInZipLoop()
|
||||
out = net(x)
|
||||
assert np.all(out.asnumpy() == 1)
|
||||
|
|
Loading…
Reference in New Issue