diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 27b99840dfb..62721c2ec8e 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -81,7 +81,7 @@ #include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" #include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" -#include "pre_activate/ascend/enhancer/add_memcpy_async.h" +#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" #include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" @@ -227,7 +227,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); } - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); if (context_ptr->ir_fusion_flag()) { AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); @@ -238,6 +237,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); } + ir_fusion_pm->AddPass(std::make_shared()); optimizer->AddPassManager(ir_fusion_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index ee0d837cee4..e47270d8478 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -22,6 +22,8 @@ #include "device/ascend/kernel_select_ascend.h" #include "kernel/kernel_query.h" #include "kernel/tbe/tbe_kernel_select.h" +#include "kernel/oplib/oplib.h" +#include "session/anf_runtime_algorithm.h" namespace mindspore { namespace opt { @@ -56,6 +58,17 @@ class KernelQuery { std::vector> *kernel_info_list) { kernel::KernelQuery(kernel_node, kernel_info_list); } + virtual bool IsTbeRef(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(node), kernel::kTBE); + if (op_info != nullptr) { + return op_info->is_ref(); + } + return false; + } }; using KernelQueryPtr = std::shared_ptr; void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc deleted file mode 100644 index 51f6732c66d..00000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/enhancer/add_memcpy_async.h" -#include -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -namespace { -bool InputIsParameterOrValueNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); - return kernel_with_index.first->isa() || kernel_with_index.first->isa(); -} - -const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - const std::vector &inputs = node->inputs(); - bool replace = false; - if (inputs.empty()) { - MS_LOG(EXCEPTION) << "node[" + AnfAlgo::GetCNodeName(node) + "]'s inputs is empty"; - } - std::vector new_inputs = {inputs[0]}; - for (size_t i = 1; i < inputs.size(); ++i) { - auto input = node->input(i); - if (manager->node_users().find(input) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "node has no output in manager"; - } - // when input is used by others or is a parameter or is a value node, insert a memcpy_async - if (manager->node_users()[input].size() > 1 || InputIsParameterOrValueNode(input)) { - replace = true; - new_inputs.push_back(CreateMemcpyAsyncOp(graph, input)); - } else { - new_inputs.push_back(input); - } - } - - CNodePtr new_node = std::make_shared(*node); - new_node->set_inputs(new_inputs); - return replace ? new_node : nullptr; -} -} // namespace - -const AnfNodePtr AddMemcpyAsync::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (func_graph == nullptr || node == nullptr || !node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - if (!AnfAlgo::IsCommunicationOp(node)) { - return nullptr; - } - return AddMemcpyAsyncIfInputIsUsedByOthers(func_graph, cnode); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc new file mode 100644 index 00000000000..cae036b2bd1 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc @@ -0,0 +1,135 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" +#include +#include +#include +#include "utils/utils.h" +#include "session/anf_runtime_algorithm.h" +#include "optimizer/opt.h" +#include "pre_activate/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +namespace { +// insert memcpy for some cnode even if not a Ref cnode +const std::set kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNextMVWithDecayOpName, + kLambUpdateWithLROpName}; + +bool IsParameterOrValueNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); + return kernel_with_index.first->isa() || kernel_with_index.first->isa(); +} + +void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(hccl_node); + MS_EXCEPTION_IF_NULL(memcpy_async); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(hccl_node); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + // find hccl_node's output which is a control depend + for (const auto &node_index : iter->second) { + AnfNodePtr output = node_index.first; + int output_index = node_index.second; + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { + CNodePtr control_depend = output->cast(); + MS_EXCEPTION_IF_NULL(control_depend); + std::vector new_inputs; + for (size_t i = 0; i < control_depend->size(); ++i) { + if (i == IntToSize(output_index)) { + new_inputs.push_back(memcpy_async); + } else { + new_inputs.push_back(control_depend->input(i)); + } + } + control_depend->set_inputs(new_inputs); + } + } +} +} // namespace + +bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input); + // when input is a parameter or is a value node + if (IsParameterOrValueNode(input)) { + return true; + } + + // when input is a Ref or some special cnodes + if (kernel_query_->IsTbeRef(input) || + kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { + return true; + } + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(input); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + // when input is used by others + if (iter->second.size() > 1) { + return true; + } + return false; +} + +void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(hccl_node); + if (hccl_node->size() != 2) { + MS_LOG(INFO) << "node[" + AnfAlgo::GetCNodeName(hccl_node) + "]'s inputs size not equal 2"; + return; + } + + auto input = hccl_node->input(1); + if (NeedInsertMemcpy(graph, input)) { + auto memcpy_async = CreateMemcpyAsyncOp(graph, input); + CNodePtr new_hccl_node = std::make_shared(*hccl_node); + new_hccl_node->set_inputs({hccl_node->input(0), memcpy_async}); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node"; + (void)manager->Replace(hccl_node, new_hccl_node); + MS_LOG(DEBUG) << "end replace"; + + // transer hccl op's control to the memcpy_async + TransferControl(new_hccl_node, memcpy_async, graph); + } +} + +const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (func_graph == nullptr || node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + if (!AnfAlgo::IsCommunicationOp(node)) { + return nullptr; + } + InsertMemcpyAsync(func_graph, cnode); + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.h b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h similarity index 50% rename from mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.h rename to mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h index 900b0fb46a1..e2f3b781ed3 100644 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.h +++ b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h @@ -13,19 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_MEMCPY_ASYNC_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_MEMCPY_ASYNC_H_ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ #include #include "pre_activate/common/optimizer.h" +#include "pre_activate/ascend/ascend_helper.h" + namespace mindspore { namespace opt { -class AddMemcpyAsync : public PatternProcessPass { +class InsertMemcpyAsyncForHcclOp : public PatternProcessPass { public: - explicit AddMemcpyAsync(bool multigraph = true) : PatternProcessPass("add_memcpy_async", multigraph) {} - ~AddMemcpyAsync() override = default; + explicit InsertMemcpyAsyncForHcclOp(bool multigraph = true) + : PatternProcessPass("insert_memcpy_async_for_hccl_op", multigraph), + kernel_query_(std::make_shared()) {} + ~InsertMemcpyAsyncForHcclOp() override = default; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; + bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const; + KernelQueryPtr kernel_query_; }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_MEMCPY_ASYNC_H_ +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc index 5f01f2fab25..f373594f4af 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc @@ -56,7 +56,7 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s bn_outputs->push_back(output); output_num++; } - return output_num > kBatchNormLeastOutputNum; + return output_num >= kBatchNormLeastOutputNum; } AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { diff --git a/tests/ut/cpp/pre_activate/ascend/enhancer/add_memcpy_async_test.cc b/tests/ut/cpp/pre_activate/ascend/enhancer/add_memcpy_async_test.cc deleted file mode 100644 index 50b76df8647..00000000000 --- a/tests/ut/cpp/pre_activate/ascend/enhancer/add_memcpy_async_test.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "common/backend_common_test.h" -#include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "ir/tensor.h" -#include "debug/anf_ir_dump.h" -#include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/enhancer/add_memcpy_async.h" - -namespace mindspore { -namespace opt { -class TestHWAddMemcpyAsync : public BackendCommon { - public: - TestHWAddMemcpyAsync() : get_py_fun_("gtest_input.pre_activate.add_memcpy_async", true) {} - - public: - UT::PyFuncGraphFetcher get_py_fun_; -}; - -TEST_F(TestHWAddMemcpyAsync, test_add_memcpy_async) { - get_py_fun_.SetDoResolve(true); - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_add_memcpy_async", "before"); - ASSERT_TRUE(g != nullptr); - std::vector shp_x{1, 64, 112, 112}; - auto x_abstract = std::make_shared(kFloat32, shp_x); - AbstractBasePtrList args_spec_list{x_abstract}; - auto func_graph = GetKernelGraph(g, args_spec_list); - EXPECT_NE(func_graph, nullptr); - - auto optimizer = std::make_shared(); - auto pm = std::make_shared(); - auto pass = std::make_shared(); - pm->AddPass(pass); - optimizer->AddPassManager(pm); - auto new_graph = optimizer->Optimize(func_graph); - - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_add_memcpy_async", "after"); - EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); -} -} // namespace opt -} // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc new file mode 100644 index 00000000000..22cf70ded3f --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc @@ -0,0 +1,165 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "session/anf_runtime_algorithm.h" +#include "operator/ops.h" +#include "ir/tensor.h" +#include "debug/anf_ir_dump.h" +#include "utils/utils.h" +#include "kernel/kernel_build_info.h" +#include "pre_activate/common/optimizer.h" +#define private public +#define protected public +#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" +#undef private +#undef protected +namespace mindspore { +namespace opt { +class TestHWInsertMemcpyForHccl : public BackendCommon { + public: + TestHWInsertMemcpyForHccl() : get_py_fun_("gtest_input.pre_activate.insert_memcpy_async_for_hccl_op", true) {} + ~TestHWInsertMemcpyForHccl() override = default; + + public: + UT::PyFuncGraphFetcher get_py_fun_; +}; + +class MockInsertMemcpyForHcclKernelQuery : public KernelQuery { + public: + MockInsertMemcpyForHcclKernelQuery() = default; + ~MockInsertMemcpyForHcclKernelQuery() override = default; + bool IsTbeRef(const AnfNodePtr &node) override { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if (cnode == nullptr) { + return false; + } + auto name = AnfAlgo::GetCNodeName(cnode); + return name == "ApplyMomentum"; + } +}; + +TEST_F(TestHWInsertMemcpyForHccl, test_cond1) { + get_py_fun_.SetDoResolve(true); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before1"); + ASSERT_TRUE(g != nullptr); + std::vector shp_x{1, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + EXPECT_NE(kg, nullptr); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pass = std::make_shared(); + pass->kernel_query_ = std::make_shared(); + pm->AddPass(pass); + optimizer->AddPassManager(pm); + auto new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWInsertMemcpyForHccl, test_cond1_no_insert) { + get_py_fun_.SetDoResolve(true); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before2"); + ASSERT_TRUE(g != nullptr); + std::vector shp_x{1, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + EXPECT_NE(kg, nullptr); + auto origin_graph = std::make_shared(*kg); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pass = std::make_shared(); + pm->AddPass(pass); + optimizer->AddPassManager(pm); + auto new_graph = optimizer->Optimize(kg); + + EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); +} + +TEST_F(TestHWInsertMemcpyForHccl, test_cond2) { + get_py_fun_.SetDoResolve(true); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond2", "before"); + ASSERT_TRUE(g != nullptr); + std::vector shp_x{1, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + EXPECT_NE(kg, nullptr); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pass = std::make_shared(); + pass->kernel_query_ = std::make_shared(); + pm->AddPass(pass); + optimizer->AddPassManager(pm); + auto new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond2", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWInsertMemcpyForHccl, test_cond3) { + get_py_fun_.SetDoResolve(true); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond3", "before"); + ASSERT_TRUE(g != nullptr); + std::vector shp_x{1, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + EXPECT_NE(kg, nullptr); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pass = std::make_shared(); + pass->kernel_query_ = std::make_shared(); + pm->AddPass(pass); + optimizer->AddPassManager(pm); + auto new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond3", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWInsertMemcpyForHccl, test_cond4) { + get_py_fun_.SetDoResolve(true); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "before"); + ASSERT_TRUE(g != nullptr); + std::vector shp_x{1, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + EXPECT_NE(kg, nullptr); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pass = std::make_shared(); + pass->kernel_query_ = std::make_shared(); + pm->AddPass(pass); + optimizer->AddPassManager(pm); + auto new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/add_memcpy_async.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/add_memcpy_async.py deleted file mode 100644 index e087530acd8..00000000000 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/add_memcpy_async.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -from mindspore.ops import Primitive -from mindspore.ops import operations as P - -all_reduce = P.AllReduce() -memcpy_async = Primitive('memcpy_async') -make_tuple = Primitive('make_tuple') -tuple_getitem = Primitive('tuple_getitem') - - -class FnDict: - def __init__(self): - self.fnDict = {} - - def __call__(self, fn): - self.fnDict[fn.__name__] = fn - - def __getitem__(self, name): - return self.fnDict[name] - - -def test_add_memcpy_async(tag): - fns = FnDict() - - @fns - def before(x): - res = all_reduce(x) - return make_tuple(x, res) - - @fns - def after(x): - res = memcpy_async(x) - res = all_reduce(res) - return make_tuple(make_tuple(x, res)) - - return fns[tag] diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py new file mode 100644 index 00000000000..7ffcfd0578f --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py @@ -0,0 +1,120 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.ops import Primitive +from mindspore.ops import operations as P + +all_reduce = P.AllReduce() +memcpy_async = Primitive('memcpy_async') +make_tuple = Primitive('make_tuple') +tuple_getitem = Primitive('tuple_getitem') +apply_momentun = P.ApplyMomentum() +control_depend = P.ControlDepend() +relu = P.ReLU() + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_insert_memcpy_async_for_hccl_op_cond1(tag): + fns = FnDict() + + @fns + def before1(x): + res1 = relu(x) + res2 = all_reduce(res1) + return make_tuple(res1, res2) + + @fns + def before2(x): + res1 = relu(x) + res2 = all_reduce(res1) + return res2 + + @fns + def after(x): + res1 = relu(x) + res2 = memcpy_async(res1) + res2 = all_reduce(res2) + return make_tuple(make_tuple(res1, res2)) + + return fns[tag] + + +def test_insert_memcpy_async_for_hccl_op_cond2(tag): + fns = FnDict() + + @fns + def before(x): + res = all_reduce(x) + return res + + @fns + def after(x): + res = memcpy_async(x) + res = all_reduce(res) + return make_tuple(res) + + return fns[tag] + + +def test_insert_memcpy_async_for_hccl_op_cond3(tag): + fns = FnDict() + + @fns + def before(a, b, c, d, e): + res = apply_momentun(a, b, c, d, e) + res = all_reduce(res) + return res + + @fns + def after(a, b, c, d, e): + res = apply_momentun(a, b, c, d, e) + res = memcpy_async(res) + res = all_reduce(res) + return make_tuple(res) + + return fns[tag] + + +def test_insert_memcpy_async_for_hccl_op_cond4(tag): + fns = FnDict() + + @fns + def before(a, b, c, d, e): + res1 = apply_momentun(a, b, c, d, e) + res2 = all_reduce(a) + res = control_depend(res1, res2) + res = make_tuple(res, res2) + return res + + @fns + def after(a, b, c, d, e): + res1 = apply_momentun(a, b, c, d, e) + res2 = memcpy_async(a) + res3 = all_reduce(res2) + res = control_depend(res1, res2) + res = make_tuple(res, res3) + return make_tuple(res) + + return fns[tag]