!34706 Parallel Optimizer in Graph Kernel

Merge pull request !34706 from jiaoy1224/parallel_opt
This commit is contained in:
i-robot 2022-05-24 07:34:24 +00:00 committed by Gitee
commit ed67d9d4e8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 187 additions and 0 deletions

View File

@ -55,6 +55,7 @@
#include "common/graph_kernel/reduce_fake_out_mem.h"
#include "common/graph_kernel/depend_elimination.h"
#include "common/graph_kernel/floatstatus_addn_fusion.h"
#include "common/graph_kernel/parallel_optimizer.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
#include "common/graph_kernel/graph_kernel_build.h"
@ -85,6 +86,9 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() const {
// Spread the MakeTuple input of UpdateState
pm->Add(std::make_shared<SpreadUpdateState>(), OptLevel_1);
// Parallel optimizer by UpdateState reorganization
pm->Add(std::make_shared<ParallelOptimizer>(), OptLevel_2);
// Eliminate the common nodes that generated in SpreadUpdateState
pm->Add(std::make_shared<GraphKernelCSE>(), OptLevel_1);
return pm;

View File

@ -0,0 +1,140 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_kernel/parallel_optimizer.h"
#include <vector>
#include <utility>
#include <algorithm>
#include "include/common/utils/utils.h"
namespace mindspore::graphkernel {
constexpr auto USER_NUM = 2;
constexpr auto REAL_NODE_START_POS = 2;
constexpr auto UMONAD_POS = 1;
namespace {
std::pair<AnfNodePtr, AnfNodePtr> IsTargetUpdateState(const AnfNodePtr &node, const PrimitivePtr &opt_prim,
const FuncGraphManagerPtr &mng) {
auto users = mng->node_users()[node];
if (users.size() != USER_NUM) return std::make_pair(nullptr, nullptr);
if (IsPrimitiveCNode(users.front().first, prim::kPrimUpdateState) && IsPrimitiveCNode(users.back().first, opt_prim)) {
return std::make_pair(users.front().first, users.back().first);
}
if (IsPrimitiveCNode(users.back().first, prim::kPrimUpdateState) && IsPrimitiveCNode(users.front().first, opt_prim)) {
return std::make_pair(users.back().first, users.front().first);
}
return std::make_pair(nullptr, nullptr);
}
// Originally, multiple optimizers are kept in serial order by UpdateState
// If two optimizers are connected through a path started from any of optimizer's parameter input, then a serial order
// is necessary. Otherwise, parallel is reasonable
bool CanParallel(const mindspore::HashSet<AnfNodePtr> &opts_set, const FuncGraphManagerPtr &mng) {
mindspore::HashMap<AnfNodePtr, mindspore::HashSet<AnfNodePtr>> other_nodes_to_opts;
std::function<mindspore::HashSet<AnfNodePtr>(AnfNodePtr)> dfs;
dfs = [&dfs, &other_nodes_to_opts, &opts_set, &mng](const AnfNodePtr &cur_node) {
if (other_nodes_to_opts.count(cur_node) > 0) {
return other_nodes_to_opts[cur_node];
}
auto users = mng->node_users()[cur_node];
mindspore::HashSet<AnfNodePtr> tmp;
for (auto &i : users) {
if (opts_set.count(i.first) > 0) {
(void)tmp.insert(i.first);
} else {
auto res = dfs(i.first);
(void)tmp.insert(res.begin(), res.end());
}
}
other_nodes_to_opts[cur_node] = tmp;
return tmp;
};
for (auto &opt : opts_set) {
auto this_opt = opt->cast<CNodePtr>();
for (size_t i = 1; i < this_opt->inputs().size(); i++) {
if (auto inp = this_opt->input(i); inp->isa<Parameter>()) {
auto joint_opts = dfs(inp);
if (joint_opts.size() != 1 || joint_opts.find(this_opt) == joint_opts.end()) {
return false;
}
}
}
}
return true;
}
void DoParallel(const std::vector<std::pair<AnfNodePtr, AnfNodePtr>> &updatestate_opts, const FuncGraphManagerPtr &mng,
const AnfNodePtr &first_updatestate) {
std::vector<AnfNodePtr> additional_inputs;
for (size_t i = 0; i < updatestate_opts.size(); i++) {
auto opt_cnode = updatestate_opts[i].second->cast<CNodePtr>();
opt_cnode->set_input(opt_cnode->inputs().size() - 1, first_updatestate);
if (i < updatestate_opts.size() - 1) {
auto ups_cnode = updatestate_opts[i].first->cast<CNodePtr>();
additional_inputs.insert(additional_inputs.end(), ups_cnode->inputs().begin() + REAL_NODE_START_POS,
ups_cnode->inputs().end());
}
}
auto last_updatestate = updatestate_opts.back().first->cast<CNodePtr>();
last_updatestate->set_input(UMONAD_POS, first_updatestate);
std::vector<AnfNodePtr> final_inputs = last_updatestate->inputs();
(void)final_inputs.insert(final_inputs.end(), additional_inputs.begin(), additional_inputs.end());
last_updatestate->set_inputs(final_inputs);
}
} // namespace
bool ParallelOptimizer::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(func_graph->get_return());
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<PrimitivePtr> opt_list = {prim::kPrimAdamWeightDecay};
bool graph_change = false;
for (auto &opt_prim : opt_list) {
auto todos = TopoSort(func_graph->get_return());
mindspore::HashSet<AnfNodePtr> visited_updatestates;
std::vector<std::pair<AnfNodePtr, AnfNodePtr>> updatestate_opts;
bool changed = false;
for (auto &node : todos) {
// find pattern: updatestate->optimizer->updatestate->optimizer->...updatestate->optimizer->updatestate
if (IsPrimitiveCNode(node, prim::kPrimUpdateState) &&
visited_updatestates.find(node) == visited_updatestates.end()) {
updatestate_opts.clear();
(void)visited_updatestates.insert(node);
auto res = IsTargetUpdateState(node, opt_prim, mng);
while (res.first != nullptr) {
(void)visited_updatestates.insert(res.first);
(void)updatestate_opts.emplace_back(res);
res = IsTargetUpdateState(updatestate_opts.back().first, opt_prim, mng);
}
mindspore::HashSet<AnfNodePtr> opts_set;
(void)std::for_each(
updatestate_opts.begin(), updatestate_opts.end(),
[&opts_set](const std::pair<AnfNodePtr, AnfNodePtr> &p) { (void)opts_set.insert(p.second); });
if (opts_set.size() > 1 && CanParallel(opts_set, mng)) {
DoParallel(updatestate_opts, mng, node);
changed = true;
}
}
}
if (changed) {
graph_change = true;
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
}
return graph_change;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OPTIMIZER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OPTIMIZER_H_
#include "backend/common/optimizer/pass.h"
namespace mindspore::graphkernel {
/**
* @brief Parallel Optimizer If Order is Meaningless
* @example
* %1 = UpdateState(...)
* %2 = AdamWeightDecay(..., %1)
* %3 = UpdateState(%1, %2)
* %4 = AdamWeightDecay(..., %3)
* %5 = UpdateState(%3, %4)
* ---------->
* %1 = UpdateState(...)
* %2 = AdamWeightDecay(..., %1)
* %3 = AdamWeightDecay(..., %1)
* %4 = UpdateState(%1, %2, %3)
*/
class ParallelOptimizer : public opt::Pass {
public:
ParallelOptimizer() : Pass("parallel_optimizer") {}
~ParallelOptimizer() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OPTIMIZER_H_