forked from mindspore-Ecosystem/mindspore
Add Split fission pass
This commit is contained in:
parent
f3f95b255b
commit
7ad877a948
|
@ -90,6 +90,7 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"},
|
||||
{"lamb_next_mv", "lamb_next_m_v"},
|
||||
{"split", "split_d"},
|
||||
{"split_v", "split_v_d"},
|
||||
{"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"},
|
||||
{"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"},
|
||||
{"pad", "pad_d"},
|
||||
|
|
|
@ -87,6 +87,7 @@
|
|||
#include "pre_activate/ascend/ir_fission/addn_fission.h"
|
||||
#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h"
|
||||
#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h"
|
||||
#include "pre_activate/ascend/ir_fission/split_fission.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
@ -141,6 +142,8 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SplitFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -0,0 +1,191 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/ascend/ir_fission/split_fission.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
std::vector<AnfNodePtr> splitv_inputs{NewValueNode(std::make_shared<Primitive>(kSplitVOpName)), input_node};
|
||||
CNodePtr splitv = func_graph->NewCNode(splitv_inputs);
|
||||
MS_EXCEPTION_IF_NULL(splitv);
|
||||
splitv->set_scope(input_node->scope());
|
||||
return splitv;
|
||||
}
|
||||
|
||||
CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(origin_cnode);
|
||||
if (origin_cnode->inputs().size() < kSplitInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be "
|
||||
<< kSplitInputNum - 1;
|
||||
}
|
||||
return CreateSplitVNode(func_graph, origin_cnode->input(1));
|
||||
}
|
||||
|
||||
void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector<int> &size_splits, int split_dim, int num_split) {
|
||||
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv);
|
||||
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(split_dim), splitv);
|
||||
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(num_split), splitv);
|
||||
}
|
||||
|
||||
size_t GetSmallSplitSize(const AnfNodePtr &split_node, int split_dim, int num_split) {
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(split_node, 0);
|
||||
if (split_dim < 0) {
|
||||
split_dim += input_shape.size();
|
||||
}
|
||||
if (IntToSize(split_dim) >= input_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0";
|
||||
}
|
||||
return input_shape[split_dim] / num_split;
|
||||
}
|
||||
|
||||
void AddNewOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &new_splitv, int outputs_num,
|
||||
std::vector<AnfNodePtr> *inputs) {
|
||||
MS_EXCEPTION_IF_NULL(inputs);
|
||||
std::vector<AnfNodePtr> new_splitv_output;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, new_splitv, outputs_num, &new_splitv_output);
|
||||
inputs->insert(inputs->end(), new_splitv_output.begin(), new_splitv_output.end());
|
||||
}
|
||||
|
||||
AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto idx = NewValueNode(SizeToInt(index));
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int32Imm>(SizeToInt(index));
|
||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
||||
idx->set_abstract(abstract_scalar);
|
||||
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
|
||||
return tuple_getitem;
|
||||
}
|
||||
|
||||
void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int split_size, int num_split,
|
||||
std::vector<TypeId> *new_type_ids,
|
||||
std::vector<std::vector<size_t>> *new_output_shapes) {
|
||||
MS_EXCEPTION_IF_NULL(new_type_ids);
|
||||
MS_EXCEPTION_IF_NULL(new_output_shapes);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
|
||||
output_shape[split_dim] = split_size;
|
||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
|
||||
for (int i = 0; i < num_split; ++i) {
|
||||
new_type_ids->emplace_back(type_id);
|
||||
new_output_shapes->emplace_back(output_shape);
|
||||
}
|
||||
}
|
||||
|
||||
void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv,
|
||||
const std::vector<int> &size_splits_base, int split_dim, int num_split) {
|
||||
SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split);
|
||||
std::vector<TypeId> base_type_ids;
|
||||
std::vector<std::vector<size_t>> base_output_shapes_base;
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
|
||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
|
||||
for (int i = 0; i < num_split; ++i) {
|
||||
output_shape[split_dim] = size_splits_base[i];
|
||||
base_output_shapes_base.emplace_back(output_shape);
|
||||
base_type_ids.emplace_back(type_id);
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get());
|
||||
}
|
||||
|
||||
AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int num_split, int divisor) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto split_dim = AnfAlgo::GetNodeAttr<int>(cnode, kAttrAxis);
|
||||
CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode);
|
||||
|
||||
// Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs.
|
||||
auto small_split_size = SizeToInt(GetSmallSplitSize(cnode, split_dim, num_split));
|
||||
std::vector<int> size_splits_new;
|
||||
for (int i = 0; i < divisor; ++i) {
|
||||
size_splits_new.emplace_back(small_split_size);
|
||||
}
|
||||
// Create new output shape and new output type id for each new Splitv node which has full inputs.
|
||||
std::vector<TypeId> new_type_ids;
|
||||
std::vector<std::vector<size_t>> new_output_shapes;
|
||||
CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes);
|
||||
|
||||
// Create make_tuple input to create a make_tuple for replacing the old Split node.
|
||||
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
// Start to divide the outputs of Split.
|
||||
std::vector<int> size_splits_base;
|
||||
const auto base_split_size = divisor * small_split_size;
|
||||
int nodes_num = 0;
|
||||
int cur_output_index = 0;
|
||||
while (num_split - cur_output_index > divisor) {
|
||||
CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num));
|
||||
SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get());
|
||||
AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs);
|
||||
cur_output_index += divisor;
|
||||
size_splits_base.emplace_back(base_split_size);
|
||||
nodes_num++;
|
||||
}
|
||||
if (cur_output_index < num_split) {
|
||||
auto last_node_num_split = num_split - cur_output_index;
|
||||
if (last_node_num_split > 1) {
|
||||
CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num));
|
||||
std::vector<int> size_splits_new_last;
|
||||
for (int i = 0; i < last_node_num_split; ++i) {
|
||||
size_splits_new_last.emplace_back(small_split_size);
|
||||
}
|
||||
SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split);
|
||||
// Create new output shape and new output type id for the last Splitv node
|
||||
std::vector<TypeId> last_new_type_ids;
|
||||
std::vector<std::vector<size_t>> last_new_output_shapes;
|
||||
CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, last_node_num_split, &last_new_type_ids,
|
||||
&last_new_output_shapes);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(last_new_type_ids, last_new_output_shapes, new_splitv.get());
|
||||
AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs);
|
||||
size_splits_base.emplace_back(last_node_num_split * small_split_size);
|
||||
} else {
|
||||
make_tuple_inputs.emplace_back(CreateTupleGetItem(func_graph, base_splitv, nodes_num));
|
||||
size_splits_base.emplace_back(small_split_size);
|
||||
}
|
||||
nodes_num++;
|
||||
}
|
||||
// Set Attr and abstract for the base splitv
|
||||
SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, size_splits_base, split_dim, nodes_num);
|
||||
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef SplitFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto split_prim = std::make_shared<Primitive>(kSplitOpName);
|
||||
return VectorRef({split_prim, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Check output num
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrOutputNum, cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto num_split = AnfAlgo::GetNodeAttr<int>(cnode, kAttrOutputNum);
|
||||
if (num_split <= outputs_divisor_) {
|
||||
return nullptr;
|
||||
}
|
||||
return DoFission(func_graph, cnode, num_split, outputs_divisor_);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_
|
||||
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr int kSplitOutputsDivisor = 63;
|
||||
class SplitFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit SplitFission(bool multigraph = true)
|
||||
: PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {}
|
||||
~SplitFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
int outputs_divisor_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_
|
|
@ -97,6 +97,7 @@ constexpr size_t kBiasAddInputNum = 3;
|
|||
constexpr size_t kTopkInputNum = 3;
|
||||
constexpr size_t kLarsV2InputNum = 5;
|
||||
constexpr size_t kFusedMulApplyMomentumOutputNum = 2;
|
||||
constexpr size_t kSplitInputNum = 2;
|
||||
|
||||
enum FusedBatchNormInput {
|
||||
kX = 1,
|
||||
|
|
|
@ -72,6 +72,7 @@ constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin";
|
|||
constexpr auto kFlattenGradOpName = "FlattenGrad";
|
||||
constexpr auto kExpandDimsOpName = "ExpandDims";
|
||||
constexpr auto kSplitOpName = "Split";
|
||||
constexpr auto kSplitVOpName = "SplitV";
|
||||
constexpr auto kSparseApplyAdagradOpName = "SparseApplyAdagrad";
|
||||
constexpr auto kMomentumOpName = "Momentum";
|
||||
constexpr auto kApplyMomentumOpName = "ApplyMomentum";
|
||||
|
@ -211,6 +212,10 @@ constexpr auto kAttrWaitEvent = "wait_event";
|
|||
constexpr auto kAttrRecordEventStream = "record_event_stream";
|
||||
constexpr auto kAttrWaitEventStream = "wait_event_stream";
|
||||
constexpr auto kAttrIndex = "index";
|
||||
constexpr auto kAttrSplitDim = "split_dim";
|
||||
constexpr auto kAttrNumSplit = "num_split";
|
||||
constexpr auto kAttrOutputNum = "output_num";
|
||||
constexpr auto kAttrSizeSplits = "size_splits";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "pre_activate/ascend/ir_fission/split_fission.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TestHWSplitFission : public BackendCommon {
|
||||
public:
|
||||
TestHWSplitFission() : get_py_fun_("gtest_input.pre_activate.split_fission_test", true) {}
|
||||
~TestHWSplitFission() override = default;
|
||||
|
||||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
TEST_F(TestHWSplitFission, test_split_fission_divided_by_3) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_split_fission", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp{512, 3, 1};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
args_spec_list.push_back(x_abstract);
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto split_fission = std::make_shared<opt::SplitFission>();
|
||||
split_fission->outputs_divisor_ = 3;
|
||||
pm->AddPass(split_fission);
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_split_fission", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
split = P.Split(0, 8)
|
||||
make_tuple = Primitive('make_tuple')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
splitv = Primitive('SplitV')
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fnDict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fnDict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fnDict[name]
|
||||
|
||||
|
||||
def test_split_fission(tag):
|
||||
""" test_adam_apply_one_with_decay_rule """
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(x):
|
||||
return split(x)
|
||||
|
||||
@fns
|
||||
def after(x):
|
||||
splitv0 = splitv(x)
|
||||
splitv1 = splitv(tuple_getitem(splitv0, 0))
|
||||
splitv2 = splitv(tuple_getitem(splitv0, 1))
|
||||
splitv3 = splitv(tuple_getitem(splitv0, 2))
|
||||
make_tuple0 = make_tuple(tuple_getitem(splitv1, 0), tuple_getitem(splitv1, 1), tuple_getitem(splitv1, 2),
|
||||
tuple_getitem(splitv2, 0), tuple_getitem(splitv2, 1), tuple_getitem(splitv2, 2),
|
||||
tuple_getitem(splitv3, 0), tuple_getitem(splitv3, 1))
|
||||
return make_tuple(
|
||||
make_tuple(tuple_getitem(make_tuple0, 0), tuple_getitem(make_tuple0, 1), tuple_getitem(make_tuple0, 2),
|
||||
tuple_getitem(make_tuple0, 3), tuple_getitem(make_tuple0, 4), tuple_getitem(make_tuple0, 5),
|
||||
tuple_getitem(make_tuple0, 6), tuple_getitem(make_tuple0, 7)))
|
||||
|
||||
return fns[tag]
|
Loading…
Reference in New Issue