forked from mindspore-Ecosystem/mindspore
!15933 add single conv split
From: @zhujingxuan Reviewed-by: @wangchengyuan,@hangangqiang Signed-off-by: @wangchengyuan
This commit is contained in:
commit
55b15079d1
|
@ -59,9 +59,13 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/fusion/tf_gelu_fusion.cc
|
||||
../optimizer/fusion/onnx_gelu_fusion.cc
|
||||
../optimizer/fusion/squeeze_fusion.cc
|
||||
../optimizer/fisson/eliminate_concat_split.cc
|
||||
../optimizer/fisson/fisson_util.cc
|
||||
../optimizer/fisson/iter_node_outputs.cc
|
||||
../optimizer/fisson/node_out_shapes.cc
|
||||
../optimizer/parallel/dynamic_creator.cc
|
||||
../optimizer/parallel/operator_info.cc
|
||||
../optimizer/parallel/parallel_pass.cc
|
||||
../optimizer/graph/conv1d_inout_adjust_pass.cc
|
||||
../optimizer/graph/weight_format_transform_pass.cc
|
||||
../optimizer/graph/weight_format_hardcode_pass.cc
|
||||
|
|
|
@ -645,6 +645,13 @@ bool IsSqueezeNode(const BaseRef &n) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsConcatNode(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimConcat);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CheckIsAllInputsParam(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
|
|
|
@ -94,6 +94,8 @@ size_t GetOutputTensorNum(const AnfNodePtr &node);
|
|||
|
||||
bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node);
|
||||
|
||||
bool IsConcatNode(const BaseRef &n);
|
||||
|
||||
size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
|
||||
|
||||
tensor::TensorPtr GetTensorInfo(const AnfNodePtr &node);
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "tools/optimizer/fisson/eliminate_concat_split.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "utils/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "mindspore/core/ops/split_with_overlap.h"
|
||||
#include "mindspore/core/ops/concat.h"
|
||||
#include "mindspore/core/base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
namespace opt {
|
||||
|
||||
const BaseRef EliminateConcatSplit::DefinePattern() const {
|
||||
auto concat_var = std::make_shared<CondVar>(IsConcatNode);
|
||||
auto split_prim = std::make_shared<ops::SplitWithOverlap>();
|
||||
|
||||
return VectorRef({split_prim, concat_var});
|
||||
}
|
||||
|
||||
CNodePtr GetRealPrevCNode(const AnfNodePtr &node) {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
||||
if (IsRealCNodeKernel(cnode)) {
|
||||
return cnode;
|
||||
}
|
||||
|
||||
auto input0 = cnode->input(0);
|
||||
|
||||
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
|
||||
auto temp_node = cnode->input(1);
|
||||
if (temp_node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
return GetRealPrevCNode(temp_node);
|
||||
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
||||
return GetRealPrevCNode(cnode->input(1));
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void ConcatSplitEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
auto pre_cnode = GetRealPrevCNode(cnode->input(1));
|
||||
CheckIfCNodeIsNull(pre_cnode);
|
||||
|
||||
if (!CheckPrimitiveType(pre_cnode, prim::kPrimConcat)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto finder = g_graph_nodes_output.find(pre_cnode->fullname_with_scope());
|
||||
if (finder == g_graph_nodes_output.end()) {
|
||||
return;
|
||||
}
|
||||
if (finder->second.size() > 1) return;
|
||||
|
||||
size_t pre_inputs_size = pre_cnode->inputs().size();
|
||||
int pre_inputs_node_size = pre_inputs_size - 1;
|
||||
auto pre_prim = GetValueNode<std::shared_ptr<ops::Concat>>(pre_cnode->input(kAnfPrimitiveIndex));
|
||||
auto prim = GetValueNode<std::shared_ptr<ops::SplitWithOverlap>>(cnode->input(kAnfPrimitiveIndex));
|
||||
|
||||
if (prim->get_number_split() != pre_inputs_node_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
// check axis NHWC
|
||||
// only support axis "N" now, other axes will support when having "InferShape"
|
||||
if (pre_prim->get_axis() != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// get inputs node
|
||||
auto it = g_graph_nodes_output.find(cnode->fullname_with_scope());
|
||||
if (it == g_graph_nodes_output.end()) {
|
||||
return;
|
||||
}
|
||||
int out_num = it->second.size();
|
||||
if (out_num != prim->get_number_split()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> inputs_node;
|
||||
for (int i = 0; i < out_num; i++) {
|
||||
auto tmp = it->second[i];
|
||||
auto tmp_cnode = tmp->cast<CNodePtr>();
|
||||
if (CheckIfCNodeIsNull(tmp_cnode) != lite::RET_OK) {
|
||||
return;
|
||||
}
|
||||
if (!CheckPrimitiveType(tmp_cnode, prim::kPrimTupleGetItem)) {
|
||||
return;
|
||||
}
|
||||
auto tmp_it = g_graph_nodes_output.find(tmp_cnode->fullname_with_scope());
|
||||
if (tmp_it == g_graph_nodes_output.end()) {
|
||||
return;
|
||||
}
|
||||
if (tmp_it->second.size() != 1) return;
|
||||
|
||||
auto next = tmp_it->second[0];
|
||||
auto next_cnode = next->cast<CNodePtr>();
|
||||
|
||||
inputs_node.push_back(next_cnode);
|
||||
}
|
||||
// replace inputs
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return;
|
||||
}
|
||||
for (size_t i = 1; i < pre_inputs_size; i++) {
|
||||
(void)manager->Replace((inputs_node[i - 1])->input(1), pre_cnode->input(i));
|
||||
}
|
||||
}
|
||||
|
||||
const AnfNodePtr EliminateConcatSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_LOG(DEBUG) << "Enter EliminateConcatSplit pass process";
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
if (CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
auto split_cnode = node->cast<CNodePtr>();
|
||||
if (CheckIfCNodeIsNull(split_cnode) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
ConcatSplitEliminate(func_graph, split_cnode);
|
||||
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ELIMINATE_CONCAT_SPLIT_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ELIMINATE_CONCAT_SPLIT_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "tools/optimizer/fisson/fisson_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class EliminateConcatSplit : public PatternProcessPass {
|
||||
public:
|
||||
explicit EliminateConcatSplit(bool multigraph = true) : PatternProcessPass("eliminate_concat_split", multigraph) {}
|
||||
~EliminateConcatSplit() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ELIMINATE_CONCAT_SPLIT_H_
|
|
@ -62,7 +62,6 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
return nullptr;
|
||||
}
|
||||
auto conv_node = pre_node->cast<CNodePtr>();
|
||||
MS_ASSERT(primitive_c);
|
||||
if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion) ||
|
||||
CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion)) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
|
||||
|
|
|
@ -63,7 +63,6 @@ const AnfNodePtr PoolingActivationFusion::Process(const FuncGraphPtr &func_graph
|
|||
}
|
||||
auto pooling_node = pre_node->cast<CNodePtr>();
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(pooling_node->input(0));
|
||||
MS_ASSERT(primitive_c);
|
||||
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Pooling>>(primitive_c));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Pooling>>(primitive_c);
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/parallel/dynamic_creator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// operator register
|
||||
|
||||
std::string GetDisOpName(const std::string &prim_name) {
|
||||
std::string op_name = prim_name;
|
||||
if (!prim_name.empty() && (prim_name[0] == '_')) {
|
||||
op_name = prim_name.substr(1);
|
||||
}
|
||||
return op_name + "Info";
|
||||
}
|
||||
|
||||
// create the OperatorInfo instance
|
||||
OperatorInfoPtr OperatorInstance(const std::string &type_name, const std::string &orig_name,
|
||||
const SplitStrategy &strategy) {
|
||||
if (type_name.length() == 0) {
|
||||
MS_LOG(EXCEPTION) << "Length of name is zero!";
|
||||
}
|
||||
std::string distribute_opname = GetDisOpName(type_name);
|
||||
OperatorInfoPtr operator_ = (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, strategy);
|
||||
if (operator_ == nullptr) {
|
||||
MS_LOG(INFO) << "Create " << type_name << " failed";
|
||||
return nullptr;
|
||||
}
|
||||
std::string origin_name = operator_->name();
|
||||
operator_->set_name(orig_name);
|
||||
MS_LOG(INFO) << "Successfully created operator " << origin_name;
|
||||
return operator_;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,80 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_DYNAMIC_CREATOR_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_PARALLEL_DYNAMIC_CREATOR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "tools/optimizer/parallel/operator_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
#define REGISTER(className) \
|
||||
OperatorInfoPtr objectCreator##className(std::string name, SplitStrategy strategy) { \
|
||||
return std::make_shared<className>(name, strategy); \
|
||||
} \
|
||||
RegisterAction className##Register(#className, (CreatFn)objectCreator##className);
|
||||
|
||||
typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const SplitStrategy &strategy);
|
||||
|
||||
class DynCreator {
|
||||
public:
|
||||
~DynCreator() = default;
|
||||
|
||||
// create static singleton dyn_creator instance
|
||||
static DynCreator &Instance() {
|
||||
static DynCreator fac = DynCreator();
|
||||
return fac;
|
||||
}
|
||||
// register
|
||||
void Register(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); }
|
||||
// creator
|
||||
OperatorInfoPtr Create(const std::string &name, const SplitStrategy &strategy) {
|
||||
auto iter = Function_map_.find(name);
|
||||
if (iter == Function_map_.end()) {
|
||||
MS_LOG(INFO) << name << " is not register yet";
|
||||
return nullptr;
|
||||
}
|
||||
return iter->second(name, strategy);
|
||||
}
|
||||
|
||||
private:
|
||||
DynCreator() = default;
|
||||
std::map<std::string, CreatFn> Function_map_;
|
||||
};
|
||||
|
||||
class RegisterAction {
|
||||
public:
|
||||
RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) {
|
||||
DynCreator::Instance().Register(name, creatfn);
|
||||
}
|
||||
~RegisterAction() = default;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
OperatorInfoPtr OperatorInstance(const std::string &type_name, const std::string &orig_name,
|
||||
const SplitStrategy &strategy);
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_DYNAMIC_CREATOR_H_
|
|
@ -0,0 +1,213 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/parallel/operator_info.h"
|
||||
#include <algorithm>
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
#include "mindspore/core/ops/concat.h"
|
||||
#include "mindspore/core/ops/addn.h"
|
||||
#include "mindspore/core/ops/split.h"
|
||||
#include "include/lite_types.h"
|
||||
#include "mindspore/ccsrc/utils/utils.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool is_any_none(const std::vector<int64_t> &split) {
|
||||
return std::any_of(split.begin(), split.end(), [](int64_t v) { return v == static_cast<int64_t>(NoSplit); });
|
||||
}
|
||||
|
||||
bool is_any_not_none(const std::vector<int64_t> &split) {
|
||||
return std::any_of(split.begin(), split.end(), [](int64_t v) { return v != static_cast<int64_t>(NoSplit); });
|
||||
}
|
||||
|
||||
lite::STATUS OperatorInfo::SetCNodeBackend() {
|
||||
for (size_t i = 0; i < strategy_.dev_num; ++i) {
|
||||
lite::DeviceType dt_type;
|
||||
std::string type = strategy_.dev_types[i];
|
||||
auto cnode = parallel_output_nodes_[i]->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (type == "GPU") {
|
||||
dt_type = lite::DeviceType::DT_GPU;
|
||||
} else if (type == "CPU") {
|
||||
dt_type = lite::DeviceType::DT_CPU;
|
||||
} else if (type == "NPU") {
|
||||
dt_type = lite::DeviceType::DT_NPU;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "SetCnodeBackend: unknown device type.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
cnode->AddAttr(mindspore::ops::kDeviceType, MakeValue(static_cast<int>(dt_type)));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
lite::STATUS OperatorInfo::CheckStrategyValue() {
|
||||
auto strategy_size = strategy_.strategys.size();
|
||||
|
||||
for (size_t index = 0; index < strategy_size; ++index) {
|
||||
auto strategy = strategy_.strategys[index];
|
||||
for (const auto &s : strategy) {
|
||||
if (s.size() != IntToSize(strategy_.dev_num)) {
|
||||
MS_LOG(ERROR) << "Strategy split number:" << s.size()
|
||||
<< " is not equal to device number: " << strategy_.dev_num;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (is_any_not_none(s) && is_any_none(s)) {
|
||||
MS_LOG(ERROR) << "Strategy split number must be all zero or all non-zero: " << s;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
lite::STATUS OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num,
|
||||
std::vector<AnfNodePtr> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
AbstractBasePtrList ptr_list;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << " : Failed to get CNode.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
auto idx = NewValueNode(SizeToInt(i));
|
||||
MS_ASSERT(idx);
|
||||
auto index = std::make_shared<Int32Imm>(SizeToInt(i));
|
||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index);
|
||||
idx->set_abstract(abstract_scalar);
|
||||
auto tuple_getitem = func_graph_->NewCNode({NewValueNode(std::make_shared<lite::TupleGetItem>()), node, idx});
|
||||
if (tuple_getitem == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << " : Failed to create output nodes.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i));
|
||||
outputs->push_back(tuple_getitem);
|
||||
ptr_list.push_back(abstract_scalar);
|
||||
}
|
||||
node->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
AnfNodePtr OperatorInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
|
||||
std::vector<AnfNodePtr> *split_outputs, size_t split_dim,
|
||||
size_t split_num, const std::vector<int64_t> &splits, bool trans_format) {
|
||||
MS_EXCEPTION_IF_NULL(orig_node);
|
||||
|
||||
auto split_prim = std::make_shared<ops::Split>();
|
||||
split_prim->set_output_num(split_num);
|
||||
split_prim->set_size_splits(splits);
|
||||
split_prim->set_axis(split_dim);
|
||||
auto value_node = NewValueNode(split_prim);
|
||||
std::vector<AnfNodePtr> split_inputs = {value_node};
|
||||
split_inputs.push_back(orig_node->input(input_index + 1));
|
||||
auto split_cnode = func_graph_->NewCNode(split_inputs);
|
||||
if (split_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << " : Failed to create split node.";
|
||||
return nullptr;
|
||||
}
|
||||
split_cnode->set_fullname_with_scope("Split_" + name_);
|
||||
CreateMultipleOutputsOfAnfNode(split_cnode, split_num, split_outputs);
|
||||
|
||||
return split_cnode;
|
||||
}
|
||||
|
||||
AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
|
||||
int32_t concat_dim, size_t input_nodes_num, bool trans_format) {
|
||||
MS_EXCEPTION_IF_NULL(orig_node);
|
||||
|
||||
if (input_nodes.size() != input_nodes_num) {
|
||||
MS_LOG(ERROR) << name_ << " : Input nodes size of concat is not equal to input nodes number.";
|
||||
return nullptr;
|
||||
}
|
||||
auto concat_prim = std::make_shared<ops::Concat>();
|
||||
concat_prim->set_axis(concat_dim);
|
||||
auto value_node = NewValueNode(concat_prim);
|
||||
std::vector<AnfNodePtr> concat_inputs = {value_node};
|
||||
(void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(concat_inputs),
|
||||
[](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); });
|
||||
auto concat_cnode = func_graph_->NewCNode(concat_inputs);
|
||||
if (concat_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << " : Failed to create concat node.";
|
||||
return nullptr;
|
||||
}
|
||||
concat_cnode->set_fullname_with_scope("Concat_" + name_);
|
||||
concat_cnode->set_scope(orig_node->scope());
|
||||
|
||||
return concat_cnode;
|
||||
}
|
||||
|
||||
AnfNodePtr OperatorInfo::CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
|
||||
int32_t reduce_dim, size_t input_nodes_num, bool trans_format) {
|
||||
MS_EXCEPTION_IF_NULL(orig_node);
|
||||
|
||||
if (input_nodes.size() != input_nodes_num) {
|
||||
MS_LOG(ERROR) << name_ << " : Input nodes size of reduce is not equal to input nodes number.";
|
||||
return nullptr;
|
||||
}
|
||||
// addup inputs element-wise
|
||||
auto addn_prim = std::make_shared<ops::AddN>();
|
||||
auto value_node = NewValueNode(addn_prim);
|
||||
std::vector<AnfNodePtr> addn_inputs = {value_node};
|
||||
(void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(addn_inputs),
|
||||
[](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); });
|
||||
auto addn_cnode = func_graph_->NewCNode(addn_inputs);
|
||||
if (addn_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << " : Failed to create concat node.";
|
||||
return nullptr;
|
||||
}
|
||||
addn_cnode->set_fullname_with_scope("AddN_" + name_);
|
||||
addn_cnode->set_scope(orig_node->scope());
|
||||
|
||||
return addn_cnode;
|
||||
}
|
||||
|
||||
lite::STATUS OperatorInfo::Init() {
|
||||
if (GetAttrs() != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << name_ << ": Parse attrs failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (CheckStrategyValue() != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy values.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (CheckStrategy(strategy_) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << name_ << ": Check strategys failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (InferParallelCNodes() != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << name_ << ": InferReplaceGraph failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (SetCNodeBackend() != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << name_ << ": SetCnodeBackend failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (InferReplaceOp() != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << name_ << ": InferForwardOps failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,101 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
/**
|
||||
* Do following steps to make a operator support parallel:
|
||||
*
|
||||
* 1.Add the schema::PrimitiveType_XXX to ParallelPass::PARALLEL_LIST;
|
||||
* 2.Add a pair of type and string name to ParallelPass::type_string;
|
||||
* 3.Implement a class XXXInfo whose parent is OperatorInfo;
|
||||
* 3.1.Override CheckStrategy(), InferParallelCNodes() and InferReplaceOp()
|
||||
* 4.include header file of XXXInfo in ops_info_head_files.h
|
||||
* 5.REGISTER XXXInfo in dynamic_creator.cc
|
||||
*/
|
||||
using schema::ReduceMode;
|
||||
|
||||
class OperatorInfo;
|
||||
using OperatorInfoPtr = std::shared_ptr<OperatorInfo>;
|
||||
|
||||
class OperatorInfo {
|
||||
public:
|
||||
OperatorInfo(std::string name, SplitStrategy strategy)
|
||||
: name_(std::move(name)),
|
||||
strategy_(std::move(strategy)),
|
||||
replace_op_(nullptr),
|
||||
func_graph_(nullptr),
|
||||
cnode_(nullptr) {}
|
||||
virtual ~OperatorInfo() = default;
|
||||
const std::string &name() const { return name_; }
|
||||
void set_name(const std::string &name) { name_ = name; }
|
||||
void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
|
||||
void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; }
|
||||
void setFmk(const int32_t FmkType) { FmkType_ = FmkType; }
|
||||
AnfNodePtr replace_op() { return replace_op_; }
|
||||
lite::STATUS Init();
|
||||
|
||||
protected:
|
||||
lite::STATUS CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num,
|
||||
std::vector<AnfNodePtr> *outputs);
|
||||
AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs,
|
||||
size_t split_dim, size_t split_num, const std::vector<int64_t> &splits,
|
||||
bool trans_format);
|
||||
AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
|
||||
int32_t concat_dim, size_t input_nodes_num, bool trans_format);
|
||||
AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes, int32_t reduce_dim,
|
||||
size_t input_nodes_num, bool trans_format);
|
||||
virtual lite::STATUS GetAttrs() = 0;
|
||||
virtual lite::STATUS InferReplaceOp() = 0;
|
||||
virtual lite::STATUS InferParallelCNodes() = 0;
|
||||
virtual lite::STATUS CheckStrategy(const SplitStrategy &strategy) = 0;
|
||||
|
||||
std::string name_;
|
||||
SplitStrategy strategy_;
|
||||
AnfNodePtr replace_op_;
|
||||
std::vector<AnfNodePtr> parallel_output_nodes_;
|
||||
FuncGraphPtr func_graph_;
|
||||
CNodePtr cnode_;
|
||||
int32_t FmkType_{};
|
||||
|
||||
private:
|
||||
lite::STATUS SetCNodeBackend();
|
||||
lite::STATUS CheckStrategyValue();
|
||||
};
|
||||
|
||||
bool is_any_none(const std::vector<int64_t> &split);
|
||||
bool is_any_not_none(const std::vector<int64_t> &split);
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/parallel/parallel_pass.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool ParallelPass::IsParallelCareNode(const AnfNodePtr &node) {
|
||||
return std::any_of(PARALLEL_LIST.begin(), PARALLEL_LIST.end(), [this, &node](auto &prim) {
|
||||
if (CheckPrimitiveType(node, prim)) {
|
||||
type_name_ = PrimToString(prim);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::string ParallelPass::PrimToString(const PrimitivePtr &prim) {
|
||||
if (type_string.find(prim->name()) == type_string.end()) {
|
||||
MS_LOG(EXCEPTION) << "String of the type not registered";
|
||||
}
|
||||
return type_string.at(prim->name());
|
||||
}
|
||||
|
||||
AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (CheckIfCNodeIsNull(cnode) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!IsParallelCareNode(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string cnode_name = cnode->fullname_with_scope();
|
||||
std::string name = cnode_name;
|
||||
std::string orig_name = cnode_name;
|
||||
// find operator name first, then operator type name.
|
||||
if (split_strategys_.find(name) == split_strategys_.end()) {
|
||||
name = type_name_;
|
||||
}
|
||||
if (cnode_name.find(PARALLEL_NAME_SUFFIX) != std::string::npos) {
|
||||
MS_LOG(DEBUG) << " : Skip splited cnode " << cnode_name;
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << " : Reached a parallel care node: " << cnode_name;
|
||||
if (split_strategys_.find(name) == split_strategys_.end()) {
|
||||
MS_LOG(DEBUG) << name << " : No split strategy for the current CNode.";
|
||||
return nullptr;
|
||||
}
|
||||
cnode->set_fullname_with_scope(cnode_name + PARALLEL_NAME_SUFFIX);
|
||||
OperatorInfoPtr operator_ = OperatorInstance(type_name_, orig_name, split_strategys_[name]);
|
||||
if (operator_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Create " << name << " OperatorInstance failed";
|
||||
}
|
||||
operator_->set_cnode(cnode);
|
||||
operator_->set_func_graph(func_graph);
|
||||
operator_->setFmk(FmkType_);
|
||||
if (operator_->Init() == RET_ERROR) {
|
||||
MS_LOG(EXCEPTION) << "Failure: operator " << name << " init failed";
|
||||
}
|
||||
return operator_->replace_op();
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "tools/optimizer/parallel/dynamic_creator.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "mindspore/ccsrc/backend/optimizer/common/node_pass.h"
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ParallelPass : public opt::NodePass {
|
||||
public:
|
||||
explicit ParallelPass(const std::unordered_map<std::string, SplitStrategy> strategys, const int32_t FmkType)
|
||||
: NodePass("parallel_pass"), split_strategys_(strategys), FmkType_(FmkType) {}
|
||||
~ParallelPass() override = default;
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
const std::set<PrimitivePtr> PARALLEL_LIST = {prim::kPrimConv2DFusion};
|
||||
const std::unordered_map<std::string, std::string> type_string = {{prim::kPrimConv2DFusion->name(), "Conv2D"}};
|
||||
|
||||
bool IsParallelCareNode(const AnfNodePtr &node);
|
||||
std::string PrimToString(const PrimitivePtr &prim);
|
||||
|
||||
std::string type_name_;
|
||||
std::unordered_map<std::string, SplitStrategy> split_strategys_;
|
||||
int32_t FmkType_;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_
|
|
@ -38,12 +38,6 @@ const std::vector<std::string> kSplitDevTypes = {"CPU", "GPU"};
|
|||
|
||||
using Strategys = std::vector<std::vector<std::vector<int64_t>>>;
|
||||
|
||||
enum Status {
|
||||
SUCCESS = 0,
|
||||
FAILED,
|
||||
INVALID_ARGUMENT,
|
||||
};
|
||||
|
||||
enum SplitMode {
|
||||
SplitN = 0,
|
||||
SplitH = 1,
|
||||
|
|
Loading…
Reference in New Issue