!15933 add single conv split

From: @zhujingxuan
Reviewed-by: @wangchengyuan,@hangangqiang
Signed-off-by: @wangchengyuan
This commit is contained in:
mindspore-ci-bot 2021-04-30 16:02:58 +08:00 committed by Gitee
commit 55b15079d1
14 changed files with 785 additions and 8 deletions

View File

@ -59,9 +59,13 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/tf_gelu_fusion.cc
../optimizer/fusion/onnx_gelu_fusion.cc
../optimizer/fusion/squeeze_fusion.cc
../optimizer/fisson/eliminate_concat_split.cc
../optimizer/fisson/fisson_util.cc
../optimizer/fisson/iter_node_outputs.cc
../optimizer/fisson/node_out_shapes.cc
../optimizer/parallel/dynamic_creator.cc
../optimizer/parallel/operator_info.cc
../optimizer/parallel/parallel_pass.cc
../optimizer/graph/conv1d_inout_adjust_pass.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc

View File

@ -645,6 +645,13 @@ bool IsSqueezeNode(const BaseRef &n) {
return false;
}
bool IsConcatNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimConcat);
}
return false;
}
bool CheckIsAllInputsParam(const AnfNodePtr &node) {
if (node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);

View File

@ -94,6 +94,8 @@ size_t GetOutputTensorNum(const AnfNodePtr &node);
bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node);
bool IsConcatNode(const BaseRef &n);
size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
tensor::TensorPtr GetTensorInfo(const AnfNodePtr &node);

View File

@ -0,0 +1,153 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/optimizer/fisson/eliminate_concat_split.h"
#include "schema/inner/model_generated.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "mindspore/core/ops/split_with_overlap.h"
#include "mindspore/core/ops/concat.h"
#include "mindspore/core/base/core_ops.h"
namespace mindspore {
namespace opt {
const BaseRef EliminateConcatSplit::DefinePattern() const {
auto concat_var = std::make_shared<CondVar>(IsConcatNode);
auto split_prim = std::make_shared<ops::SplitWithOverlap>();
return VectorRef({split_prim, concat_var});
}
CNodePtr GetRealPrevCNode(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (IsRealCNodeKernel(cnode)) {
return cnode;
}
auto input0 = cnode->input(0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
auto temp_node = cnode->input(1);
if (temp_node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
return GetRealPrevCNode(temp_node);
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
return GetRealPrevCNode(cnode->input(1));
} else {
return nullptr;
}
}
void ConcatSplitEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
auto pre_cnode = GetRealPrevCNode(cnode->input(1));
CheckIfCNodeIsNull(pre_cnode);
if (!CheckPrimitiveType(pre_cnode, prim::kPrimConcat)) {
return;
}
auto finder = g_graph_nodes_output.find(pre_cnode->fullname_with_scope());
if (finder == g_graph_nodes_output.end()) {
return;
}
if (finder->second.size() > 1) return;
size_t pre_inputs_size = pre_cnode->inputs().size();
int pre_inputs_node_size = pre_inputs_size - 1;
auto pre_prim = GetValueNode<std::shared_ptr<ops::Concat>>(pre_cnode->input(kAnfPrimitiveIndex));
auto prim = GetValueNode<std::shared_ptr<ops::SplitWithOverlap>>(cnode->input(kAnfPrimitiveIndex));
if (prim->get_number_split() != pre_inputs_node_size) {
return;
}
// check axis NHWC
// only support axis "N" now, other axes will support when having "InferShape"
if (pre_prim->get_axis() != 0) {
return;
}
// get inputs node
auto it = g_graph_nodes_output.find(cnode->fullname_with_scope());
if (it == g_graph_nodes_output.end()) {
return;
}
int out_num = it->second.size();
if (out_num != prim->get_number_split()) {
return;
}
std::vector<CNodePtr> inputs_node;
for (int i = 0; i < out_num; i++) {
auto tmp = it->second[i];
auto tmp_cnode = tmp->cast<CNodePtr>();
if (CheckIfCNodeIsNull(tmp_cnode) != lite::RET_OK) {
return;
}
if (!CheckPrimitiveType(tmp_cnode, prim::kPrimTupleGetItem)) {
return;
}
auto tmp_it = g_graph_nodes_output.find(tmp_cnode->fullname_with_scope());
if (tmp_it == g_graph_nodes_output.end()) {
return;
}
if (tmp_it->second.size() != 1) return;
auto next = tmp_it->second[0];
auto next_cnode = next->cast<CNodePtr>();
inputs_node.push_back(next_cnode);
}
// replace inputs
auto manager = func_graph->manager();
if (manager == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return;
}
for (size_t i = 1; i < pre_inputs_size; i++) {
(void)manager->Replace((inputs_node[i - 1])->input(1), pre_cnode->input(i));
}
}
const AnfNodePtr EliminateConcatSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_LOG(DEBUG) << "Enter EliminateConcatSplit pass process";
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK) {
return nullptr;
}
if (CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
return nullptr;
}
auto split_cnode = node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(split_cnode) != lite::RET_OK) {
return nullptr;
}
ConcatSplitEliminate(func_graph, split_cnode);
return node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ELIMINATE_CONCAT_SPLIT_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ELIMINATE_CONCAT_SPLIT_H_
#include "backend/optimizer/common/optimizer.h"
#include "tools/optimizer/fisson/fisson_util.h"
namespace mindspore {
namespace opt {
class EliminateConcatSplit : public PatternProcessPass {
public:
explicit EliminateConcatSplit(bool multigraph = true) : PatternProcessPass("eliminate_concat_split", multigraph) {}
~EliminateConcatSplit() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ELIMINATE_CONCAT_SPLIT_H_

View File

@ -62,7 +62,6 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
return nullptr;
}
auto conv_node = pre_node->cast<CNodePtr>();
MS_ASSERT(primitive_c);
if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion) ||
CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion)) {
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));

View File

@ -63,7 +63,6 @@ const AnfNodePtr PoolingActivationFusion::Process(const FuncGraphPtr &func_graph
}
auto pooling_node = pre_node->cast<CNodePtr>();
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(pooling_node->input(0));
MS_ASSERT(primitive_c);
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Pooling>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Pooling>>(primitive_c);

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/parallel/dynamic_creator.h"
namespace mindspore {
namespace opt {
// operator register
std::string GetDisOpName(const std::string &prim_name) {
std::string op_name = prim_name;
if (!prim_name.empty() && (prim_name[0] == '_')) {
op_name = prim_name.substr(1);
}
return op_name + "Info";
}
// create the OperatorInfo instance
OperatorInfoPtr OperatorInstance(const std::string &type_name, const std::string &orig_name,
const SplitStrategy &strategy) {
if (type_name.length() == 0) {
MS_LOG(EXCEPTION) << "Length of name is zero!";
}
std::string distribute_opname = GetDisOpName(type_name);
OperatorInfoPtr operator_ = (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, strategy);
if (operator_ == nullptr) {
MS_LOG(INFO) << "Create " << type_name << " failed";
return nullptr;
}
std::string origin_name = operator_->name();
operator_->set_name(orig_name);
MS_LOG(INFO) << "Successfully created operator " << origin_name;
return operator_;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,80 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_DYNAMIC_CREATOR_H_
#define MINDSPORE_LITE_SRC_PASS_PARALLEL_DYNAMIC_CREATOR_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "tools/optimizer/parallel/operator_info.h"
namespace mindspore {
namespace opt {
#define REGISTER(className) \
OperatorInfoPtr objectCreator##className(std::string name, SplitStrategy strategy) { \
return std::make_shared<className>(name, strategy); \
} \
RegisterAction className##Register(#className, (CreatFn)objectCreator##className);
typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const SplitStrategy &strategy);
class DynCreator {
public:
~DynCreator() = default;
// create static singleton dyn_creator instance
static DynCreator &Instance() {
static DynCreator fac = DynCreator();
return fac;
}
// register
void Register(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); }
// creator
OperatorInfoPtr Create(const std::string &name, const SplitStrategy &strategy) {
auto iter = Function_map_.find(name);
if (iter == Function_map_.end()) {
MS_LOG(INFO) << name << " is not register yet";
return nullptr;
}
return iter->second(name, strategy);
}
private:
DynCreator() = default;
std::map<std::string, CreatFn> Function_map_;
};
class RegisterAction {
public:
RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) {
DynCreator::Instance().Register(name, creatfn);
}
~RegisterAction() = default;
private:
std::string name_;
};
OperatorInfoPtr OperatorInstance(const std::string &type_name, const std::string &orig_name,
const SplitStrategy &strategy);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_DYNAMIC_CREATOR_H_

View File

@ -0,0 +1,213 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/parallel/operator_info.h"
#include <algorithm>
#include "tools/converter/ops/ops_def.h"
#include "tools/optimizer/parallel/split_strategy.h"
#include "mindspore/core/ops/concat.h"
#include "mindspore/core/ops/addn.h"
#include "mindspore/core/ops/split.h"
#include "include/lite_types.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "base/core_ops.h"
#include "include/errorcode.h"
namespace mindspore {
namespace opt {
bool is_any_none(const std::vector<int64_t> &split) {
return std::any_of(split.begin(), split.end(), [](int64_t v) { return v == static_cast<int64_t>(NoSplit); });
}
bool is_any_not_none(const std::vector<int64_t> &split) {
return std::any_of(split.begin(), split.end(), [](int64_t v) { return v != static_cast<int64_t>(NoSplit); });
}
lite::STATUS OperatorInfo::SetCNodeBackend() {
for (size_t i = 0; i < strategy_.dev_num; ++i) {
lite::DeviceType dt_type;
std::string type = strategy_.dev_types[i];
auto cnode = parallel_output_nodes_[i]->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (type == "GPU") {
dt_type = lite::DeviceType::DT_GPU;
} else if (type == "CPU") {
dt_type = lite::DeviceType::DT_CPU;
} else if (type == "NPU") {
dt_type = lite::DeviceType::DT_NPU;
} else {
MS_LOG(ERROR) << "SetCnodeBackend: unknown device type.";
return lite::RET_ERROR;
}
cnode->AddAttr(mindspore::ops::kDeviceType, MakeValue(static_cast<int>(dt_type)));
}
return lite::RET_OK;
}
lite::STATUS OperatorInfo::CheckStrategyValue() {
auto strategy_size = strategy_.strategys.size();
for (size_t index = 0; index < strategy_size; ++index) {
auto strategy = strategy_.strategys[index];
for (const auto &s : strategy) {
if (s.size() != IntToSize(strategy_.dev_num)) {
MS_LOG(ERROR) << "Strategy split number:" << s.size()
<< " is not equal to device number: " << strategy_.dev_num;
return lite::RET_ERROR;
}
if (is_any_not_none(s) && is_any_none(s)) {
MS_LOG(ERROR) << "Strategy split number must be all zero or all non-zero: " << s;
return lite::RET_ERROR;
}
}
}
return lite::RET_OK;
}
lite::STATUS OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num,
std::vector<AnfNodePtr> *outputs) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(outputs);
AbstractBasePtrList ptr_list;
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
MS_LOG(ERROR) << name_ << " : Failed to get CNode.";
return lite::RET_ERROR;
}
for (size_t i = 0; i < output_num; ++i) {
auto idx = NewValueNode(SizeToInt(i));
MS_ASSERT(idx);
auto index = std::make_shared<Int32Imm>(SizeToInt(i));
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index);
idx->set_abstract(abstract_scalar);
auto tuple_getitem = func_graph_->NewCNode({NewValueNode(std::make_shared<lite::TupleGetItem>()), node, idx});
if (tuple_getitem == nullptr) {
MS_LOG(ERROR) << name_ << " : Failed to create output nodes.";
return lite::RET_ERROR;
}
tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i));
outputs->push_back(tuple_getitem);
ptr_list.push_back(abstract_scalar);
}
node->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
return lite::RET_OK;
}
AnfNodePtr OperatorInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
std::vector<AnfNodePtr> *split_outputs, size_t split_dim,
size_t split_num, const std::vector<int64_t> &splits, bool trans_format) {
MS_EXCEPTION_IF_NULL(orig_node);
auto split_prim = std::make_shared<ops::Split>();
split_prim->set_output_num(split_num);
split_prim->set_size_splits(splits);
split_prim->set_axis(split_dim);
auto value_node = NewValueNode(split_prim);
std::vector<AnfNodePtr> split_inputs = {value_node};
split_inputs.push_back(orig_node->input(input_index + 1));
auto split_cnode = func_graph_->NewCNode(split_inputs);
if (split_cnode == nullptr) {
MS_LOG(ERROR) << name_ << " : Failed to create split node.";
return nullptr;
}
split_cnode->set_fullname_with_scope("Split_" + name_);
CreateMultipleOutputsOfAnfNode(split_cnode, split_num, split_outputs);
return split_cnode;
}
AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
int32_t concat_dim, size_t input_nodes_num, bool trans_format) {
MS_EXCEPTION_IF_NULL(orig_node);
if (input_nodes.size() != input_nodes_num) {
MS_LOG(ERROR) << name_ << " : Input nodes size of concat is not equal to input nodes number.";
return nullptr;
}
auto concat_prim = std::make_shared<ops::Concat>();
concat_prim->set_axis(concat_dim);
auto value_node = NewValueNode(concat_prim);
std::vector<AnfNodePtr> concat_inputs = {value_node};
(void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(concat_inputs),
[](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); });
auto concat_cnode = func_graph_->NewCNode(concat_inputs);
if (concat_cnode == nullptr) {
MS_LOG(ERROR) << name_ << " : Failed to create concat node.";
return nullptr;
}
concat_cnode->set_fullname_with_scope("Concat_" + name_);
concat_cnode->set_scope(orig_node->scope());
return concat_cnode;
}
AnfNodePtr OperatorInfo::CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
int32_t reduce_dim, size_t input_nodes_num, bool trans_format) {
MS_EXCEPTION_IF_NULL(orig_node);
if (input_nodes.size() != input_nodes_num) {
MS_LOG(ERROR) << name_ << " : Input nodes size of reduce is not equal to input nodes number.";
return nullptr;
}
// addup inputs element-wise
auto addn_prim = std::make_shared<ops::AddN>();
auto value_node = NewValueNode(addn_prim);
std::vector<AnfNodePtr> addn_inputs = {value_node};
(void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(addn_inputs),
[](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); });
auto addn_cnode = func_graph_->NewCNode(addn_inputs);
if (addn_cnode == nullptr) {
MS_LOG(ERROR) << name_ << " : Failed to create concat node.";
return nullptr;
}
addn_cnode->set_fullname_with_scope("AddN_" + name_);
addn_cnode->set_scope(orig_node->scope());
return addn_cnode;
}
lite::STATUS OperatorInfo::Init() {
if (GetAttrs() != lite::RET_OK) {
MS_LOG(ERROR) << name_ << ": Parse attrs failed.";
return lite::RET_ERROR;
}
if (CheckStrategyValue() != lite::RET_OK) {
MS_LOG(ERROR) << name_ << ": Invalid strategy values.";
return lite::RET_ERROR;
}
if (CheckStrategy(strategy_) != lite::RET_OK) {
MS_LOG(ERROR) << name_ << ": Check strategys failed.";
return lite::RET_ERROR;
}
if (InferParallelCNodes() != lite::RET_OK) {
MS_LOG(ERROR) << name_ << ": InferReplaceGraph failed.";
return lite::RET_ERROR;
}
if (SetCNodeBackend() != lite::RET_OK) {
MS_LOG(ERROR) << name_ << ": SetCnodeBackend failed.";
return lite::RET_ERROR;
}
if (InferReplaceOp() != lite::RET_OK) {
MS_LOG(ERROR) << name_ << ": InferForwardOps failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,101 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_
#define MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_
#include <utility>
#include <vector>
#include <string>
#include <memory>
#include <unordered_map>
#include "tools/optimizer/parallel/split_strategy.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "schema/inner/model_generated.h"
#include "include/context.h"
#include "include/errorcode.h"
namespace mindspore {
namespace opt {
/**
* Do following steps to make a operator support parallel:
*
* 1.Add the schema::PrimitiveType_XXX to ParallelPass::PARALLEL_LIST;
* 2.Add a pair of type and string name to ParallelPass::type_string;
* 3.Implement a class XXXInfo whose parent is OperatorInfo;
* 3.1.Override CheckStrategy(), InferParallelCNodes() and InferReplaceOp()
* 4.include header file of XXXInfo in ops_info_head_files.h
* 5.REGISTER XXXInfo in dynamic_creator.cc
*/
using schema::ReduceMode;
class OperatorInfo;
using OperatorInfoPtr = std::shared_ptr<OperatorInfo>;
class OperatorInfo {
public:
OperatorInfo(std::string name, SplitStrategy strategy)
: name_(std::move(name)),
strategy_(std::move(strategy)),
replace_op_(nullptr),
func_graph_(nullptr),
cnode_(nullptr) {}
virtual ~OperatorInfo() = default;
const std::string &name() const { return name_; }
void set_name(const std::string &name) { name_ = name; }
void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; }
void setFmk(const int32_t FmkType) { FmkType_ = FmkType; }
AnfNodePtr replace_op() { return replace_op_; }
lite::STATUS Init();
protected:
lite::STATUS CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num,
std::vector<AnfNodePtr> *outputs);
AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs,
size_t split_dim, size_t split_num, const std::vector<int64_t> &splits,
bool trans_format);
AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
int32_t concat_dim, size_t input_nodes_num, bool trans_format);
AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes, int32_t reduce_dim,
size_t input_nodes_num, bool trans_format);
virtual lite::STATUS GetAttrs() = 0;
virtual lite::STATUS InferReplaceOp() = 0;
virtual lite::STATUS InferParallelCNodes() = 0;
virtual lite::STATUS CheckStrategy(const SplitStrategy &strategy) = 0;
std::string name_;
SplitStrategy strategy_;
AnfNodePtr replace_op_;
std::vector<AnfNodePtr> parallel_output_nodes_;
FuncGraphPtr func_graph_;
CNodePtr cnode_;
int32_t FmkType_{};
private:
lite::STATUS SetCNodeBackend();
lite::STATUS CheckStrategyValue();
};
bool is_any_none(const std::vector<int64_t> &split);
bool is_any_not_none(const std::vector<int64_t> &split);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_

View File

@ -0,0 +1,86 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/parallel/parallel_pass.h"
#include "include/errorcode.h"
#include "ir/tensor.h"
namespace mindspore {
namespace opt {
bool ParallelPass::IsParallelCareNode(const AnfNodePtr &node) {
return std::any_of(PARALLEL_LIST.begin(), PARALLEL_LIST.end(), [this, &node](auto &prim) {
if (CheckPrimitiveType(node, prim)) {
type_name_ = PrimToString(prim);
return true;
} else {
return false;
}
});
}
std::string ParallelPass::PrimToString(const PrimitivePtr &prim) {
if (type_string.find(prim->name()) == type_string.end()) {
MS_LOG(EXCEPTION) << "String of the type not registered";
}
return type_string.at(prim->name());
}
AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
return nullptr;
}
if (!utils::isa<CNode>(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(cnode) != lite::RET_OK) {
return nullptr;
}
if (!IsParallelCareNode(node)) {
return nullptr;
}
std::string cnode_name = cnode->fullname_with_scope();
std::string name = cnode_name;
std::string orig_name = cnode_name;
// find operator name first, then operator type name.
if (split_strategys_.find(name) == split_strategys_.end()) {
name = type_name_;
}
if (cnode_name.find(PARALLEL_NAME_SUFFIX) != std::string::npos) {
MS_LOG(DEBUG) << " : Skip splited cnode " << cnode_name;
return nullptr;
}
MS_LOG(DEBUG) << " : Reached a parallel care node: " << cnode_name;
if (split_strategys_.find(name) == split_strategys_.end()) {
MS_LOG(DEBUG) << name << " : No split strategy for the current CNode.";
return nullptr;
}
cnode->set_fullname_with_scope(cnode_name + PARALLEL_NAME_SUFFIX);
OperatorInfoPtr operator_ = OperatorInstance(type_name_, orig_name, split_strategys_[name]);
if (operator_ == nullptr) {
MS_LOG(EXCEPTION) << "Failure: Create " << name << " OperatorInstance failed";
}
operator_->set_cnode(cnode);
operator_->set_func_graph(func_graph);
operator_->setFmk(FmkType_);
if (operator_->Init() == RET_ERROR) {
MS_LOG(EXCEPTION) << "Failure: operator " << name << " init failed";
}
return operator_->replace_op();
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <utility>
#include <set>
#include <string>
#include <unordered_map>
#include "ir/anf.h"
#include "tools/optimizer/parallel/dynamic_creator.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "mindspore/ccsrc/backend/optimizer/common/node_pass.h"
#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_
namespace mindspore {
namespace opt {
class ParallelPass : public opt::NodePass {
public:
explicit ParallelPass(const std::unordered_map<std::string, SplitStrategy> strategys, const int32_t FmkType)
: NodePass("parallel_pass"), split_strategys_(strategys), FmkType_(FmkType) {}
~ParallelPass() override = default;
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
private:
const std::set<PrimitivePtr> PARALLEL_LIST = {prim::kPrimConv2DFusion};
const std::unordered_map<std::string, std::string> type_string = {{prim::kPrimConv2DFusion->name(), "Conv2D"}};
bool IsParallelCareNode(const AnfNodePtr &node);
std::string PrimToString(const PrimitivePtr &prim);
std::string type_name_;
std::unordered_map<std::string, SplitStrategy> split_strategys_;
int32_t FmkType_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_

View File

@ -38,12 +38,6 @@ const std::vector<std::string> kSplitDevTypes = {"CPU", "GPU"};
using Strategys = std::vector<std::vector<std::vector<int64_t>>>;
enum Status {
SUCCESS = 0,
FAILED,
INVALID_ARGUMENT,
};
enum SplitMode {
SplitN = 0,
SplitH = 1,