!21958 Fix the execution sequence problem of the load in maketuple

Merge pull request !21958 from Margaret_wangrui/load_in_maketuple
This commit is contained in:
i-robot 2021-08-21 15:18:54 +00:00 committed by Gitee
commit 9f08cdc4ab
6 changed files with 52 additions and 19 deletions

View File

@ -122,7 +122,7 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf
return nullptr;
}
auto bn_infer = CreateBNInfer(graph, batchnorm, node);
TransferDepend(batchnorm, graph, bn_infer);
TransferDependOrUpdateState(batchnorm, graph, bn_infer);
return bn_infer;
}
} // namespace opt

View File

@ -125,7 +125,7 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c
return nullptr;
}
auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node);
TransferDepend(batchnorm_grad, graph, bn_infer_grad);
TransferDependOrUpdateState(batchnorm_grad, graph, bn_infer_grad);
return bn_infer_grad;
}
} // namespace opt

View File

@ -916,21 +916,34 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
return new_value_node;
}
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
// Find BatchNorm's output which is a Depend or UpdateState.
for (const auto &node_index : manager->node_users()[old_node]) {
auto node_users = manager->node_users()[old_node];
for (const auto &node_index : node_users) {
AnfNodePtr output = node_index.first;
size_t index = IntToSize(node_index.second);
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
auto depend = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend);
depend->set_input(index, new_node);
auto output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
auto inputs = output_cnode->inputs();
std::vector<AnfNodePtr> new_inputs{output_cnode->input(0)};
for (size_t i = 1; i < inputs.size(); i++) {
auto input = inputs[i];
if (input == old_node) {
new_inputs.emplace_back(new_node);
} else {
new_inputs.emplace_back(input);
}
}
auto new_output = graph->NewCNode(new_inputs);
new_output->set_abstract(output->abstract());
new_output->set_scope(output->scope());
manager->Replace(output, new_output);
}
}
}

View File

@ -213,8 +213,8 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor
// Create a new value node of func graph,not kernel graph
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);
// Transfer depend to the new node
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
// Transfer depend or updatestate to the new node
void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);

View File

@ -76,15 +76,17 @@ class OrderEnforcer {
}
}
bool CheckMakeTupleHaveLoad(const CNodePtr &cnode) {
std::unordered_set<AnfNodePtr> CheckMakeTupleHaveLoad(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
std::unordered_set<AnfNodePtr> loads;
auto inputs = cnode->inputs();
for (size_t index = 1; index < inputs.size(); index++) {
auto input = cnode->input(index);
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
return true;
loads.insert(input);
}
}
return false;
return loads;
}
std::vector<AnfNodePtr> FindUpdateStateUsers(const CNodePtr &cnode) {
@ -155,23 +157,31 @@ class OrderEnforcer {
// u3 = UpdateState(u', maketuple2, addn) # need put addn or other-op into u3 inputs
// assign = Assign(para2, inputs, u3)
void HandleMakeTupleUsers(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto maketuple = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maketuple);
if (CheckMakeTupleHaveLoad(maketuple)) {
std::unordered_set<AnfNodePtr> loads = CheckMakeTupleHaveLoad(maketuple);
if (!loads.empty()) {
auto update_state = FindLastUpdateState(maketuple);
if (update_state != nullptr) {
std::unordered_set<AnfNodePtr> maketuple_users = GetSpecialOperatorRealUsers(maketuple);
std::unordered_set<AnfNodePtr> no_push_maketuple_users;
std::unordered_set<AnfNodePtr> no_push_all_users;
// Push and Pull at the end of the execution order,
// In order to ensure push and pull operator cut into the same graph, do not put push operator into updatestate
for (auto maketuple_user : maketuple_users) {
if (!IsPrimitiveCNode(maketuple_user, prim::kPrimPush)) {
no_push_maketuple_users.insert(maketuple_user);
no_push_all_users.insert(maketuple_user);
}
}
for (auto load : loads) {
std::unordered_set<AnfNodePtr> load_users = GetSpecialOperatorRealUsers(load);
for (auto load_user : load_users) {
no_push_all_users.insert(load_user);
}
}
auto update_state_cnode = update_state->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(update_state_cnode);
AddInputEdges(update_state_cnode, no_push_maketuple_users);
AddInputEdges(update_state_cnode, no_push_all_users);
}
}
}
@ -265,6 +275,8 @@ class OrderEnforcer {
// Add load users as input edges of the update_state node.
void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) {
auto sorted_load_users = SortLoadUsers(load_users);
auto inputs = update_state->inputs();
size_t origin_size = inputs.size();
for (auto &load_user : sorted_load_users) {
if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
continue;
@ -272,10 +284,16 @@ class OrderEnforcer {
if (!IsDependOn(load_user, update_state)) {
processed_nodes_.insert(load_user);
if (!IsInUpdateState(load_user, update_state)) {
manager_->AddEdge(update_state, load_user);
inputs.emplace_back(load_user);
}
}
}
if (inputs.size() > origin_size) {
auto new_update_state = func_graph_->NewCNode(inputs);
new_update_state->set_abstract(update_state->abstract());
new_update_state->set_scope(update_state->scope());
manager_->Replace(update_state, new_update_state);
}
}
// Sort load users by their topo sort order.

View File

@ -157,7 +157,9 @@ def test_if_after_if_in_if():
control_flow_if_after_if_in_if(IfAfterIfInIfNet, x)
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_if_after_if_in_if_01():
x = Tensor(2, mstype.int32)
control_flow_if_after_if_in_if(IfAfterIfInIfNet1, x)