forked from mindspore-Ecosystem/mindspore
Fix order enforce for sub-graph calling
This commit is contained in:
parent
9fa0b7840a
commit
b37b85ab68
|
@ -69,27 +69,35 @@ class OrderEnforcer {
|
|||
}
|
||||
const size_t attach_index = 2;
|
||||
auto &attach = update_state->input(attach_index);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimLoad) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimLoad) && IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
|
||||
// Skip UpdateState for Loads.
|
||||
return;
|
||||
} else if (attach->isa<CNode>()) {
|
||||
EnforceOrderForOtherCNode(attach->cast<CNodePtr>());
|
||||
}
|
||||
// Check previous update_state.
|
||||
auto &prev_u = update_state->input(1);
|
||||
if (!IsPrimitiveCNode(prev_u, prim::kPrimUpdateState)) {
|
||||
// Skip if previous is not UpdateState (maybe a U).
|
||||
return;
|
||||
}
|
||||
// Search side effect cnodes that use previous update_state as input.
|
||||
auto side_effect_nodes = FindNodeUsers(prev_u, [&update_state](const AnfNodePtr &user_node) {
|
||||
return (user_node != update_state) && !IsPrimitiveCNode(user_node, prim::kPrimLoad);
|
||||
});
|
||||
// For such side effect cnodes, try enfore order for them.
|
||||
for (auto &side_effect_node : side_effect_nodes) {
|
||||
HandleSideEffectNode(side_effect_node->cast<CNodePtr>(), prev_u->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckMakeTupleHaveLoad(const CNodePtr &cnode) {
|
||||
auto inputs = cnode->inputs();
|
||||
for (size_t index = 1; index < inputs.size(); index++) {
|
||||
auto input = cnode->input(index);
|
||||
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
bool HasLoadInput(const CNodePtr &cnode) {
|
||||
auto &inputs = cnode->inputs();
|
||||
return std::any_of(inputs.begin() + 1, inputs.end(),
|
||||
[](const AnfNodePtr &input) { return IsPrimitiveCNode(input, prim::kPrimLoad); });
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FindUpdateStateUsers(const CNodePtr &cnode) {
|
||||
std::vector<AnfNodePtr> FindUpdateStateUsers(const AnfNodePtr &node) {
|
||||
auto &node_users = manager_->node_users();
|
||||
auto iter = node_users.find(cnode);
|
||||
auto iter = node_users.find(node);
|
||||
if (iter == node_users.end()) {
|
||||
return {};
|
||||
}
|
||||
|
@ -99,81 +107,76 @@ class OrderEnforcer {
|
|||
auto &user_node = user.first;
|
||||
if (IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) {
|
||||
update_states.emplace_back(user_node);
|
||||
} else if (IsPrimitiveCNode(user_node, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple_users = FindUpdateStateUsers(user_node->cast<CNodePtr>());
|
||||
for (auto make_tuple_user : make_tuple_users) {
|
||||
if (IsPrimitiveCNode(make_tuple_user, prim::kPrimUpdateState)) {
|
||||
update_states.emplace_back(make_tuple_user);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(user_node, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple_users = FindUpdateStateUsers(user_node);
|
||||
update_states.insert(update_states.end(), make_tuple_users.begin(), make_tuple_users.end());
|
||||
}
|
||||
}
|
||||
return update_states;
|
||||
}
|
||||
|
||||
AnfNodePtr FindLastUpdateState(const CNodePtr &cnode) {
|
||||
auto inputs = cnode->inputs();
|
||||
auto &inputs = cnode->inputs();
|
||||
// Find all update_state nodes from the user of input load nodes.
|
||||
std::vector<AnfNodePtr> all_update_states;
|
||||
for (size_t index = 1; index < inputs.size(); index++) {
|
||||
auto input = cnode->input(index);
|
||||
auto &input = inputs[index];
|
||||
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||
std::vector<AnfNodePtr> update_states = FindUpdateStateUsers(input->cast<CNodePtr>());
|
||||
std::copy(update_states.begin(), update_states.end(), std::back_inserter(all_update_states));
|
||||
std::vector<AnfNodePtr> update_states = FindUpdateStateUsers(input);
|
||||
all_update_states.insert(all_update_states.end(), update_states.begin(), update_states.end());
|
||||
}
|
||||
}
|
||||
AnfNodePtr last_update_state = nullptr;
|
||||
if (all_update_states.empty()) {
|
||||
return last_update_state;
|
||||
// Find the last update_state by topo sort order.
|
||||
auto last_update_state =
|
||||
std::max_element(all_update_states.begin(), all_update_states.end(),
|
||||
[this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); });
|
||||
if (last_update_state == all_update_states.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (all_update_states.size() == 1) {
|
||||
return all_update_states[0];
|
||||
}
|
||||
for (size_t i = 0; i < all_update_states.size() - 1; i++) {
|
||||
auto cur_update_state = all_update_states[i];
|
||||
auto next_update_state = all_update_states[i + 1];
|
||||
if (topo_sort_map_[cur_update_state] <= topo_sort_map_[next_update_state]) {
|
||||
last_update_state = next_update_state;
|
||||
}
|
||||
}
|
||||
return last_update_state;
|
||||
return *last_update_state;
|
||||
}
|
||||
|
||||
// Convert:
|
||||
// load1 = Load(para1, u1)
|
||||
// load2 = Load(para2, u2)
|
||||
// maketuple1 = MakeTuple(inputs, load1, load2)
|
||||
// addn = AddN(maketupe1) or other-op
|
||||
// maketuple2 = MakeTuple(load1, load2)
|
||||
// u3 = UpdateState(u', maketuple2)
|
||||
// maketuple1 = MakeTuple(inputs, load1, load2) # the make_tuple we should handle.
|
||||
// addn = AddN(maketupe1) # or other-op, user of the make_tuple
|
||||
// maketuple2 = MakeTuple(load1, load2) # load user
|
||||
// u3 = UpdateState(u', maketuple2) # the last update_state for load users.
|
||||
// assign = Assign(para2, inputs, u3)
|
||||
// To:
|
||||
// load1 = Load(para1, u1)
|
||||
// load2 = Load(para2, u2)
|
||||
// maketuple1 = MakeTuple(inputs, load1, load2)
|
||||
// addn = AddN(maketupe1) or other-op
|
||||
// addn = AddN(maketupe1)
|
||||
// maketuple2 = MakeTuple(load1, load2)
|
||||
// u3 = UpdateState(u', maketuple2, addn) # need put addn or other-op into u3 inputs
|
||||
// assign = Assign(para2, inputs, u3)
|
||||
void HandleMakeTupleUsers(const AnfNodePtr &node) {
|
||||
auto maketuple = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(maketuple);
|
||||
if (CheckMakeTupleHaveLoad(maketuple)) {
|
||||
auto update_state = FindLastUpdateState(maketuple);
|
||||
if (update_state != nullptr) {
|
||||
std::unordered_set<AnfNodePtr> maketuple_users = GetSpecialOperatorRealUsers(maketuple);
|
||||
std::unordered_set<AnfNodePtr> no_push_maketuple_users;
|
||||
// Push and Pull at the end of the execution order,
|
||||
// In order to ensure push and pull operator cut into the same graph, do not put push operator into updatestate
|
||||
for (auto maketuple_user : maketuple_users) {
|
||||
if (!IsPrimitiveCNode(maketuple_user, prim::kPrimPush)) {
|
||||
no_push_maketuple_users.insert(maketuple_user);
|
||||
}
|
||||
}
|
||||
auto update_state_cnode = update_state->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(update_state_cnode);
|
||||
AddInputEdges(update_state_cnode, no_push_maketuple_users);
|
||||
}
|
||||
if (!HasLoadInput(maketuple)) {
|
||||
// MakeTuple without Load input.
|
||||
return;
|
||||
}
|
||||
// Find the last update_state node from users of input Loads.
|
||||
auto update_state = FindLastUpdateState(maketuple);
|
||||
if (update_state == nullptr) {
|
||||
return;
|
||||
}
|
||||
// Users of the make_tuple.
|
||||
auto maketuple_users = FindNodeUsers(maketuple, [](const AnfNodePtr &user_node) {
|
||||
// Push and Pull at the end of the execution order,
|
||||
// In order to ensure push and pull operator cut into the same graph,
|
||||
// we do not put push operator into updatestate.
|
||||
return !IsPrimitiveCNode(user_node, prim::kPrimPush);
|
||||
});
|
||||
// Attach make_tuple users to the update_state.
|
||||
auto update_state_cnode = update_state->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(update_state_cnode);
|
||||
AddInputEdges(update_state_cnode, maketuple_users);
|
||||
}
|
||||
|
||||
bool IsRef(const AnfNodePtr &node) {
|
||||
|
@ -181,61 +184,36 @@ class OrderEnforcer {
|
|||
return abs != nullptr && abs->isa<abstract::AbstractRef>();
|
||||
}
|
||||
|
||||
// Find Load or parameter users as the candidate nodes to enforce order of execution.
|
||||
std::unordered_set<AnfNodePtr> GetSpecialOperatorRealUsers(const AnfNodePtr &node) {
|
||||
auto &node_users = manager_->node_users();
|
||||
auto iter = node_users.find(node);
|
||||
if (iter == node_users.end()) {
|
||||
return {};
|
||||
}
|
||||
std::unordered_set<AnfNodePtr> real_users;
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
auto &user_node = user.first;
|
||||
real_users.insert(user_node);
|
||||
}
|
||||
return real_users;
|
||||
bool IsSpecialPrimitive(const AnfNodePtr &node) const {
|
||||
return IsPrimitiveCNode(node, prim::kPrimExpandDims) || IsPrimitiveCNode(node, prim::kPrimBatchNormGrad);
|
||||
}
|
||||
|
||||
bool IsOneOfPrimitive(const AnfNodePtr &node, const std::set<PrimitivePtr> &special_node_types) const {
|
||||
for (const auto &type : special_node_types) {
|
||||
if (IsPrimitiveCNode(node, type)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void EnforceOrderForOtherCNode(const CNodePtr &cnode) {
|
||||
void HandleSideEffectNode(const CNodePtr &cnode, const CNodePtr &update_state) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Find refs from the cnode inputs.
|
||||
auto &inputs = cnode->inputs();
|
||||
const size_t last_index = inputs.size() - 1;
|
||||
auto last_input = cnode->input(last_index);
|
||||
if (!IsPrimitiveCNode(last_input, prim::kPrimUpdateState)) {
|
||||
return;
|
||||
}
|
||||
const std::set<PrimitivePtr> special_operators = {prim::kPrimExpandDims, prim::kPrimBatchNormGrad};
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto &input = inputs.at(i);
|
||||
if (!IsRef(input)) {
|
||||
auto &input = inputs[i];
|
||||
// Skip non-ref input and update_state.
|
||||
if (!IsRef(input) || input == update_state) {
|
||||
continue;
|
||||
}
|
||||
// load ref users
|
||||
auto loads = FindLoadUsers(input);
|
||||
for (auto load : loads) {
|
||||
std::unordered_set<AnfNodePtr> load_users = FindUsers(load);
|
||||
// The input is a ref (of parameter), find load nodes for it.
|
||||
auto loads = FindLoadNodes(input);
|
||||
for (auto &load : loads) {
|
||||
// Find user nodes of the Load.
|
||||
auto load_users = FindLoadUsers(load);
|
||||
std::unordered_set<AnfNodePtr> real_users;
|
||||
for (auto load_user : load_users) {
|
||||
// check the special operator, only one level of user is considered for now
|
||||
if (IsOneOfPrimitive(load_user, special_operators)) {
|
||||
std::unordered_set<AnfNodePtr> special_real_users = GetSpecialOperatorRealUsers(load_user);
|
||||
for (auto &load_user : load_users) {
|
||||
// Check the special operator, only one level of user is considered for now.
|
||||
if (IsSpecialPrimitive(load_user)) {
|
||||
auto special_real_users = FindNodeUsers(load_user);
|
||||
real_users.insert(special_real_users.begin(), special_real_users.end());
|
||||
} else {
|
||||
real_users.insert(load_user);
|
||||
}
|
||||
}
|
||||
AddInputEdges(last_input->cast<CNodePtr>(), real_users);
|
||||
AddInputEdges(update_state, real_users);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -245,16 +223,15 @@ class OrderEnforcer {
|
|||
const size_t attach_index = 2;
|
||||
const size_t input_size = update_state->inputs().size();
|
||||
for (size_t index = attach_index; index < input_size; index++) {
|
||||
auto attach = update_state->input(attach_index);
|
||||
auto &attach = update_state->input(attach_index);
|
||||
if (attach == load_user) {
|
||||
return true;
|
||||
}
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
|
||||
auto attach_cnode = attach->cast<CNodePtr>();
|
||||
auto inputs = attach_cnode->inputs();
|
||||
bool has_load_user =
|
||||
std::any_of(inputs.begin() + 1, inputs.end(), [load_user](const auto &input) { return input == load_user; });
|
||||
if (has_load_user) {
|
||||
auto &inputs = attach_cnode->inputs();
|
||||
auto iter = std::find(inputs.begin() + 1, inputs.end(), load_user);
|
||||
if (iter != inputs.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -329,43 +306,39 @@ class OrderEnforcer {
|
|||
return topo_sort_map_[node1] < topo_sort_map_[node2];
|
||||
}
|
||||
|
||||
// Find Load or parameter users as the candidate nodes to enforce order of execution.
|
||||
std::unordered_set<AnfNodePtr> FindUsers(const AnfNodePtr &load_or_param) {
|
||||
using PredFunc = std::function<bool(const AnfNodePtr &)>;
|
||||
|
||||
// Find user nodes for the given node.
|
||||
std::unordered_set<AnfNodePtr> FindNodeUsers(const AnfNodePtr &node, PredFunc pred = nullptr) {
|
||||
auto &node_users = manager_->node_users();
|
||||
auto iter = node_users.find(load_or_param);
|
||||
auto iter = node_users.find(node);
|
||||
if (iter == node_users.end()) {
|
||||
return {};
|
||||
}
|
||||
std::unordered_set<AnfNodePtr> load_param_users;
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
std::unordered_set<AnfNodePtr> users;
|
||||
for (auto &user : iter->second) {
|
||||
auto &user_node = user.first;
|
||||
if (processed_nodes_.find(user_node) != processed_nodes_.end()) {
|
||||
// Skip processed nodes.
|
||||
continue;
|
||||
if (pred == nullptr || pred(user_node)) {
|
||||
users.emplace(user_node);
|
||||
}
|
||||
auto cnode = dyn_cast<CNode>(user_node);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
load_param_users.insert(cnode);
|
||||
}
|
||||
return load_param_users;
|
||||
return users;
|
||||
}
|
||||
|
||||
std::unordered_set<AnfNodePtr> FindLoadUsers(const AnfNodePtr ¶m) {
|
||||
auto &node_users = manager_->node_users();
|
||||
auto iter = node_users.find(param);
|
||||
if (iter == node_users.end()) {
|
||||
return {};
|
||||
}
|
||||
std::unordered_set<AnfNodePtr> loads;
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
auto &user_node = user.first;
|
||||
if (IsPrimitiveCNode(user_node, prim::kPrimLoad)) {
|
||||
loads.insert(user_node);
|
||||
}
|
||||
}
|
||||
return loads;
|
||||
// Find Load or parameter users as the candidate nodes to enforce order of execution.
|
||||
std::unordered_set<AnfNodePtr> FindLoadUsers(const AnfNodePtr &load_or_param) {
|
||||
return FindNodeUsers(load_or_param, [this](const AnfNodePtr &user_node) {
|
||||
// Skip processed nodes.
|
||||
return processed_nodes_.find(user_node) == processed_nodes_.end();
|
||||
});
|
||||
}
|
||||
|
||||
// Find Load nodes for a parameter.
|
||||
std::unordered_set<AnfNodePtr> FindLoadNodes(const AnfNodePtr ¶m) {
|
||||
return FindNodeUsers(param, [this](const AnfNodePtr &user_node) {
|
||||
// Search for Load nodes only.
|
||||
return IsPrimitiveCNode(user_node, prim::kPrimLoad);
|
||||
});
|
||||
}
|
||||
|
||||
const FuncGraphPtr &func_graph_;
|
||||
|
|
|
@ -149,6 +149,7 @@ def control_flow_if_after_if_in_if(input_net, x):
|
|||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -157,19 +158,25 @@ def test_if_after_if_in_if():
|
|||
control_flow_if_after_if_in_if(IfAfterIfInIfNet, x)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not supported side effect")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_after_if_in_if_01():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_after_if_in_if(IfAfterIfInIfNet1, x)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not supported side effect")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_after_if_in_if_02():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_after_if_in_if(IfAfterIfInIfNet2, x)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not supported side effect")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_after_if_in_if_03():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_after_if_in_if(IfAfterIfInIfNet3, x)
|
||||
|
|
Loading…
Reference in New Issue