forked from mindspore-Ecosystem/mindspore
!3194 Add concat and pack fission pass
Merge pull request !3194 from YuJianfeng/concat
This commit is contained in:
commit
419022f2a5
|
@ -97,6 +97,8 @@
|
|||
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
|
||||
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
@ -153,6 +155,8 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SplitFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<PackFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ConcatFission>());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origin_concat_cnode, size_t begin_index,
|
||||
size_t offset) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_concat_cnode);
|
||||
std::vector<AnfNodePtr> new_concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
||||
for (size_t i = begin_index; i < begin_index + offset; ++i) {
|
||||
new_concat_inputs.push_back(origin_concat_cnode->input(i));
|
||||
}
|
||||
CNodePtr new_concat = func_graph->NewCNode(new_concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_concat);
|
||||
new_concat->set_scope(origin_concat_cnode->scope());
|
||||
// Set attrs
|
||||
AnfAlgo::CopyNodeAttr(kAttrAxis, origin_concat_cnode, new_concat);
|
||||
AnfAlgo::CopyNodeAttr(kAttrT, origin_concat_cnode, new_concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(offset)), new_concat);
|
||||
std::vector<int> dyn_input_sizes{SizeToInt(offset)};
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_concat);
|
||||
// infer shape
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_concat_cnode, 0);
|
||||
auto axis = AnfAlgo::GetNodeAttr<int>(origin_concat_cnode, kAttrAxis);
|
||||
if (axis < 0) {
|
||||
axis += input_shape.size();
|
||||
}
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(origin_concat_cnode, 0);
|
||||
if (axis < 0 || axis >= SizeToInt(output_shape.size()) || axis >= SizeToInt(input_shape.size())) {
|
||||
MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range";
|
||||
}
|
||||
output_shape[axis] = input_shape[axis] * offset;
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_concat_cnode, 0)}, {output_shape},
|
||||
new_concat.get());
|
||||
return new_concat;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef ConcatFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimConcat, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// The real input begins with index 1.
|
||||
size_t origin_input_size = cnode->inputs().size() - 1;
|
||||
if (origin_input_size <= inputs_divisor_) {
|
||||
return nullptr;
|
||||
}
|
||||
CNodePtr new_cnode = cnode;
|
||||
while (origin_input_size > inputs_divisor_) {
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
std::vector<AnfNodePtr> base_concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
||||
size_t cur_input_index = 1;
|
||||
// Divide the inputs of concat by inputs_divisor_.
|
||||
while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) {
|
||||
base_concat_inputs.push_back(CreateNewConcat(func_graph, new_cnode, cur_input_index, inputs_divisor_));
|
||||
cur_input_index += inputs_divisor_;
|
||||
}
|
||||
for (size_t i = cur_input_index; i <= origin_input_size; i++) {
|
||||
base_concat_inputs.push_back(new_cnode->input(i));
|
||||
}
|
||||
CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(base_concat);
|
||||
base_concat->set_scope(new_cnode->scope());
|
||||
base_concat->set_abstract(new_cnode->abstract());
|
||||
// Set attrs
|
||||
AnfAlgo::CopyNodeAttr(kAttrAxis, new_cnode, base_concat);
|
||||
AnfAlgo::CopyNodeAttr(kAttrT, new_cnode, base_concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat);
|
||||
std::vector<int> dyn_input_sizes{SizeToInt(base_concat_inputs.size() - 1)};
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_concat);
|
||||
|
||||
new_cnode = base_concat;
|
||||
origin_input_size = base_concat->inputs().size() - 1;
|
||||
}
|
||||
|
||||
return new_cnode;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_CONCAT_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_CONCAT_FISSION_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr size_t kConcatInputsDivisor = 63;
|
||||
class ConcatFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConcatFission(bool multigraph = true)
|
||||
: PatternProcessPass("concat_fission", multigraph), inputs_divisor_(kConcatInputsDivisor) {}
|
||||
~ConcatFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
size_t inputs_divisor_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_CONCAT_FISSION_H_
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_pack_cnode, size_t begin_index,
|
||||
size_t offset) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_pack_cnode);
|
||||
std::vector<AnfNodePtr> new_pack_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimPack->name()))};
|
||||
for (size_t i = begin_index; i < begin_index + offset; ++i) {
|
||||
new_pack_inputs.push_back(origin_pack_cnode->input(i));
|
||||
}
|
||||
CNodePtr new_pack = func_graph->NewCNode(new_pack_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_pack);
|
||||
new_pack->set_scope(origin_pack_cnode->scope());
|
||||
new_pack->set_abstract(origin_pack_cnode->abstract());
|
||||
AnfAlgo::CopyNodeAttr(kAttrAxis, origin_pack_cnode, new_pack);
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_pack);
|
||||
AnfAlgo::SetNodeAttr(kAttrNum, MakeValue(SizeToInt(offset)), new_pack);
|
||||
std::vector<int> dyn_input_sizes{SizeToInt(offset)};
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_pack);
|
||||
// infer shape
|
||||
auto output_shape = AnfAlgo ::GetOutputInferShape(origin_pack_cnode, 0);
|
||||
auto axis = AnfAlgo::GetNodeAttr<int>(new_pack, kAttrAxis);
|
||||
if (axis < 0) {
|
||||
axis += output_shape.size();
|
||||
}
|
||||
if (axis < 0) {
|
||||
MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range";
|
||||
}
|
||||
std::vector<size_t> new_shape;
|
||||
for (size_t i = 0; i < output_shape.size() + 1; ++i) {
|
||||
if (i < IntToSize(axis)) {
|
||||
new_shape.push_back(output_shape[i]);
|
||||
} else if (i == IntToSize(axis)) {
|
||||
new_shape.push_back(offset);
|
||||
} else {
|
||||
new_shape.push_back(output_shape[i - 1]);
|
||||
}
|
||||
}
|
||||
new_shape.erase(new_shape.begin() + axis + 1);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {output_shape},
|
||||
new_pack.get());
|
||||
return new_pack;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef PackFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimPack, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// The real input begins with index 1.
|
||||
size_t origin_input_size = cnode->inputs().size() - 1;
|
||||
if (origin_input_size <= inputs_divisor_) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> base_concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
||||
size_t cur_input_index = 1;
|
||||
// Divide the inputs of pack by inputs_divisor_.
|
||||
while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) {
|
||||
base_concat_inputs.push_back(CreateNewPack(func_graph, cnode, cur_input_index, inputs_divisor_));
|
||||
cur_input_index += inputs_divisor_;
|
||||
}
|
||||
if (cur_input_index <= origin_input_size) {
|
||||
base_concat_inputs.push_back(
|
||||
CreateNewPack(func_graph, cnode, cur_input_index, origin_input_size - cur_input_index + 1));
|
||||
}
|
||||
|
||||
CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(base_concat);
|
||||
base_concat->set_scope(cnode->scope());
|
||||
base_concat->set_abstract(cnode->abstract());
|
||||
AnfAlgo::CopyNodeAttr(kAttrAxis, cnode, base_concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat);
|
||||
std::vector<int> dyn_input_sizes{SizeToInt(base_concat_inputs.size() - 1)};
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_concat);
|
||||
|
||||
return base_concat;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_PACK_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_PACK_FISSION_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr size_t kPackInputsDivisor = 63;
|
||||
class PackFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit PackFission(bool multigraph = true)
|
||||
: PatternProcessPass("pack_fission", multigraph), inputs_divisor_(kPackInputsDivisor) {}
|
||||
~PackFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
size_t inputs_divisor_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_PACK_FISSION_H_
|
|
@ -241,6 +241,9 @@ constexpr auto kAttrOffset = "offset";
|
|||
constexpr auto kAttrPsKey = "ps_key";
|
||||
constexpr auto kAttrOptimizerType = "optim_type";
|
||||
constexpr auto kAttrChildGraph = "child_graph";
|
||||
constexpr auto kAttrInputNums = "inputNums";
|
||||
constexpr auto kAttrT = "T";
|
||||
constexpr auto kAttrNum = "num";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TestHWConcatFission : public BackendCommon {
|
||||
public:
|
||||
TestHWConcatFission() : get_py_fun_("gtest_input.pre_activate.concat_fission_test", true) {}
|
||||
~TestHWConcatFission() override = default;
|
||||
|
||||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_2) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto concat_fission = std::make_shared<opt::ConcatFission>();
|
||||
concat_fission->inputs_divisor_ = 2;
|
||||
pm->AddPass(concat_fission);
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_2");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_3) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto concat_fission = std::make_shared<opt::ConcatFission>();
|
||||
concat_fission->inputs_divisor_ = 3;
|
||||
pm->AddPass(concat_fission);
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_3");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_4) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto concat_fission = std::make_shared<opt::ConcatFission>();
|
||||
concat_fission->inputs_divisor_ = 4;
|
||||
pm->AddPass(concat_fission);
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_4");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_8) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto concat_fission = std::make_shared<opt::ConcatFission>();
|
||||
concat_fission->inputs_divisor_ = 8;
|
||||
pm->AddPass(concat_fission);
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_8");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_9) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto concat_fission = std::make_shared<opt::ConcatFission>();
|
||||
concat_fission->inputs_divisor_ = 9;
|
||||
pm->AddPass(concat_fission);
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_9");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TestHWPackFission : public BackendCommon {
|
||||
public:
|
||||
TestHWPackFission() : get_py_fun_("gtest_input.pre_activate.pack_fission_test", true) {}
|
||||
~TestHWPackFission() override = default;
|
||||
|
||||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
TEST_F(TestHWPackFission, test_pack_fission_divided_by_3) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_pack_fission", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pack_fission = std::make_shared<opt::PackFission>();
|
||||
pack_fission->inputs_divisor_ = 3;
|
||||
pm->AddPass(pack_fission);
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_pack_fission", "after_divided_by_3");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWPackFission, test_pack_fission_divided_by_4) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_pack_fission", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pack_fission = std::make_shared<opt::PackFission>();
|
||||
pack_fission->inputs_divisor_ = 4;
|
||||
pm->AddPass(pack_fission);
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_pack_fission", "after_divided_by_4");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
concat = P.Concat()
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fnDict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fnDict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fnDict[name]
|
||||
|
||||
|
||||
def test_concat_fission(tag):
|
||||
""" test_adam_apply_one_with_decay_rule """
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, input5, input6, input7, input8):
|
||||
return concat((input0, input1, input2, input3, input4, input5, input6, input7, input8))
|
||||
|
||||
@fns
|
||||
def after_divided_by_2(input0, input1, input2, input3, input4, input5, input6, input7, input8):
|
||||
a = concat((input0, input1))
|
||||
b = concat((input2, input3))
|
||||
c = concat((input4, input5))
|
||||
d = concat((input6, input7))
|
||||
f = concat((a, b))
|
||||
g = concat((c, d))
|
||||
i = concat((f, g))
|
||||
return concat((i, input8))
|
||||
|
||||
@fns
|
||||
def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8):
|
||||
a = concat((input0, input1, input2))
|
||||
b = concat((input3, input4, input5))
|
||||
c = concat((input6, input7, input8))
|
||||
return concat((a, b, c))
|
||||
|
||||
@fns
|
||||
def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8):
|
||||
a = concat((input0, input1, input2, input3))
|
||||
b = concat((input4, input5, input6, input7))
|
||||
return concat((a, b, input8))
|
||||
|
||||
@fns
|
||||
def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8):
|
||||
a = concat((input0, input1, input2, input3, input4, input5, input6, input7))
|
||||
return concat((a, input8))
|
||||
|
||||
@fns
|
||||
def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8):
|
||||
return concat((input0, input1, input2, input3, input4, input5, input6, input7, input8))
|
||||
|
||||
return fns[tag]
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import Primitive
|
||||
|
||||
pack = P.Pack()
|
||||
concat = P.Concat()
|
||||
make_tuple = Primitive('make_tuple')
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fnDict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fnDict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fnDict[name]
|
||||
|
||||
|
||||
def test_pack_fission(tag):
|
||||
""" test_adam_apply_one_with_decay_rule """
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, input5, input6, input7, input8):
|
||||
return pack((input0, input1, input2, input3, input4, input5, input6, input7, input8))
|
||||
|
||||
@fns
|
||||
def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8):
|
||||
pack1 = pack(input0, input1, input2)
|
||||
pack2 = pack(input3, input4, input5)
|
||||
pack3 = pack(input6, input7, input8)
|
||||
return make_tuple(concat(pack1, pack2, pack3))
|
||||
|
||||
@fns
|
||||
def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8):
|
||||
pack1 = pack(input0, input1, input2, input3)
|
||||
pack2 = pack(input4, input5, input6, input7)
|
||||
pack3 = pack(input8)
|
||||
return make_tuple(concat(pack1, pack2, pack3))
|
||||
|
||||
return fns[tag]
|
Loading…
Reference in New Issue