insert memcpy async if hccl op cascade

This commit is contained in:
huanghui 2020-07-21 20:37:41 +08:00
parent 48325dea3b
commit f1563d2d37
7 changed files with 289 additions and 59 deletions

View File

@ -87,6 +87,7 @@
#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h"
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h"
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h"
#include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h"
#include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h"
@ -340,6 +341,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
other_pm->AddPass(std::make_shared<AllGatherFusion>());
other_pm->AddPass(std::make_shared<ReduceScatterFusion>());
other_pm->AddPass(std::make_shared<BroadcastFusion>());
other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>());
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
optimizer->AddPassManager(other_pm);

View File

@ -0,0 +1,114 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
#include <vector>
#include <set>
#include <string>
#include "utils/utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "frontend/optimizer/opt.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
namespace {
bool IsPartOutputsOfHcclOp(const AnfNodePtr &node, const CNodePtr &cur_hccl, const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(cur_hccl);
MS_EXCEPTION_IF_NULL(graph);
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prev_node = cnode->input(kRealInputNodeIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(prev_node);
if (!AnfAlgo::IsCommunicationOp(prev_node)) {
return false;
}
auto prev_hccl_op = prev_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(prev_hccl_op);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto iter = node_users.find(prev_hccl_op);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
for (const auto &node_index : iter->second) {
AnfNodePtr output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
bool is_contain = false;
for (size_t i = 1; i < cur_hccl->size(); ++i) {
if (cur_hccl->input(i) == output) {
is_contain = true;
break;
}
}
if (!is_contain) {
return true;
}
}
}
return false;
}
} // namespace
AnfNodePtr InsertMemcpyAsyncForCascade::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(hccl_node);
std::vector<AnfNodePtr> memcpy_async_list;
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
for (size_t i = 1; i < hccl_node->size(); ++i) {
auto input = hccl_node->input(i);
MS_EXCEPTION_IF_NULL(input);
// when input is also a hccl op and just part outputs of it linking with cur_hccl_op
if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) {
auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
auto kernel_info = std::make_shared<device::KernelInfo>();
memcpy_async->set_kernel_info(kernel_info);
MS_EXCEPTION_IF_NULL(kernel_select_);
kernel_select_->SelectKernel(memcpy_async->cast<CNodePtr>());
new_inputs.push_back(memcpy_async);
memcpy_async_list.push_back(memcpy_async);
} else {
new_inputs.push_back(input);
}
}
if (!memcpy_async_list.empty()) {
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
new_hccl_node->set_inputs(new_inputs);
return new_hccl_node;
}
return nullptr;
}
const AnfNodePtr InsertMemcpyAsyncForCascade::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (!AnfAlgo::IsCommunicationOp(node)) {
return nullptr;
}
return InsertMemcpyAsync(func_graph, cnode);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
class InsertMemcpyAsyncForCascade : public PatternProcessPass {
public:
explicit InsertMemcpyAsyncForCascade(bool multigraph = true)
: PatternProcessPass("insert_memcpy_async_for_cascade", multigraph),
kernel_select_(std::make_shared<KernelSelect>()) {}
~InsertMemcpyAsyncForCascade() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const;
KernelSelectPtr kernel_select_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_OP_CASCADE_H_

View File

@ -32,12 +32,17 @@ const std::set<std::string> kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNe
bool IsParameterOrValueNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>();
auto real_node = kernel_with_index.first;
MS_EXCEPTION_IF_NULL(real_node);
if (real_node->isa<Parameter>()) {
return true;
}
return real_node->isa<ValueNode>();
}
void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) {
void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list,
const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(hccl_node);
MS_EXCEPTION_IF_NULL(memcpy_async);
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
@ -48,49 +53,62 @@ void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async,
}
// find hccl_node's output which is a control depend
for (const auto &node_index : iter->second) {
AnfNodePtr output = node_index.first;
int output_index = node_index.second;
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
CNodePtr control_depend = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(control_depend);
std::vector<AnfNodePtr> new_inputs;
for (size_t i = 0; i < control_depend->size(); ++i) {
if (i == IntToSize(output_index)) {
new_inputs.push_back(memcpy_async);
} else {
new_inputs.push_back(control_depend->input(i));
}
}
control_depend->set_inputs(new_inputs);
if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimControlDepend)) {
continue;
}
CNodePtr control_depend = node_index.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(control_depend);
std::vector<AnfNodePtr> new_inputs;
for (size_t i = 0; i < control_depend->size(); ++i) {
if (i == IntToSize(node_index.second)) {
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end());
make_tuple_inputs.emplace_back(hccl_node);
auto make_tuple = graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple);
new_inputs.push_back(make_tuple);
} else {
new_inputs.push_back(control_depend->input(i));
}
}
control_depend->set_inputs(new_inputs);
}
}
} // namespace
bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const {
bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input,
const CNodePtr &cur_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input);
MS_EXCEPTION_IF_NULL(cur_node);
// when input is a parameter or is a value node
if (IsParameterOrValueNode(input)) {
return true;
}
// when input is a Ref or some special cnodes
if (kernel_query_->IsTbeRef(input) ||
kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) {
return true;
}
if (input->isa<CNode>()) {
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto iter = node_users.find(input);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
// when input is used by others
if (iter->second.size() > 1) {
return true;
// when input is a Ref cnode
if (kernel_query_->IsTbeRef(input)) {
return true;
}
// when input is some special cnodes
if (kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) {
return true;
}
// when input is used by others
auto iter = node_users.find(input);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
if (iter->second.size() > 1) {
return true;
}
}
return false;
}
@ -98,21 +116,20 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(hccl_node);
bool has_insert_memcpy = false;
AnfNodePtr memcpy_async = nullptr;
std::vector<AnfNodePtr> memcpy_async_list;
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
for (size_t i = 1; i < hccl_node->size(); ++i) {
auto input = hccl_node->input(i);
if (NeedInsertMemcpy(graph, input)) {
memcpy_async = CreateMemcpyAsyncOp(graph, input);
has_insert_memcpy = true;
if (NeedInsertMemcpy(graph, input, hccl_node)) {
auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
new_inputs.push_back(memcpy_async);
memcpy_async_list.push_back(memcpy_async);
} else {
new_inputs.push_back(input);
}
}
if (has_insert_memcpy) {
if (!memcpy_async_list.empty()) {
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
new_hccl_node->set_inputs(new_inputs);
auto manager = graph->manager();
@ -122,9 +139,7 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
MS_LOG(DEBUG) << "end replace";
// transer hccl op's control to the memcpy_async
if (hccl_node->size() == 2) {
TransferControl(new_hccl_node, memcpy_async, graph);
}
TransferControl(new_hccl_node, memcpy_async_list, graph);
}
}

View File

@ -32,7 +32,7 @@ class InsertMemcpyAsyncForHcclOp : public PatternProcessPass {
private:
void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const;
bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const;
bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &cur_node) const;
KernelQueryPtr kernel_query_;
};
} // namespace opt

View File

@ -22,6 +22,7 @@
#include "utils/utils.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/optimizer/common/optimizer.h"
#include "ir/param_value.h"
#define private public
#define protected public
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
@ -44,12 +45,10 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery {
~MockInsertMemcpyForHcclKernelQuery() override = default;
bool IsTbeRef(const AnfNodePtr &node) override {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
if (!node->isa<CNode>()) {
return false;
}
auto name = AnfAlgo::GetCNodeName(cnode);
return name == "ApplyMomentum";
return AnfAlgo::GetCNodeName(node->cast<CNodePtr>()) == "ApplyMomentum";
}
};
@ -105,6 +104,11 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond2) {
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kg, nullptr);
for (auto p : kg->parameters()) {
auto param = p->cast<ParameterPtr>();
EXPECT_NE(param, nullptr);
param->set_default_param(std::make_shared<ParamValue>());
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
@ -146,10 +150,16 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) {
ASSERT_TRUE(g != nullptr);
std::vector<int> shp_x{1, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract};
AbstractBasePtrList args_spec_list{x_abstract, x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kg, nullptr);
for (auto p : kg->parameters()) {
auto param = p->cast<ParameterPtr>();
EXPECT_NE(param, nullptr);
param->set_default_param(std::make_shared<ParamValue>());
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>();
@ -161,5 +171,34 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWInsertMemcpyForHccl, test_cond5) {
get_py_fun_.SetDoResolve(true);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "before");
ASSERT_TRUE(g != nullptr);
std::vector<int> shp_x{1, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kg, nullptr);
for (auto p : kg->parameters()) {
auto param = p->cast<ParameterPtr>();
EXPECT_NE(param, nullptr);
param->set_default_param(std::make_shared<ParamValue>());
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>();
pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>();
pm->AddPass(pass);
optimizer->AddPassManager(pm);
auto new_graph = optimizer->Optimize(kg);
kg->SetExecOrderByDefault();
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore

View File

@ -17,6 +17,7 @@ from mindspore.ops import Primitive
from mindspore.ops import operations as P
all_reduce = P.AllReduce()
broadcast = P.Broadcast(1)
memcpy_async = Primitive('memcpy_async')
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
@ -101,20 +102,40 @@ def test_insert_memcpy_async_for_hccl_op_cond4(tag):
fns = FnDict()
@fns
def before(a, b, c, d, e):
res1 = apply_momentun(a, b, c, d, e)
res2 = all_reduce(a)
res = control_depend(res1, res2)
res = make_tuple(res, res2)
def before(a, b):
x = relu(a)
y = all_reduce(b)
res = control_depend(x, y)
return res
@fns
def after(a, b, c, d, e):
res1 = apply_momentun(a, b, c, d, e)
res2 = memcpy_async(a)
res3 = all_reduce(res2)
res = control_depend(res1, res2)
res = make_tuple(res, res3)
def after(a, b):
x = relu(a)
y1 = memcpy_async(b)
y2 = all_reduce(y1)
res = control_depend(x, make_tuple(y1, y2))
return make_tuple(res)
return fns[tag]
def test_insert_memcpy_async_for_hccl_op_cond5(tag):
fns = FnDict()
@fns
def before(a, b, c):
x = relu(a)
y = broadcast((b, c))
res = control_depend(x, y)
return res
@fns
def after(a, b, c):
x = relu(a)
m1 = memcpy_async(b)
m2 = memcpy_async(c)
y = broadcast(m1, m2)
res = control_depend(x, make_tuple(m1, m2, y))
return make_tuple(res)
return fns[tag]