forked from mindspore-Ecosystem/mindspore
!346 getnext parallel optimization part II: Eliminate Memcpy in specify scenario
Merge pull request !346 from laiyongqiang/develop
This commit is contained in:
commit
58a70b5f82
|
@ -21,7 +21,7 @@
|
|||
#include "pre_activate/ascend/ir_fission/bn_grad_split.h"
|
||||
#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h"
|
||||
#include "pre_activate/common/ir_fusion/allreduce_fusion.h"
|
||||
#include "pre_activate/pass/allreduce_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fusion/square_sum_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h"
|
||||
|
@ -58,8 +58,10 @@
|
|||
#include "pre_activate/ascend/ir_fission/add_memcpy_async.h"
|
||||
#include "pre_activate/ascend/format_type/insert_cast_for_runop.h"
|
||||
#include "pre_activate/ascend/format_type/insert_transdata_for_runop.h"
|
||||
#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h"
|
||||
#include "pre_activate/ascend/ir_fission/addn_fission.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/anf_ir_utils.h"
|
||||
|
||||
|
@ -244,6 +246,9 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
other_pm->AddPass(std::make_shared<BufferFusion>());
|
||||
other_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
other_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) {
|
||||
other_pm->AddPass(std::make_shared<GetnextMemcpyElimination>());
|
||||
}
|
||||
other_pm->AddPass(std::make_shared<CheckConsistency>());
|
||||
optimizer->AddPassManager(other_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h"
|
||||
#include <memory>
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "optimizer/opt.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
|
||||
const BaseRef GetnextMemcpyElimination::DefinePattern() const {
|
||||
auto prim_memcpy = std::make_shared<Primitive>(kMemCpyAsyncOpName);
|
||||
VarPtr x = std::make_shared<SeqVar>();
|
||||
VectorRef memcpy_async({prim_memcpy, x});
|
||||
return memcpy_async;
|
||||
}
|
||||
|
||||
const AnfNodePtr GetnextMemcpyElimination::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
if (graph == nullptr || node == nullptr || equiv == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto memcpy_cnode = node->cast<CNodePtr>();
|
||||
if (memcpy_cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 1. memcpy has attr kAttrLabelForInsertStreamActive
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, node)) {
|
||||
MS_LOG(DEBUG) << "node has no label_for_insert_stream_active attr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 2. memcpy's output has only one user next_node
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (manager->node_users().find(memcpy_cnode) == manager->node_users().end()) {
|
||||
MS_LOG(EXCEPTION) << "memcpy has no output in manager";
|
||||
}
|
||||
auto next_nodes = manager->node_users()[memcpy_cnode];
|
||||
if (next_nodes.size() > 1) {
|
||||
MS_LOG(DEBUG) << "node's output has more than one users";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 3. next_node has only one input which is memcpy's output
|
||||
for (auto &item : next_nodes) {
|
||||
auto next_node = item.first->cast<CNodePtr>();
|
||||
if (next_node->inputs().size() != 2) {
|
||||
MS_LOG(DEBUG) << "next node has more than one input";
|
||||
return nullptr;
|
||||
}
|
||||
// add attr label_for_insert_stream_active for next_node
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), next_node);
|
||||
}
|
||||
|
||||
return memcpy_cnode->input(1);
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H
|
||||
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class GetnextMemcpyElimination : public PatternProcessPass {
|
||||
public:
|
||||
explicit GetnextMemcpyElimination(bool multigraph = true)
|
||||
: PatternProcessPass("getnext_memcpy_elimination", multigraph) {}
|
||||
~GetnextMemcpyElimination() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H
|
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/common/ir_fusion/allreduce_fusion.h"
|
||||
#include "pre_activate/pass/allreduce_fusion.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_IR_FUSION_ALLREDUCE_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_IR_FUSION_ALLREDUCE_FUSION_H_
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_
|
||||
#include <vector>
|
||||
|
||||
#include "pre_activate/common/pass.h"
|
||||
|
@ -46,4 +46,4 @@ class AllReduceFusion : public Pass {
|
|||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_IR_FUSION_ALLREDUCE_FUSION_H_
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_
|
|
@ -20,7 +20,7 @@
|
|||
#include "device/gpu/gpu_stream_assign.h"
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "pre_activate/common/pass_manager.h"
|
||||
#include "pre_activate/common/ir_fusion/allreduce_fusion.h"
|
||||
#include "pre_activate/pass/allreduce_fusion.h"
|
||||
#include "device/kernel_runtime_manager.h"
|
||||
#include "predict/predict.h"
|
||||
#include "common/utils.h"
|
||||
|
|
|
@ -148,6 +148,7 @@ constexpr auto kAttrSrcFormat = "src_format";
|
|||
constexpr auto kAttrOutputUsedNum = "output_used_num";
|
||||
constexpr auto kAttrHasBias = "has_bias";
|
||||
constexpr auto kAttrN = "N";
|
||||
constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "operator/ops.h"
|
||||
#include "ir/meta_tensor.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "utils/utils.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TestGetNextMemcpyElimination : public BackendCommon {
|
||||
public:
|
||||
TestGetNextMemcpyElimination() : get_py_fun_("gtest_input.pre_activate.getnext_memcpy_elimination_test", true) {}
|
||||
|
||||
public:
|
||||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
TEST_F(TestGetNextMemcpyElimination, test_getnext_memcpy_elimination) {
|
||||
FuncGraphPtr g_before = get_py_fun_.CallAndParseRet("test_getnext_memcpy_elimination", "before");
|
||||
ASSERT_TRUE(g_before != nullptr);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pass = std::make_shared<opt::GetnextMemcpyElimination>();
|
||||
pm->AddPass(pass);
|
||||
optimizer->AddPassManager(pm);
|
||||
auto new_graph = optimizer->Optimize(g_before);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_getnext_memcpy_elimination", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestGetNextMemcpyElimination, test_getnext_memcpy_elimination_no_attr) {
|
||||
FuncGraphPtr g_before = get_py_fun_.CallAndParseRet("test_getnext_memcpy_elimination_no_attr", "before");
|
||||
ASSERT_TRUE(g_before != nullptr);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pass = std::make_shared<opt::GetnextMemcpyElimination>();
|
||||
pm->AddPass(pass);
|
||||
optimizer->AddPassManager(pm);
|
||||
auto new_graph = optimizer->Optimize(g_before);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_getnext_memcpy_elimination_no_attr", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestGetNextMemcpyElimination, test_getnext_memcpy_elimination_memcpy_multi_users) {
|
||||
FuncGraphPtr g_before = get_py_fun_.CallAndParseRet("test_getnext_memcpy_elimination_memcpy_multi_users", "before");
|
||||
ASSERT_TRUE(g_before != nullptr);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pass = std::make_shared<opt::GetnextMemcpyElimination>();
|
||||
pm->AddPass(pass);
|
||||
optimizer->AddPassManager(pm);
|
||||
auto new_graph = optimizer->Optimize(g_before);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_getnext_memcpy_elimination_memcpy_multi_users", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestGetNextMemcpyElimination, test_getnext_memcpy_elimination_next_multi_inputs) {
|
||||
FuncGraphPtr g_before = get_py_fun_.CallAndParseRet("test_getnext_memcpy_elimination_next_multi_inputs", "before");
|
||||
ASSERT_TRUE(g_before != nullptr);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pass = std::make_shared<opt::GetnextMemcpyElimination>();
|
||||
pm->AddPass(pass);
|
||||
optimizer->AddPassManager(pm);
|
||||
auto new_graph = optimizer->Optimize(g_before);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_getnext_memcpy_elimination_next_multi_inputs", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -20,7 +20,7 @@
|
|||
#include "ir/manager.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "pre_activate/common/ir_fusion/allreduce_fusion.h"
|
||||
#include "pre_activate/pass/allreduce_fusion.h"
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "device/kernel_info.h"
|
||||
#include "pre_activate/common/pass_manager.h"
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import Primitive
|
||||
import mindspore as ms
|
||||
|
||||
get_next = P.GetNext([ms.float32], [[1, 64, 112, 112]], 1, "")
|
||||
memcpy_async_attr = Primitive('memcpy_async')
|
||||
memcpy_async_attr.add_prim_attr("label_for_insert_stream_active", True)
|
||||
memcpy_async = Primitive('memcpy_async')
|
||||
cast = P.Cast()
|
||||
add = P.TensorAdd()
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fnDict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fnDict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fnDict[name]
|
||||
|
||||
|
||||
def test_getnext_memcpy_elimination(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(x):
|
||||
res = get_next()
|
||||
res = memcpy_async_attr(res)
|
||||
res = cast(res)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x):
|
||||
res = get_next()
|
||||
res = cast(res)
|
||||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_getnext_memcpy_elimination_no_attr(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(x):
|
||||
res = get_next()
|
||||
res = memcpy_async(res)
|
||||
res = cast(res)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x):
|
||||
res = get_next()
|
||||
res = memcpy_async(res)
|
||||
res = cast(res)
|
||||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_getnext_memcpy_elimination_memcpy_multi_users(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(x):
|
||||
res = get_next()
|
||||
memcpy_out = memcpy_async_attr(res)
|
||||
res = cast(memcpy_out)
|
||||
res = add(memcpy_out, res)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x):
|
||||
res = get_next()
|
||||
memcpy_out = memcpy_async_attr(res)
|
||||
res = cast(memcpy_out)
|
||||
res = add(memcpy_out, res)
|
||||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_getnext_memcpy_elimination_next_multi_inputs(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(x):
|
||||
res = get_next()
|
||||
memcpy_out = memcpy_async_attr(res)
|
||||
res = add(memcpy_out, res)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x):
|
||||
res = get_next()
|
||||
memcpy_out = memcpy_async_attr(res)
|
||||
res = add(memcpy_out, res)
|
||||
return res
|
||||
|
||||
return fns[tag]
|
Loading…
Reference in New Issue