Add Split fission pass

This commit is contained in:
yujianfeng 2020-06-19 11:46:38 +08:00
parent f3f95b255b
commit 7ad877a948
8 changed files with 352 additions and 0 deletions

View File

@ -90,6 +90,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"},
{"lamb_next_mv", "lamb_next_m_v"},
{"split", "split_d"},
{"split_v", "split_v_d"},
{"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"},
{"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"},
{"pad", "pad_d"},

View File

@ -87,6 +87,7 @@
#include "pre_activate/ascend/ir_fission/addn_fission.h"
#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h"
#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h"
#include "pre_activate/ascend/ir_fission/split_fission.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
@ -141,6 +142,8 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>());
ir_fusion_pm->AddPass(std::make_shared<SplitFission>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
}
} // namespace

View File

@ -0,0 +1,191 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fission/split_fission.h"
#include <memory>
#include <vector>
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
namespace {
CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(input_node);
std::vector<AnfNodePtr> splitv_inputs{NewValueNode(std::make_shared<Primitive>(kSplitVOpName)), input_node};
CNodePtr splitv = func_graph->NewCNode(splitv_inputs);
MS_EXCEPTION_IF_NULL(splitv);
splitv->set_scope(input_node->scope());
return splitv;
}
CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) {
MS_EXCEPTION_IF_NULL(origin_cnode);
if (origin_cnode->inputs().size() < kSplitInputNum) {
MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be "
<< kSplitInputNum - 1;
}
return CreateSplitVNode(func_graph, origin_cnode->input(1));
}
void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector<int> &size_splits, int split_dim, int num_split) {
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv);
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(split_dim), splitv);
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(num_split), splitv);
}
size_t GetSmallSplitSize(const AnfNodePtr &split_node, int split_dim, int num_split) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(split_node, 0);
if (split_dim < 0) {
split_dim += input_shape.size();
}
if (IntToSize(split_dim) >= input_shape.size()) {
MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0";
}
return input_shape[split_dim] / num_split;
}
void AddNewOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &new_splitv, int outputs_num,
std::vector<AnfNodePtr> *inputs) {
MS_EXCEPTION_IF_NULL(inputs);
std::vector<AnfNodePtr> new_splitv_output;
CreateMultipleOutputsOfAnfNode(func_graph, new_splitv, outputs_num, &new_splitv_output);
inputs->insert(inputs->end(), new_splitv_output.begin(), new_splitv_output.end());
}
AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) {
MS_EXCEPTION_IF_NULL(func_graph);
auto idx = NewValueNode(SizeToInt(index));
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int32Imm>(SizeToInt(index));
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
idx->set_abstract(abstract_scalar);
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
return tuple_getitem;
}
void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int split_size, int num_split,
std::vector<TypeId> *new_type_ids,
std::vector<std::vector<size_t>> *new_output_shapes) {
MS_EXCEPTION_IF_NULL(new_type_ids);
MS_EXCEPTION_IF_NULL(new_output_shapes);
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
output_shape[split_dim] = split_size;
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
for (int i = 0; i < num_split; ++i) {
new_type_ids->emplace_back(type_id);
new_output_shapes->emplace_back(output_shape);
}
}
void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv,
const std::vector<int> &size_splits_base, int split_dim, int num_split) {
SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split);
std::vector<TypeId> base_type_ids;
std::vector<std::vector<size_t>> base_output_shapes_base;
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
for (int i = 0; i < num_split; ++i) {
output_shape[split_dim] = size_splits_base[i];
base_output_shapes_base.emplace_back(output_shape);
base_type_ids.emplace_back(type_id);
}
AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get());
}
AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int num_split, int divisor) {
MS_EXCEPTION_IF_NULL(func_graph);
auto split_dim = AnfAlgo::GetNodeAttr<int>(cnode, kAttrAxis);
CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode);
// Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs.
auto small_split_size = SizeToInt(GetSmallSplitSize(cnode, split_dim, num_split));
std::vector<int> size_splits_new;
for (int i = 0; i < divisor; ++i) {
size_splits_new.emplace_back(small_split_size);
}
// Create new output shape and new output type id for each new Splitv node which has full inputs.
std::vector<TypeId> new_type_ids;
std::vector<std::vector<size_t>> new_output_shapes;
CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes);
// Create make_tuple input to create a make_tuple for replacing the old Split node.
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
// Start to divide the outputs of Split.
std::vector<int> size_splits_base;
const auto base_split_size = divisor * small_split_size;
int nodes_num = 0;
int cur_output_index = 0;
while (num_split - cur_output_index > divisor) {
CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num));
SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor);
AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get());
AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs);
cur_output_index += divisor;
size_splits_base.emplace_back(base_split_size);
nodes_num++;
}
if (cur_output_index < num_split) {
auto last_node_num_split = num_split - cur_output_index;
if (last_node_num_split > 1) {
CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num));
std::vector<int> size_splits_new_last;
for (int i = 0; i < last_node_num_split; ++i) {
size_splits_new_last.emplace_back(small_split_size);
}
SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split);
// Create new output shape and new output type id for the last Splitv node
std::vector<TypeId> last_new_type_ids;
std::vector<std::vector<size_t>> last_new_output_shapes;
CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, last_node_num_split, &last_new_type_ids,
&last_new_output_shapes);
AnfAlgo::SetOutputInferTypeAndShape(last_new_type_ids, last_new_output_shapes, new_splitv.get());
AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs);
size_splits_base.emplace_back(last_node_num_split * small_split_size);
} else {
make_tuple_inputs.emplace_back(CreateTupleGetItem(func_graph, base_splitv, nodes_num));
size_splits_base.emplace_back(small_split_size);
}
nodes_num++;
}
// Set Attr and abstract for the base splitv
SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, size_splits_base, split_dim, nodes_num);
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
return make_tuple;
}
} // namespace
const BaseRef SplitFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto split_prim = std::make_shared<Primitive>(kSplitOpName);
return VectorRef({split_prim, Xs});
}
const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// Check output num
if (!AnfAlgo::HasNodeAttr(kAttrOutputNum, cnode)) {
return nullptr;
}
auto num_split = AnfAlgo::GetNodeAttr<int>(cnode, kAttrOutputNum);
if (num_split <= outputs_divisor_) {
return nullptr;
}
return DoFission(func_graph, cnode, num_split, outputs_divisor_);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_
#include "pre_activate/common/optimizer.h"
namespace mindspore {
namespace opt {
constexpr int kSplitOutputsDivisor = 63;
class SplitFission : public PatternProcessPass {
public:
explicit SplitFission(bool multigraph = true)
: PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {}
~SplitFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
int outputs_divisor_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_

View File

@ -97,6 +97,7 @@ constexpr size_t kBiasAddInputNum = 3;
constexpr size_t kTopkInputNum = 3;
constexpr size_t kLarsV2InputNum = 5;
constexpr size_t kFusedMulApplyMomentumOutputNum = 2;
constexpr size_t kSplitInputNum = 2;
enum FusedBatchNormInput {
kX = 1,

View File

@ -72,6 +72,7 @@ constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin";
constexpr auto kFlattenGradOpName = "FlattenGrad";
constexpr auto kExpandDimsOpName = "ExpandDims";
constexpr auto kSplitOpName = "Split";
constexpr auto kSplitVOpName = "SplitV";
constexpr auto kSparseApplyAdagradOpName = "SparseApplyAdagrad";
constexpr auto kMomentumOpName = "Momentum";
constexpr auto kApplyMomentumOpName = "ApplyMomentum";
@ -211,6 +212,10 @@ constexpr auto kAttrWaitEvent = "wait_event";
constexpr auto kAttrRecordEventStream = "record_event_stream";
constexpr auto kAttrWaitEventStream = "wait_event_stream";
constexpr auto kAttrIndex = "index";
constexpr auto kAttrSplitDim = "split_dim";
constexpr auto kAttrNumSplit = "num_split";
constexpr auto kAttrOutputNum = "output_num";
constexpr auto kAttrSizeSplits = "size_splits";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";

View File

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#define private public
#define protected public
#include "pre_activate/ascend/ir_fission/split_fission.h"
#undef private
#undef protected
namespace mindspore {
namespace opt {
class TestHWSplitFission : public BackendCommon {
public:
TestHWSplitFission() : get_py_fun_("gtest_input.pre_activate.split_fission_test", true) {}
~TestHWSplitFission() override = default;
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWSplitFission, test_split_fission_divided_by_3) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_split_fission", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp{512, 3, 1};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
args_spec_list.push_back(x_abstract);
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto split_fission = std::make_shared<opt::SplitFission>();
split_fission->outputs_divisor_ = 3;
pm->AddPass(split_fission);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_split_fission", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,58 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P
split = P.Split(0, 8)
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
splitv = Primitive('SplitV')
class FnDict:
def __init__(self):
self.fnDict = {}
def __call__(self, fn):
self.fnDict[fn.__name__] = fn
def __getitem__(self, name):
return self.fnDict[name]
def test_split_fission(tag):
""" test_adam_apply_one_with_decay_rule """
fns = FnDict()
@fns
def before(x):
return split(x)
@fns
def after(x):
splitv0 = splitv(x)
splitv1 = splitv(tuple_getitem(splitv0, 0))
splitv2 = splitv(tuple_getitem(splitv0, 1))
splitv3 = splitv(tuple_getitem(splitv0, 2))
make_tuple0 = make_tuple(tuple_getitem(splitv1, 0), tuple_getitem(splitv1, 1), tuple_getitem(splitv1, 2),
tuple_getitem(splitv2, 0), tuple_getitem(splitv2, 1), tuple_getitem(splitv2, 2),
tuple_getitem(splitv3, 0), tuple_getitem(splitv3, 1))
return make_tuple(
make_tuple(tuple_getitem(make_tuple0, 0), tuple_getitem(make_tuple0, 1), tuple_getitem(make_tuple0, 2),
tuple_getitem(make_tuple0, 3), tuple_getitem(make_tuple0, 4), tuple_getitem(make_tuple0, 5),
tuple_getitem(make_tuple0, 6), tuple_getitem(make_tuple0, 7)))
return fns[tag]