forked from mindspore-Ecosystem/mindspore
Revert "Dealing with the random accuracy problem: parameter and load are equivalent."
This reverts commit 1a2f7e2663
.
This commit is contained in:
parent
54cd78d25c
commit
a903cad5b3
|
@ -31,13 +31,12 @@ std::unordered_set<FuncGraphPtr> GetAllSubGraphs(const std::unordered_set<AnfNod
|
|||
for (auto node : call_partial) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto fg_idx = IsValueNode<FuncGraph>(cnode->input(0)) ? 0 : 1;
|
||||
auto fg_idx = IsPrimitiveCNode(cnode, prim::kPrimCall) ? 0 : 1;
|
||||
auto fg_value_node = cnode->input(fg_idx)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(fg_value_node);
|
||||
auto value = fg_value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto caller_fg = value->cast<FuncGraphPtr>();
|
||||
graphs.insert(caller_fg);
|
||||
auto sub_graphs = caller_fg->func_graphs_used_total();
|
||||
for (auto sub_graph : sub_graphs) {
|
||||
graphs.insert(sub_graph);
|
||||
|
@ -47,28 +46,18 @@ std::unordered_set<FuncGraphPtr> GetAllSubGraphs(const std::unordered_set<AnfNod
|
|||
}
|
||||
|
||||
bool HasUMonadCallNodeUser(const FuncGraphPtr &fg, const std::unordered_set<AnfNodePtr> &call_partial,
|
||||
const AnfNodePtr ¶m) {
|
||||
const AnfNodePtr &load_param) {
|
||||
if (call_partial.empty()) {
|
||||
return false;
|
||||
}
|
||||
auto manager = fg->manager();
|
||||
auto param_users = manager->node_users()[param];
|
||||
std::unordered_set<AnfNodePtr> all_users;
|
||||
for (auto param_user : param_users) {
|
||||
all_users.insert(param_user.first);
|
||||
if (IsPrimitiveCNode(param_user.first, prim::kPrimLoad)) {
|
||||
auto load_users = manager->node_users()[param_user.first];
|
||||
for (auto load_user : load_users) {
|
||||
all_users.insert(load_user.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto load_param_users = manager->node_users()[load_param];
|
||||
std::unordered_set<FuncGraphPtr> sub_graphs = GetAllSubGraphs(call_partial);
|
||||
for (auto user : all_users) {
|
||||
if (!user->isa<CNode>()) {
|
||||
for (auto user : load_param_users) {
|
||||
if (!user.first->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto node = user->cast<CNodePtr>();
|
||||
auto node = user.first->cast<CNodePtr>();
|
||||
auto user_graph = node->func_graph();
|
||||
// Check if user graph is in sub graphs.
|
||||
bool exist_user_graph = std::any_of(sub_graphs.begin(), sub_graphs.end(),
|
||||
|
@ -87,7 +76,7 @@ std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, cons
|
|||
// Record the param and the user set of param in toposet nodes
|
||||
std::unordered_map<AnfNodePtr, std::unordered_set<AnfNodePtr>> unload_users_record;
|
||||
std::unordered_set<AnfNodePtr> call_partial_nodes;
|
||||
bool has_umonad_call_node_users = false;
|
||||
std::unordered_map<AnfNodePtr, bool> has_umonad_call_node_users;
|
||||
for (size_t i = 0; i < toposet.size(); i++) {
|
||||
auto &node = toposet[i];
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -95,23 +84,20 @@ std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, cons
|
|||
continue;
|
||||
}
|
||||
// Record param user in toposort nodes.
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
|
||||
for (const auto &input : cnode->inputs()) {
|
||||
AnfNodePtr cur_param = nullptr;
|
||||
if (input->isa<Parameter>()) {
|
||||
cur_param = input;
|
||||
} else if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||
cur_param = input->cast<CNodePtr>()->input(1);
|
||||
} else if (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast<CNodePtr>()->input(1)->isa<Parameter>()) {
|
||||
cur_param = input->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
if (cur_param != nullptr) {
|
||||
unload_users_record[cur_param].insert(cnode);
|
||||
}
|
||||
for (const auto &input : cnode->inputs()) {
|
||||
AnfNodePtr cur_param = nullptr;
|
||||
if (input->isa<Parameter>()) {
|
||||
cur_param = input;
|
||||
} else if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||
cur_param = input->cast<CNodePtr>()->input(1);
|
||||
} else if (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast<CNodePtr>()->input(1)->isa<Parameter>()) {
|
||||
cur_param = input->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
if (cur_param != nullptr) {
|
||||
unload_users_record[cur_param].insert(cnode);
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimCall) || IsValueNode<FuncGraph>(cnode->input(0)) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimCall) || IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
|
||||
call_partial_nodes.insert(cnode);
|
||||
}
|
||||
|
||||
|
@ -129,11 +115,11 @@ std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, cons
|
|||
load_groups.push_back({i});
|
||||
// If had not user in toposort, should check if has call or partial user
|
||||
// If already has call node user, do not need check again.
|
||||
if (!unload_users_record[load_param].empty()) {
|
||||
if (!unload_users_record[load_param].empty() || has_umonad_call_node_users[load_param] == true) {
|
||||
continue;
|
||||
}
|
||||
has_umonad_call_node_users = HasUMonadCallNodeUser(fg, call_partial_nodes, load_param);
|
||||
if (!has_umonad_call_node_users) {
|
||||
has_umonad_call_node_users[load_param] = HasUMonadCallNodeUser(fg, call_partial_nodes, load_param);
|
||||
if (has_umonad_call_node_users[load_param] == false) {
|
||||
need_replace_loads->emplace_back(cnode);
|
||||
}
|
||||
} else {
|
||||
|
@ -165,8 +151,8 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos
|
|||
return false;
|
||||
}
|
||||
// if Call/Switch/SwitchLayer, do not replace load.
|
||||
if (IsPrimitiveCNode(node, prim::kPrimCall) || IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(0)) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimSwitchLayer)) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimSwitchLayer)) {
|
||||
return true;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
|
Loading…
Reference in New Issue