Fix order enforce for sub-graph calling

This commit is contained in:
He Wei 2021-08-30 08:54:21 +08:00
parent 9fa0b7840a
commit b37b85ab68
2 changed files with 117 additions and 137 deletions

View File

@ -69,27 +69,35 @@ class OrderEnforcer {
}
const size_t attach_index = 2;
auto &attach = update_state->input(attach_index);
if (IsPrimitiveCNode(attach, prim::kPrimLoad) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
if (IsPrimitiveCNode(attach, prim::kPrimLoad) && IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
// Skip UpdateState for Loads.
return;
} else if (attach->isa<CNode>()) {
EnforceOrderForOtherCNode(attach->cast<CNodePtr>());
}
// Check previous update_state.
auto &prev_u = update_state->input(1);
if (!IsPrimitiveCNode(prev_u, prim::kPrimUpdateState)) {
// Skip if previous is not UpdateState (maybe a U).
return;
}
// Search side effect cnodes that use previous update_state as input.
auto side_effect_nodes = FindNodeUsers(prev_u, [&update_state](const AnfNodePtr &user_node) {
return (user_node != update_state) && !IsPrimitiveCNode(user_node, prim::kPrimLoad);
});
// For such side effect cnodes, try enfore order for them.
for (auto &side_effect_node : side_effect_nodes) {
HandleSideEffectNode(side_effect_node->cast<CNodePtr>(), prev_u->cast<CNodePtr>());
}
}
bool CheckMakeTupleHaveLoad(const CNodePtr &cnode) {
auto inputs = cnode->inputs();
for (size_t index = 1; index < inputs.size(); index++) {
auto input = cnode->input(index);
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
return true;
}
}
return false;
bool HasLoadInput(const CNodePtr &cnode) {
auto &inputs = cnode->inputs();
return std::any_of(inputs.begin() + 1, inputs.end(),
[](const AnfNodePtr &input) { return IsPrimitiveCNode(input, prim::kPrimLoad); });
}
std::vector<AnfNodePtr> FindUpdateStateUsers(const CNodePtr &cnode) {
std::vector<AnfNodePtr> FindUpdateStateUsers(const AnfNodePtr &node) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(cnode);
auto iter = node_users.find(node);
if (iter == node_users.end()) {
return {};
}
@ -99,81 +107,76 @@ class OrderEnforcer {
auto &user_node = user.first;
if (IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) {
update_states.emplace_back(user_node);
} else if (IsPrimitiveCNode(user_node, prim::kPrimMakeTuple)) {
auto make_tuple_users = FindUpdateStateUsers(user_node->cast<CNodePtr>());
for (auto make_tuple_user : make_tuple_users) {
if (IsPrimitiveCNode(make_tuple_user, prim::kPrimUpdateState)) {
update_states.emplace_back(make_tuple_user);
}
}
continue;
}
if (IsPrimitiveCNode(user_node, prim::kPrimMakeTuple)) {
auto make_tuple_users = FindUpdateStateUsers(user_node);
update_states.insert(update_states.end(), make_tuple_users.begin(), make_tuple_users.end());
}
}
return update_states;
}
AnfNodePtr FindLastUpdateState(const CNodePtr &cnode) {
auto inputs = cnode->inputs();
auto &inputs = cnode->inputs();
// Find all update_state nodes from the user of input load nodes.
std::vector<AnfNodePtr> all_update_states;
for (size_t index = 1; index < inputs.size(); index++) {
auto input = cnode->input(index);
auto &input = inputs[index];
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
std::vector<AnfNodePtr> update_states = FindUpdateStateUsers(input->cast<CNodePtr>());
std::copy(update_states.begin(), update_states.end(), std::back_inserter(all_update_states));
std::vector<AnfNodePtr> update_states = FindUpdateStateUsers(input);
all_update_states.insert(all_update_states.end(), update_states.begin(), update_states.end());
}
}
AnfNodePtr last_update_state = nullptr;
if (all_update_states.empty()) {
return last_update_state;
// Find the last update_state by topo sort order.
auto last_update_state =
std::max_element(all_update_states.begin(), all_update_states.end(),
[this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); });
if (last_update_state == all_update_states.end()) {
return nullptr;
}
if (all_update_states.size() == 1) {
return all_update_states[0];
}
for (size_t i = 0; i < all_update_states.size() - 1; i++) {
auto cur_update_state = all_update_states[i];
auto next_update_state = all_update_states[i + 1];
if (topo_sort_map_[cur_update_state] <= topo_sort_map_[next_update_state]) {
last_update_state = next_update_state;
}
}
return last_update_state;
return *last_update_state;
}
// Convert:
// load1 = Load(para1, u1)
// load2 = Load(para2, u2)
// maketuple1 = MakeTuple(inputs, load1, load2)
// addn = AddN(maketupe1) or other-op
// maketuple2 = MakeTuple(load1, load2)
// u3 = UpdateState(u', maketuple2)
// maketuple1 = MakeTuple(inputs, load1, load2) # the make_tuple we should handle.
// addn = AddN(maketupe1) # or other-op, user of the make_tuple
// maketuple2 = MakeTuple(load1, load2) # load user
// u3 = UpdateState(u', maketuple2) # the last update_state for load users.
// assign = Assign(para2, inputs, u3)
// To:
// load1 = Load(para1, u1)
// load2 = Load(para2, u2)
// maketuple1 = MakeTuple(inputs, load1, load2)
// addn = AddN(maketupe1) or other-op
// addn = AddN(maketupe1)
// maketuple2 = MakeTuple(load1, load2)
// u3 = UpdateState(u', maketuple2, addn) # need put addn or other-op into u3 inputs
// assign = Assign(para2, inputs, u3)
void HandleMakeTupleUsers(const AnfNodePtr &node) {
auto maketuple = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maketuple);
if (CheckMakeTupleHaveLoad(maketuple)) {
auto update_state = FindLastUpdateState(maketuple);
if (update_state != nullptr) {
std::unordered_set<AnfNodePtr> maketuple_users = GetSpecialOperatorRealUsers(maketuple);
std::unordered_set<AnfNodePtr> no_push_maketuple_users;
// Push and Pull at the end of the execution order,
// In order to ensure push and pull operator cut into the same graph, do not put push operator into updatestate
for (auto maketuple_user : maketuple_users) {
if (!IsPrimitiveCNode(maketuple_user, prim::kPrimPush)) {
no_push_maketuple_users.insert(maketuple_user);
}
}
auto update_state_cnode = update_state->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(update_state_cnode);
AddInputEdges(update_state_cnode, no_push_maketuple_users);
}
if (!HasLoadInput(maketuple)) {
// MakeTuple without Load input.
return;
}
// Find the last update_state node from users of input Loads.
auto update_state = FindLastUpdateState(maketuple);
if (update_state == nullptr) {
return;
}
// Users of the make_tuple.
auto maketuple_users = FindNodeUsers(maketuple, [](const AnfNodePtr &user_node) {
// Push and Pull at the end of the execution order,
// In order to ensure push and pull operator cut into the same graph,
// we do not put push operator into updatestate.
return !IsPrimitiveCNode(user_node, prim::kPrimPush);
});
// Attach make_tuple users to the update_state.
auto update_state_cnode = update_state->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(update_state_cnode);
AddInputEdges(update_state_cnode, maketuple_users);
}
bool IsRef(const AnfNodePtr &node) {
@ -181,61 +184,36 @@ class OrderEnforcer {
return abs != nullptr && abs->isa<abstract::AbstractRef>();
}
// Find Load or parameter users as the candidate nodes to enforce order of execution.
std::unordered_set<AnfNodePtr> GetSpecialOperatorRealUsers(const AnfNodePtr &node) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(node);
if (iter == node_users.end()) {
return {};
}
std::unordered_set<AnfNodePtr> real_users;
auto &users = iter->second;
for (auto &user : users) {
auto &user_node = user.first;
real_users.insert(user_node);
}
return real_users;
bool IsSpecialPrimitive(const AnfNodePtr &node) const {
return IsPrimitiveCNode(node, prim::kPrimExpandDims) || IsPrimitiveCNode(node, prim::kPrimBatchNormGrad);
}
bool IsOneOfPrimitive(const AnfNodePtr &node, const std::set<PrimitivePtr> &special_node_types) const {
for (const auto &type : special_node_types) {
if (IsPrimitiveCNode(node, type)) {
return true;
}
}
return false;
}
void EnforceOrderForOtherCNode(const CNodePtr &cnode) {
void HandleSideEffectNode(const CNodePtr &cnode, const CNodePtr &update_state) {
MS_EXCEPTION_IF_NULL(cnode);
// Find refs from the cnode inputs.
auto &inputs = cnode->inputs();
const size_t last_index = inputs.size() - 1;
auto last_input = cnode->input(last_index);
if (!IsPrimitiveCNode(last_input, prim::kPrimUpdateState)) {
return;
}
const std::set<PrimitivePtr> special_operators = {prim::kPrimExpandDims, prim::kPrimBatchNormGrad};
for (size_t i = 1; i < inputs.size(); ++i) {
auto &input = inputs.at(i);
if (!IsRef(input)) {
auto &input = inputs[i];
// Skip non-ref input and update_state.
if (!IsRef(input) || input == update_state) {
continue;
}
// load ref users
auto loads = FindLoadUsers(input);
for (auto load : loads) {
std::unordered_set<AnfNodePtr> load_users = FindUsers(load);
// The input is a ref (of parameter), find load nodes for it.
auto loads = FindLoadNodes(input);
for (auto &load : loads) {
// Find user nodes of the Load.
auto load_users = FindLoadUsers(load);
std::unordered_set<AnfNodePtr> real_users;
for (auto load_user : load_users) {
// check the special operator, only one level of user is considered for now
if (IsOneOfPrimitive(load_user, special_operators)) {
std::unordered_set<AnfNodePtr> special_real_users = GetSpecialOperatorRealUsers(load_user);
for (auto &load_user : load_users) {
// Check the special operator, only one level of user is considered for now.
if (IsSpecialPrimitive(load_user)) {
auto special_real_users = FindNodeUsers(load_user);
real_users.insert(special_real_users.begin(), special_real_users.end());
} else {
real_users.insert(load_user);
}
}
AddInputEdges(last_input->cast<CNodePtr>(), real_users);
AddInputEdges(update_state, real_users);
}
}
}
@ -245,16 +223,15 @@ class OrderEnforcer {
const size_t attach_index = 2;
const size_t input_size = update_state->inputs().size();
for (size_t index = attach_index; index < input_size; index++) {
auto attach = update_state->input(attach_index);
auto &attach = update_state->input(attach_index);
if (attach == load_user) {
return true;
}
if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
auto attach_cnode = attach->cast<CNodePtr>();
auto inputs = attach_cnode->inputs();
bool has_load_user =
std::any_of(inputs.begin() + 1, inputs.end(), [load_user](const auto &input) { return input == load_user; });
if (has_load_user) {
auto &inputs = attach_cnode->inputs();
auto iter = std::find(inputs.begin() + 1, inputs.end(), load_user);
if (iter != inputs.end()) {
return true;
}
}
@ -329,43 +306,39 @@ class OrderEnforcer {
return topo_sort_map_[node1] < topo_sort_map_[node2];
}
// Find Load or parameter users as the candidate nodes to enforce order of execution.
std::unordered_set<AnfNodePtr> FindUsers(const AnfNodePtr &load_or_param) {
using PredFunc = std::function<bool(const AnfNodePtr &)>;
// Find user nodes for the given node.
std::unordered_set<AnfNodePtr> FindNodeUsers(const AnfNodePtr &node, PredFunc pred = nullptr) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(load_or_param);
auto iter = node_users.find(node);
if (iter == node_users.end()) {
return {};
}
std::unordered_set<AnfNodePtr> load_param_users;
auto &users = iter->second;
for (auto &user : users) {
std::unordered_set<AnfNodePtr> users;
for (auto &user : iter->second) {
auto &user_node = user.first;
if (processed_nodes_.find(user_node) != processed_nodes_.end()) {
// Skip processed nodes.
continue;
if (pred == nullptr || pred(user_node)) {
users.emplace(user_node);
}
auto cnode = dyn_cast<CNode>(user_node);
MS_EXCEPTION_IF_NULL(cnode);
load_param_users.insert(cnode);
}
return load_param_users;
return users;
}
std::unordered_set<AnfNodePtr> FindLoadUsers(const AnfNodePtr &param) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(param);
if (iter == node_users.end()) {
return {};
}
std::unordered_set<AnfNodePtr> loads;
auto &users = iter->second;
for (auto &user : users) {
auto &user_node = user.first;
if (IsPrimitiveCNode(user_node, prim::kPrimLoad)) {
loads.insert(user_node);
}
}
return loads;
// Find Load or parameter users as the candidate nodes to enforce order of execution.
std::unordered_set<AnfNodePtr> FindLoadUsers(const AnfNodePtr &load_or_param) {
return FindNodeUsers(load_or_param, [this](const AnfNodePtr &user_node) {
// Skip processed nodes.
return processed_nodes_.find(user_node) == processed_nodes_.end();
});
}
// Find Load nodes for a parameter.
std::unordered_set<AnfNodePtr> FindLoadNodes(const AnfNodePtr &param) {
return FindNodeUsers(param, [this](const AnfNodePtr &user_node) {
// Search for Load nodes only.
return IsPrimitiveCNode(user_node, prim::kPrimLoad);
});
}
const FuncGraphPtr &func_graph_;

View File

@ -149,6 +149,7 @@ def control_flow_if_after_if_in_if(input_net, x):
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -157,19 +158,25 @@ def test_if_after_if_in_if():
control_flow_if_after_if_in_if(IfAfterIfInIfNet, x)
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_if_after_if_in_if_01():
x = Tensor(2, mstype.int32)
control_flow_if_after_if_in_if(IfAfterIfInIfNet1, x)
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_if_after_if_in_if_02():
x = Tensor(2, mstype.int32)
control_flow_if_after_if_in_if(IfAfterIfInIfNet2, x)
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_if_after_if_in_if_03():
x = Tensor(2, mstype.int32)
control_flow_if_after_if_in_if(IfAfterIfInIfNet3, x)