forked from mindspore-Ecosystem/mindspore
!21958 Fix the execution sequence problem of the load in maketuple
Merge pull request !21958 from Margaret_wangrui/load_in_maketuple
This commit is contained in:
commit
9f08cdc4ab
|
@ -122,7 +122,7 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf
|
|||
return nullptr;
|
||||
}
|
||||
auto bn_infer = CreateBNInfer(graph, batchnorm, node);
|
||||
TransferDepend(batchnorm, graph, bn_infer);
|
||||
TransferDependOrUpdateState(batchnorm, graph, bn_infer);
|
||||
return bn_infer;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -125,7 +125,7 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c
|
|||
return nullptr;
|
||||
}
|
||||
auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node);
|
||||
TransferDepend(batchnorm_grad, graph, bn_infer_grad);
|
||||
TransferDependOrUpdateState(batchnorm_grad, graph, bn_infer_grad);
|
||||
return bn_infer_grad;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -916,21 +916,34 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
|
|||
return new_value_node;
|
||||
}
|
||||
|
||||
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
|
||||
void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
|
||||
MS_EXCEPTION_IF_NULL(old_node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
// Find BatchNorm's output which is a Depend or UpdateState.
|
||||
for (const auto &node_index : manager->node_users()[old_node]) {
|
||||
auto node_users = manager->node_users()[old_node];
|
||||
for (const auto &node_index : node_users) {
|
||||
AnfNodePtr output = node_index.first;
|
||||
size_t index = IntToSize(node_index.second);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
|
||||
AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
|
||||
auto depend = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend);
|
||||
depend->set_input(index, new_node);
|
||||
auto output_cnode = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
auto inputs = output_cnode->inputs();
|
||||
std::vector<AnfNodePtr> new_inputs{output_cnode->input(0)};
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
auto input = inputs[i];
|
||||
if (input == old_node) {
|
||||
new_inputs.emplace_back(new_node);
|
||||
} else {
|
||||
new_inputs.emplace_back(input);
|
||||
}
|
||||
}
|
||||
auto new_output = graph->NewCNode(new_inputs);
|
||||
new_output->set_abstract(output->abstract());
|
||||
new_output->set_scope(output->scope());
|
||||
manager->Replace(output, new_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -213,8 +213,8 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor
|
|||
// Create a new value node of func graph,not kernel graph
|
||||
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);
|
||||
|
||||
// Transfer depend to the new node
|
||||
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
|
||||
// Transfer depend or updatestate to the new node
|
||||
void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
|
||||
|
||||
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
|
|
|
@ -76,15 +76,17 @@ class OrderEnforcer {
|
|||
}
|
||||
}
|
||||
|
||||
bool CheckMakeTupleHaveLoad(const CNodePtr &cnode) {
|
||||
std::unordered_set<AnfNodePtr> CheckMakeTupleHaveLoad(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::unordered_set<AnfNodePtr> loads;
|
||||
auto inputs = cnode->inputs();
|
||||
for (size_t index = 1; index < inputs.size(); index++) {
|
||||
auto input = cnode->input(index);
|
||||
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||
return true;
|
||||
loads.insert(input);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return loads;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FindUpdateStateUsers(const CNodePtr &cnode) {
|
||||
|
@ -155,23 +157,31 @@ class OrderEnforcer {
|
|||
// u3 = UpdateState(u', maketuple2, addn) # need put addn or other-op into u3 inputs
|
||||
// assign = Assign(para2, inputs, u3)
|
||||
void HandleMakeTupleUsers(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto maketuple = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(maketuple);
|
||||
if (CheckMakeTupleHaveLoad(maketuple)) {
|
||||
std::unordered_set<AnfNodePtr> loads = CheckMakeTupleHaveLoad(maketuple);
|
||||
if (!loads.empty()) {
|
||||
auto update_state = FindLastUpdateState(maketuple);
|
||||
if (update_state != nullptr) {
|
||||
std::unordered_set<AnfNodePtr> maketuple_users = GetSpecialOperatorRealUsers(maketuple);
|
||||
std::unordered_set<AnfNodePtr> no_push_maketuple_users;
|
||||
std::unordered_set<AnfNodePtr> no_push_all_users;
|
||||
// Push and Pull at the end of the execution order,
|
||||
// In order to ensure push and pull operator cut into the same graph, do not put push operator into updatestate
|
||||
for (auto maketuple_user : maketuple_users) {
|
||||
if (!IsPrimitiveCNode(maketuple_user, prim::kPrimPush)) {
|
||||
no_push_maketuple_users.insert(maketuple_user);
|
||||
no_push_all_users.insert(maketuple_user);
|
||||
}
|
||||
}
|
||||
for (auto load : loads) {
|
||||
std::unordered_set<AnfNodePtr> load_users = GetSpecialOperatorRealUsers(load);
|
||||
for (auto load_user : load_users) {
|
||||
no_push_all_users.insert(load_user);
|
||||
}
|
||||
}
|
||||
auto update_state_cnode = update_state->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(update_state_cnode);
|
||||
AddInputEdges(update_state_cnode, no_push_maketuple_users);
|
||||
AddInputEdges(update_state_cnode, no_push_all_users);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -265,6 +275,8 @@ class OrderEnforcer {
|
|||
// Add load users as input edges of the update_state node.
|
||||
void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) {
|
||||
auto sorted_load_users = SortLoadUsers(load_users);
|
||||
auto inputs = update_state->inputs();
|
||||
size_t origin_size = inputs.size();
|
||||
for (auto &load_user : sorted_load_users) {
|
||||
if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
|
||||
continue;
|
||||
|
@ -272,10 +284,16 @@ class OrderEnforcer {
|
|||
if (!IsDependOn(load_user, update_state)) {
|
||||
processed_nodes_.insert(load_user);
|
||||
if (!IsInUpdateState(load_user, update_state)) {
|
||||
manager_->AddEdge(update_state, load_user);
|
||||
inputs.emplace_back(load_user);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (inputs.size() > origin_size) {
|
||||
auto new_update_state = func_graph_->NewCNode(inputs);
|
||||
new_update_state->set_abstract(update_state->abstract());
|
||||
new_update_state->set_scope(update_state->scope());
|
||||
manager_->Replace(update_state, new_update_state);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort load users by their topo sort order.
|
||||
|
|
|
@ -157,7 +157,9 @@ def test_if_after_if_in_if():
|
|||
control_flow_if_after_if_in_if(IfAfterIfInIfNet, x)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not supported side effect")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_after_if_in_if_01():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_after_if_in_if(IfAfterIfInIfNet1, x)
|
||||
|
|
Loading…
Reference in New Issue