Revert "Dealing with the random accuracy problem: parameter and load are equivalent."

This reverts commit 1a2f7e2663.
This commit is contained in:
Margaret_wangrui 2021-09-03 10:12:43 +08:00
parent 54cd78d25c
commit a903cad5b3
1 changed files with 24 additions and 38 deletions

View File

@ -31,13 +31,12 @@ std::unordered_set<FuncGraphPtr> GetAllSubGraphs(const std::unordered_set<AnfNod
for (auto node : call_partial) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto fg_idx = IsValueNode<FuncGraph>(cnode->input(0)) ? 0 : 1;
auto fg_idx = IsPrimitiveCNode(cnode, prim::kPrimCall) ? 0 : 1;
auto fg_value_node = cnode->input(fg_idx)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(fg_value_node);
auto value = fg_value_node->value();
MS_EXCEPTION_IF_NULL(value);
auto caller_fg = value->cast<FuncGraphPtr>();
graphs.insert(caller_fg);
auto sub_graphs = caller_fg->func_graphs_used_total();
for (auto sub_graph : sub_graphs) {
graphs.insert(sub_graph);
@ -47,28 +46,18 @@ std::unordered_set<FuncGraphPtr> GetAllSubGraphs(const std::unordered_set<AnfNod
}
bool HasUMonadCallNodeUser(const FuncGraphPtr &fg, const std::unordered_set<AnfNodePtr> &call_partial,
const AnfNodePtr &param) {
const AnfNodePtr &load_param) {
if (call_partial.empty()) {
return false;
}
auto manager = fg->manager();
auto param_users = manager->node_users()[param];
std::unordered_set<AnfNodePtr> all_users;
for (auto param_user : param_users) {
all_users.insert(param_user.first);
if (IsPrimitiveCNode(param_user.first, prim::kPrimLoad)) {
auto load_users = manager->node_users()[param_user.first];
for (auto load_user : load_users) {
all_users.insert(load_user.first);
}
}
}
auto load_param_users = manager->node_users()[load_param];
std::unordered_set<FuncGraphPtr> sub_graphs = GetAllSubGraphs(call_partial);
for (auto user : all_users) {
if (!user->isa<CNode>()) {
for (auto user : load_param_users) {
if (!user.first->isa<CNode>()) {
continue;
}
auto node = user->cast<CNodePtr>();
auto node = user.first->cast<CNodePtr>();
auto user_graph = node->func_graph();
// Check if user graph is in sub graphs.
bool exist_user_graph = std::any_of(sub_graphs.begin(), sub_graphs.end(),
@ -87,7 +76,7 @@ std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, cons
// Record the param and the user set of param in toposet nodes
std::unordered_map<AnfNodePtr, std::unordered_set<AnfNodePtr>> unload_users_record;
std::unordered_set<AnfNodePtr> call_partial_nodes;
bool has_umonad_call_node_users = false;
std::unordered_map<AnfNodePtr, bool> has_umonad_call_node_users;
for (size_t i = 0; i < toposet.size(); i++) {
auto &node = toposet[i];
auto cnode = node->cast<CNodePtr>();
@ -95,7 +84,6 @@ std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, cons
continue;
}
// Record param user in toposort nodes.
if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
for (const auto &input : cnode->inputs()) {
AnfNodePtr cur_param = nullptr;
if (input->isa<Parameter>()) {
@ -109,9 +97,7 @@ std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, cons
unload_users_record[cur_param].insert(cnode);
}
}
}
if (IsPrimitiveCNode(cnode, prim::kPrimCall) || IsValueNode<FuncGraph>(cnode->input(0)) ||
IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
if (IsPrimitiveCNode(cnode, prim::kPrimCall) || IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
call_partial_nodes.insert(cnode);
}
@ -129,11 +115,11 @@ std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, cons
load_groups.push_back({i});
// If had not user in toposort, should check if has call or partial user
// If already has call node user, do not need check again.
if (!unload_users_record[load_param].empty()) {
if (!unload_users_record[load_param].empty() || has_umonad_call_node_users[load_param] == true) {
continue;
}
has_umonad_call_node_users = HasUMonadCallNodeUser(fg, call_partial_nodes, load_param);
if (!has_umonad_call_node_users) {
has_umonad_call_node_users[load_param] = HasUMonadCallNodeUser(fg, call_partial_nodes, load_param);
if (has_umonad_call_node_users[load_param] == false) {
need_replace_loads->emplace_back(cnode);
}
} else {
@ -165,8 +151,8 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos
return false;
}
// if Call/Switch/SwitchLayer, do not replace load.
if (IsPrimitiveCNode(node, prim::kPrimCall) || IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(0)) ||
IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimSwitchLayer)) {
if (IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch) ||
IsPrimitiveCNode(node, prim::kPrimSwitchLayer)) {
return true;
}
auto cnode = node->cast<CNodePtr>();