forked from mindspore-Ecosystem/mindspore
insert memcpy async if hccl op cascade
This commit is contained in:
parent
48325dea3b
commit
f1563d2d37
|
@ -87,6 +87,7 @@
|
|||
#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h"
|
||||
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h"
|
||||
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
|
||||
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
|
||||
#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h"
|
||||
#include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h"
|
||||
|
@ -340,6 +341,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
other_pm->AddPass(std::make_shared<AllGatherFusion>());
|
||||
other_pm->AddPass(std::make_shared<ReduceScatterFusion>());
|
||||
other_pm->AddPass(std::make_shared<BroadcastFusion>());
|
||||
other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>());
|
||||
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
|
||||
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
|
||||
optimizer->AddPassManager(other_pm);
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "utils/utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool IsPartOutputsOfHcclOp(const AnfNodePtr &node, const CNodePtr &cur_hccl, const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(cur_hccl);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prev_node = cnode->input(kRealInputNodeIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(prev_node);
|
||||
if (!AnfAlgo::IsCommunicationOp(prev_node)) {
|
||||
return false;
|
||||
}
|
||||
auto prev_hccl_op = prev_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prev_hccl_op);
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto &node_users = manager->node_users();
|
||||
auto iter = node_users.find(prev_hccl_op);
|
||||
if (iter == node_users.end()) {
|
||||
MS_LOG(EXCEPTION) << "node has no output in manager";
|
||||
}
|
||||
for (const auto &node_index : iter->second) {
|
||||
AnfNodePtr output = node_index.first;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
|
||||
bool is_contain = false;
|
||||
for (size_t i = 1; i < cur_hccl->size(); ++i) {
|
||||
if (cur_hccl->input(i) == output) {
|
||||
is_contain = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!is_contain) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr InsertMemcpyAsyncForCascade::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(hccl_node);
|
||||
std::vector<AnfNodePtr> memcpy_async_list;
|
||||
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
|
||||
for (size_t i = 1; i < hccl_node->size(); ++i) {
|
||||
auto input = hccl_node->input(i);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
// when input is also a hccl op and just part outputs of it linking with cur_hccl_op
|
||||
if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) {
|
||||
auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
memcpy_async->set_kernel_info(kernel_info);
|
||||
MS_EXCEPTION_IF_NULL(kernel_select_);
|
||||
kernel_select_->SelectKernel(memcpy_async->cast<CNodePtr>());
|
||||
new_inputs.push_back(memcpy_async);
|
||||
memcpy_async_list.push_back(memcpy_async);
|
||||
} else {
|
||||
new_inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
if (!memcpy_async_list.empty()) {
|
||||
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
|
||||
new_hccl_node->set_inputs(new_inputs);
|
||||
return new_hccl_node;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const AnfNodePtr InsertMemcpyAsyncForCascade::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!AnfAlgo::IsCommunicationOp(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
return InsertMemcpyAsync(func_graph, cnode);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class InsertMemcpyAsyncForCascade : public PatternProcessPass {
|
||||
public:
|
||||
explicit InsertMemcpyAsyncForCascade(bool multigraph = true)
|
||||
: PatternProcessPass("insert_memcpy_async_for_cascade", multigraph),
|
||||
kernel_select_(std::make_shared<KernelSelect>()) {}
|
||||
~InsertMemcpyAsyncForCascade() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const;
|
||||
KernelSelectPtr kernel_select_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_OP_CASCADE_H_
|
|
@ -32,12 +32,17 @@ const std::set<std::string> kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNe
|
|||
bool IsParameterOrValueNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
|
||||
return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>();
|
||||
auto real_node = kernel_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_node);
|
||||
if (real_node->isa<Parameter>()) {
|
||||
return true;
|
||||
}
|
||||
return real_node->isa<ValueNode>();
|
||||
}
|
||||
|
||||
void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) {
|
||||
void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list,
|
||||
const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(hccl_node);
|
||||
MS_EXCEPTION_IF_NULL(memcpy_async);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -48,49 +53,62 @@ void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async,
|
|||
}
|
||||
// find hccl_node's output which is a control depend
|
||||
for (const auto &node_index : iter->second) {
|
||||
AnfNodePtr output = node_index.first;
|
||||
int output_index = node_index.second;
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
|
||||
CNodePtr control_depend = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(control_depend);
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
for (size_t i = 0; i < control_depend->size(); ++i) {
|
||||
if (i == IntToSize(output_index)) {
|
||||
new_inputs.push_back(memcpy_async);
|
||||
} else {
|
||||
new_inputs.push_back(control_depend->input(i));
|
||||
}
|
||||
}
|
||||
control_depend->set_inputs(new_inputs);
|
||||
if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
CNodePtr control_depend = node_index.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(control_depend);
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
for (size_t i = 0; i < control_depend->size(); ++i) {
|
||||
if (i == IntToSize(node_index.second)) {
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end());
|
||||
make_tuple_inputs.emplace_back(hccl_node);
|
||||
auto make_tuple = graph->NewCNode(make_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
new_inputs.push_back(make_tuple);
|
||||
} else {
|
||||
new_inputs.push_back(control_depend->input(i));
|
||||
}
|
||||
}
|
||||
control_depend->set_inputs(new_inputs);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const {
|
||||
bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input,
|
||||
const CNodePtr &cur_node) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_EXCEPTION_IF_NULL(cur_node);
|
||||
// when input is a parameter or is a value node
|
||||
if (IsParameterOrValueNode(input)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// when input is a Ref or some special cnodes
|
||||
if (kernel_query_->IsTbeRef(input) ||
|
||||
kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) {
|
||||
return true;
|
||||
}
|
||||
if (input->isa<CNode>()) {
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto &node_users = manager->node_users();
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto &node_users = manager->node_users();
|
||||
auto iter = node_users.find(input);
|
||||
if (iter == node_users.end()) {
|
||||
MS_LOG(EXCEPTION) << "node has no output in manager";
|
||||
}
|
||||
// when input is used by others
|
||||
if (iter->second.size() > 1) {
|
||||
return true;
|
||||
// when input is a Ref cnode
|
||||
if (kernel_query_->IsTbeRef(input)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// when input is some special cnodes
|
||||
if (kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// when input is used by others
|
||||
auto iter = node_users.find(input);
|
||||
if (iter == node_users.end()) {
|
||||
MS_LOG(EXCEPTION) << "node has no output in manager";
|
||||
}
|
||||
if (iter->second.size() > 1) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -98,21 +116,20 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
|
|||
void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(hccl_node);
|
||||
bool has_insert_memcpy = false;
|
||||
AnfNodePtr memcpy_async = nullptr;
|
||||
std::vector<AnfNodePtr> memcpy_async_list;
|
||||
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
|
||||
for (size_t i = 1; i < hccl_node->size(); ++i) {
|
||||
auto input = hccl_node->input(i);
|
||||
if (NeedInsertMemcpy(graph, input)) {
|
||||
memcpy_async = CreateMemcpyAsyncOp(graph, input);
|
||||
has_insert_memcpy = true;
|
||||
if (NeedInsertMemcpy(graph, input, hccl_node)) {
|
||||
auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
|
||||
new_inputs.push_back(memcpy_async);
|
||||
memcpy_async_list.push_back(memcpy_async);
|
||||
} else {
|
||||
new_inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
if (has_insert_memcpy) {
|
||||
if (!memcpy_async_list.empty()) {
|
||||
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
|
||||
new_hccl_node->set_inputs(new_inputs);
|
||||
auto manager = graph->manager();
|
||||
|
@ -122,9 +139,7 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
|
|||
MS_LOG(DEBUG) << "end replace";
|
||||
|
||||
// transer hccl op's control to the memcpy_async
|
||||
if (hccl_node->size() == 2) {
|
||||
TransferControl(new_hccl_node, memcpy_async, graph);
|
||||
}
|
||||
TransferControl(new_hccl_node, memcpy_async_list, graph);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class InsertMemcpyAsyncForHcclOp : public PatternProcessPass {
|
|||
|
||||
private:
|
||||
void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const;
|
||||
bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const;
|
||||
bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &cur_node) const;
|
||||
KernelQueryPtr kernel_query_;
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "utils/utils.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "ir/param_value.h"
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
|
||||
|
@ -44,12 +45,10 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery {
|
|||
~MockInsertMemcpyForHcclKernelQuery() override = default;
|
||||
bool IsTbeRef(const AnfNodePtr &node) override {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto name = AnfAlgo::GetCNodeName(cnode);
|
||||
return name == "ApplyMomentum";
|
||||
return AnfAlgo::GetCNodeName(node->cast<CNodePtr>()) == "ApplyMomentum";
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -105,6 +104,11 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond2) {
|
|||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
EXPECT_NE(kg, nullptr);
|
||||
for (auto p : kg->parameters()) {
|
||||
auto param = p->cast<ParameterPtr>();
|
||||
EXPECT_NE(param, nullptr);
|
||||
param->set_default_param(std::make_shared<ParamValue>());
|
||||
}
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
|
@ -146,10 +150,16 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) {
|
|||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int> shp_x{1, 64, 112, 112};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract};
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract};
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
EXPECT_NE(kg, nullptr);
|
||||
|
||||
for (auto p : kg->parameters()) {
|
||||
auto param = p->cast<ParameterPtr>();
|
||||
EXPECT_NE(param, nullptr);
|
||||
param->set_default_param(std::make_shared<ParamValue>());
|
||||
}
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>();
|
||||
|
@ -161,5 +171,34 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWInsertMemcpyForHccl, test_cond5) {
|
||||
get_py_fun_.SetDoResolve(true);
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int> shp_x{1, 64, 112, 112};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract};
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
EXPECT_NE(kg, nullptr);
|
||||
|
||||
for (auto p : kg->parameters()) {
|
||||
auto param = p->cast<ParameterPtr>();
|
||||
EXPECT_NE(param, nullptr);
|
||||
param->set_default_param(std::make_shared<ParamValue>());
|
||||
}
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>();
|
||||
pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>();
|
||||
pm->AddPass(pass);
|
||||
optimizer->AddPassManager(pm);
|
||||
auto new_graph = optimizer->Optimize(kg);
|
||||
kg->SetExecOrderByDefault();
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,7 @@ from mindspore.ops import Primitive
|
|||
from mindspore.ops import operations as P
|
||||
|
||||
all_reduce = P.AllReduce()
|
||||
broadcast = P.Broadcast(1)
|
||||
memcpy_async = Primitive('memcpy_async')
|
||||
make_tuple = Primitive('make_tuple')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
|
@ -101,20 +102,40 @@ def test_insert_memcpy_async_for_hccl_op_cond4(tag):
|
|||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(a, b, c, d, e):
|
||||
res1 = apply_momentun(a, b, c, d, e)
|
||||
res2 = all_reduce(a)
|
||||
res = control_depend(res1, res2)
|
||||
res = make_tuple(res, res2)
|
||||
def before(a, b):
|
||||
x = relu(a)
|
||||
y = all_reduce(b)
|
||||
res = control_depend(x, y)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(a, b, c, d, e):
|
||||
res1 = apply_momentun(a, b, c, d, e)
|
||||
res2 = memcpy_async(a)
|
||||
res3 = all_reduce(res2)
|
||||
res = control_depend(res1, res2)
|
||||
res = make_tuple(res, res3)
|
||||
def after(a, b):
|
||||
x = relu(a)
|
||||
y1 = memcpy_async(b)
|
||||
y2 = all_reduce(y1)
|
||||
res = control_depend(x, make_tuple(y1, y2))
|
||||
return make_tuple(res)
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_insert_memcpy_async_for_hccl_op_cond5(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(a, b, c):
|
||||
x = relu(a)
|
||||
y = broadcast((b, c))
|
||||
res = control_depend(x, y)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(a, b, c):
|
||||
x = relu(a)
|
||||
m1 = memcpy_async(b)
|
||||
m2 = memcpy_async(c)
|
||||
y = broadcast(m1, m2)
|
||||
res = control_depend(x, make_tuple(m1, m2, y))
|
||||
return make_tuple(res)
|
||||
|
||||
return fns[tag]
|
||||
|
|
Loading…
Reference in New Issue