diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 2636def192a..46ef1d7d533 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -97,6 +97,8 @@ #include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h" #include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" #include "backend/optimizer/ascend/format_type/remove_internal_output.h" +#include "backend/optimizer/ascend/ir_fission/concat_fission.h" +#include "backend/optimizer/ascend/ir_fission/pack_fission.h" #include "utils/context/ms_context.h" #include "utils/config_manager.h" #include "debug/anf_ir_dump.h" @@ -153,6 +155,8 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); } } // namespace diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc new file mode 100644 index 00000000000..c61cc03e963 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/concat_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origin_concat_cnode, size_t begin_index, + size_t offset) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(origin_concat_cnode); + std::vector new_concat_inputs{NewValueNode(std::make_shared(prim::kPrimConcat->name()))}; + for (size_t i = begin_index; i < begin_index + offset; ++i) { + new_concat_inputs.push_back(origin_concat_cnode->input(i)); + } + CNodePtr new_concat = func_graph->NewCNode(new_concat_inputs); + MS_EXCEPTION_IF_NULL(new_concat); + new_concat->set_scope(origin_concat_cnode->scope()); + // Set attrs + AnfAlgo::CopyNodeAttr(kAttrAxis, origin_concat_cnode, new_concat); + AnfAlgo::CopyNodeAttr(kAttrT, origin_concat_cnode, new_concat); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_concat); + AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(offset)), new_concat); + std::vector dyn_input_sizes{SizeToInt(offset)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_concat); + // infer shape + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_concat_cnode, 0); + auto axis = AnfAlgo::GetNodeAttr(origin_concat_cnode, kAttrAxis); + if (axis < 0) { + axis += input_shape.size(); + } + auto output_shape = AnfAlgo::GetOutputInferShape(origin_concat_cnode, 0); + if (axis < 0 || axis >= SizeToInt(output_shape.size()) || axis >= SizeToInt(input_shape.size())) { + MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range"; + } + output_shape[axis] = input_shape[axis] * offset; + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_concat_cnode, 0)}, {output_shape}, + new_concat.get()); + return new_concat; +} +} // namespace + +const BaseRef ConcatFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimConcat, Xs}); +} + +const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // The real input begins with index 1. + size_t origin_input_size = cnode->inputs().size() - 1; + if (origin_input_size <= inputs_divisor_) { + return nullptr; + } + CNodePtr new_cnode = cnode; + while (origin_input_size > inputs_divisor_) { + MS_EXCEPTION_IF_NULL(new_cnode); + std::vector base_concat_inputs{NewValueNode(std::make_shared(prim::kPrimConcat->name()))}; + size_t cur_input_index = 1; + // Divide the inputs of concat by inputs_divisor_. + while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { + base_concat_inputs.push_back(CreateNewConcat(func_graph, new_cnode, cur_input_index, inputs_divisor_)); + cur_input_index += inputs_divisor_; + } + for (size_t i = cur_input_index; i <= origin_input_size; i++) { + base_concat_inputs.push_back(new_cnode->input(i)); + } + CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs); + MS_EXCEPTION_IF_NULL(base_concat); + base_concat->set_scope(new_cnode->scope()); + base_concat->set_abstract(new_cnode->abstract()); + // Set attrs + AnfAlgo::CopyNodeAttr(kAttrAxis, new_cnode, base_concat); + AnfAlgo::CopyNodeAttr(kAttrT, new_cnode, base_concat); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat); + AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat); + std::vector dyn_input_sizes{SizeToInt(base_concat_inputs.size() - 1)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_concat); + + new_cnode = base_concat; + origin_input_size = base_concat->inputs().size() - 1; + } + + return new_cnode; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.h new file mode 100644 index 00000000000..a2a8d413b5c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_CONCAT_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_CONCAT_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +constexpr size_t kConcatInputsDivisor = 63; +class ConcatFission : public PatternProcessPass { + public: + explicit ConcatFission(bool multigraph = true) + : PatternProcessPass("concat_fission", multigraph), inputs_divisor_(kConcatInputsDivisor) {} + ~ConcatFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + size_t inputs_divisor_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_CONCAT_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc new file mode 100644 index 00000000000..9196451aea6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/pack_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_pack_cnode, size_t begin_index, + size_t offset) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(origin_pack_cnode); + std::vector new_pack_inputs{NewValueNode(std::make_shared(prim::kPrimPack->name()))}; + for (size_t i = begin_index; i < begin_index + offset; ++i) { + new_pack_inputs.push_back(origin_pack_cnode->input(i)); + } + CNodePtr new_pack = func_graph->NewCNode(new_pack_inputs); + MS_EXCEPTION_IF_NULL(new_pack); + new_pack->set_scope(origin_pack_cnode->scope()); + new_pack->set_abstract(origin_pack_cnode->abstract()); + AnfAlgo::CopyNodeAttr(kAttrAxis, origin_pack_cnode, new_pack); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_pack); + AnfAlgo::SetNodeAttr(kAttrNum, MakeValue(SizeToInt(offset)), new_pack); + std::vector dyn_input_sizes{SizeToInt(offset)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_pack); + // infer shape + auto output_shape = AnfAlgo ::GetOutputInferShape(origin_pack_cnode, 0); + auto axis = AnfAlgo::GetNodeAttr(new_pack, kAttrAxis); + if (axis < 0) { + axis += output_shape.size(); + } + if (axis < 0) { + MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range"; + } + std::vector new_shape; + for (size_t i = 0; i < output_shape.size() + 1; ++i) { + if (i < IntToSize(axis)) { + new_shape.push_back(output_shape[i]); + } else if (i == IntToSize(axis)) { + new_shape.push_back(offset); + } else { + new_shape.push_back(output_shape[i - 1]); + } + } + new_shape.erase(new_shape.begin() + axis + 1); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {output_shape}, + new_pack.get()); + return new_pack; +} +} // namespace + +const BaseRef PackFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimPack, Xs}); +} + +const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // The real input begins with index 1. + size_t origin_input_size = cnode->inputs().size() - 1; + if (origin_input_size <= inputs_divisor_) { + return nullptr; + } + std::vector base_concat_inputs{NewValueNode(std::make_shared(prim::kPrimConcat->name()))}; + size_t cur_input_index = 1; + // Divide the inputs of pack by inputs_divisor_. + while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { + base_concat_inputs.push_back(CreateNewPack(func_graph, cnode, cur_input_index, inputs_divisor_)); + cur_input_index += inputs_divisor_; + } + if (cur_input_index <= origin_input_size) { + base_concat_inputs.push_back( + CreateNewPack(func_graph, cnode, cur_input_index, origin_input_size - cur_input_index + 1)); + } + + CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs); + MS_EXCEPTION_IF_NULL(base_concat); + base_concat->set_scope(cnode->scope()); + base_concat->set_abstract(cnode->abstract()); + AnfAlgo::CopyNodeAttr(kAttrAxis, cnode, base_concat); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat); + AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat); + std::vector dyn_input_sizes{SizeToInt(base_concat_inputs.size() - 1)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_concat); + + return base_concat; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.h new file mode 100644 index 00000000000..85504621cf9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_PACK_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_PACK_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +constexpr size_t kPackInputsDivisor = 63; +class PackFission : public PatternProcessPass { + public: + explicit PackFission(bool multigraph = true) + : PatternProcessPass("pack_fission", multigraph), inputs_divisor_(kPackInputsDivisor) {} + ~PackFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + size_t inputs_divisor_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_PACK_FISSION_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index e437ce85345..82bfe7a891b 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -241,6 +241,9 @@ constexpr auto kAttrOffset = "offset"; constexpr auto kAttrPsKey = "ps_key"; constexpr auto kAttrOptimizerType = "optim_type"; constexpr auto kAttrChildGraph = "child_graph"; +constexpr auto kAttrInputNums = "inputNums"; +constexpr auto kAttrT = "T"; +constexpr auto kAttrNum = "num"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/concat_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/concat_fission_test.cc new file mode 100644 index 00000000000..0198a99a592 --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/concat_fission_test.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#define private public +#define protected public +#include "backend/optimizer/ascend/ir_fission/concat_fission.h" +#undef private +#undef protected + +namespace mindspore { +namespace opt { +class TestHWConcatFission : public BackendCommon { + public: + TestHWConcatFission() : get_py_fun_("gtest_input.pre_activate.concat_fission_test", true) {} + ~TestHWConcatFission() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWConcatFission, test_concat_fission_divided_by_2) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto concat_fission = std::make_shared(); + concat_fission->inputs_divisor_ = 2; + pm->AddPass(concat_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_2"); + EXPECT_NE(g_after, nullptr); + auto kg_after = GetKernelGraph(g_after, args_spec_list); + EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); +} + +TEST_F(TestHWConcatFission, test_concat_fission_divided_by_3) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto concat_fission = std::make_shared(); + concat_fission->inputs_divisor_ = 3; + pm->AddPass(concat_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_3"); + EXPECT_NE(g_after, nullptr); + auto kg_after = GetKernelGraph(g_after, args_spec_list); + EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); +} + +TEST_F(TestHWConcatFission, test_concat_fission_divided_by_4) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto concat_fission = std::make_shared(); + concat_fission->inputs_divisor_ = 4; + pm->AddPass(concat_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_4"); + EXPECT_NE(g_after, nullptr); + auto kg_after = GetKernelGraph(g_after, args_spec_list); + EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); +} + +TEST_F(TestHWConcatFission, test_concat_fission_divided_by_8) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto concat_fission = std::make_shared(); + concat_fission->inputs_divisor_ = 8; + pm->AddPass(concat_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_8"); + EXPECT_NE(g_after, nullptr); + auto kg_after = GetKernelGraph(g_after, args_spec_list); + EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); +} + +TEST_F(TestHWConcatFission, test_concat_fission_divided_by_9) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto concat_fission = std::make_shared(); + concat_fission->inputs_divisor_ = 9; + pm->AddPass(concat_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_9"); + EXPECT_NE(g_after, nullptr); + auto kg_after = GetKernelGraph(g_after, args_spec_list); + EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/pack_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/pack_fission_test.cc new file mode 100644 index 00000000000..d22e55c927f --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/pack_fission_test.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#define private public +#define protected public +#include "backend/optimizer/ascend/ir_fission/pack_fission.h" +#undef private +#undef protected + +namespace mindspore { +namespace opt { +class TestHWPackFission : public BackendCommon { + public: + TestHWPackFission() : get_py_fun_("gtest_input.pre_activate.pack_fission_test", true) {} + ~TestHWPackFission() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWPackFission, test_pack_fission_divided_by_3) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_pack_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pack_fission = std::make_shared(); + pack_fission->inputs_divisor_ = 3; + pm->AddPass(pack_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_pack_fission", "after_divided_by_3"); + EXPECT_NE(g_after, nullptr); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWPackFission, test_pack_fission_divided_by_4) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_pack_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pack_fission = std::make_shared(); + pack_fission->inputs_divisor_ = 4; + pm->AddPass(pack_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_pack_fission", "after_divided_by_4"); + EXPECT_NE(g_after, nullptr); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/concat_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/concat_fission_test.py new file mode 100644 index 00000000000..f1f01999eb3 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/concat_fission_test.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.ops import operations as P + +concat = P.Concat() + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_concat_fission(tag): + """ test_adam_apply_one_with_decay_rule """ + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, input5, input6, input7, input8): + return concat((input0, input1, input2, input3, input4, input5, input6, input7, input8)) + + @fns + def after_divided_by_2(input0, input1, input2, input3, input4, input5, input6, input7, input8): + a = concat((input0, input1)) + b = concat((input2, input3)) + c = concat((input4, input5)) + d = concat((input6, input7)) + f = concat((a, b)) + g = concat((c, d)) + i = concat((f, g)) + return concat((i, input8)) + + @fns + def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): + a = concat((input0, input1, input2)) + b = concat((input3, input4, input5)) + c = concat((input6, input7, input8)) + return concat((a, b, c)) + + @fns + def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): + a = concat((input0, input1, input2, input3)) + b = concat((input4, input5, input6, input7)) + return concat((a, b, input8)) + + @fns + def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8): + a = concat((input0, input1, input2, input3, input4, input5, input6, input7)) + return concat((a, input8)) + + @fns + def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8): + return concat((input0, input1, input2, input3, input4, input5, input6, input7, input8)) + + return fns[tag] diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/pack_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/pack_fission_test.py new file mode 100644 index 00000000000..8678c6273c9 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/pack_fission_test.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.ops import operations as P +from mindspore.ops import Primitive + +pack = P.Pack() +concat = P.Concat() +make_tuple = Primitive('make_tuple') + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_pack_fission(tag): + """ test_adam_apply_one_with_decay_rule """ + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, input5, input6, input7, input8): + return pack((input0, input1, input2, input3, input4, input5, input6, input7, input8)) + + @fns + def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): + pack1 = pack(input0, input1, input2) + pack2 = pack(input3, input4, input5) + pack3 = pack(input6, input7, input8) + return make_tuple(concat(pack1, pack2, pack3)) + + @fns + def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): + pack1 = pack(input0, input1, input2, input3) + pack2 = pack(input4, input5, input6, input7) + pack3 = pack(input8) + return make_tuple(concat(pack1, pack2, pack3)) + + return fns[tag]