!16440 adjsut sigle parallel split node pass

From: @zoloft
Reviewed-by: @zhang_xue_tong
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2021-05-17 10:08:02 +08:00 committed by Gitee
commit ea99e19b47
14 changed files with 163 additions and 119 deletions

View File

@ -134,7 +134,7 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
// 1. deal with split strategy
std::unordered_map<std::string, opt::SplitStrategy> split_strategys = ParserSplitStrategy(opt::SplitH);
std::unordered_map<std::string, opt::SplitStrategy> split_strategys = opt::ParserSplitStrategy();
if (split_strategys.empty()) {
MS_LOG(ERROR) << "parse split_strategy error.";
return RET_OK;
@ -144,9 +144,7 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter
parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
// 3. multi_conv parallel pass
auto strategy = split_strategys.begin()->second;
parallel_pm->AddPass(
std::make_shared<opt::MultiConvSplitPass>(strategy, schema::PrimitiveType_Conv2DFusion, config->fmk, 3));
parallel_pm->AddPass(std::make_shared<opt::MultiConvSplitPass>(split_strategys, config->fmk, 3));
parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
// 4. single conv parallel pass
parallel_pm->AddPass(std::make_shared<opt::ParallelPass>(split_strategys, config->fmk));

View File

@ -19,13 +19,27 @@
#include "mindspore/ccsrc/utils/utils.h"
#include "mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "mindspore/core/base/base.h"
#include "mindspore/core/ops/fusion/conv2d_fusion.h"
#include "base/base.h"
#include "ops/fusion/conv2d_fusion.h"
#include "tools/optimizer/parallel/split_strategy.h"
using mindspore::lite::converter::FmkType;
using mindspore::schema::PrimitiveType_Conv2dTransposeFusion;
namespace mindspore {
namespace opt {
std::string MultiConvSplitPass::IsMultiParallelConvNode(const AnfNodePtr &node) const {
std::string parallel_name;
for (const auto &parallel_prim : kParallelSet) {
if (CheckPrimitiveType(node, parallel_prim)) {
if (kParallelOpNames.find(parallel_prim) != kParallelOpNames.end()) {
return kParallelOpNames.at(parallel_prim);
}
}
}
return parallel_name;
}
const BaseRef MultiConvSplitPass::DefinePattern() const {
auto conv1_var = std::make_shared<CondVar>(IsConvNode);
auto conv1_other_var = std::make_shared<SeqVar>();
@ -50,8 +64,12 @@ const AnfNodePtr MultiConvSplitPass::Process(const FuncGraphPtr &func_graph, con
if (device_type != kDeviceTypeNone) {
return node;
}
auto parallel_name = IsMultiParallelConvNode(node);
if (parallel_name.empty()) {
return node;
}
std::shared_ptr<MultiNodeSplitProxy> multi_node_split_proxy =
std::make_shared<MultiNodeSplitProxy>(strategy_, primitive_type_, fmk_type_, num_);
std::make_shared<MultiNodeSplitProxy>(strategys_.at(parallel_name), primitive_type_, fmk_type_, num_);
return multi_node_split_proxy->DoSplit(func_graph, node);
}

View File

@ -18,6 +18,8 @@
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_MULTI_CONV_SPLIT_H_
#include <vector>
#include <string>
#include <unordered_map>
#include "backend/optimizer/common/optimizer.h"
#include "tools/optimizer/fisson/fisson_util.h"
#include "tools/optimizer/parallel/split_strategy.h"
@ -29,19 +31,18 @@ namespace mindspore {
namespace opt {
class MultiConvSplitPass : public PatternProcessPass {
public:
explicit MultiConvSplitPass(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1,
explicit MultiConvSplitPass(const std::unordered_map<std::string, SplitStrategy> &strategys, int32_t fmk_type = -1,
int32_t num = 3, bool multigraph = true)
: PatternProcessPass("multi_conv_split", multigraph),
strategy_(strategy),
primitive_type_(primitive_type),
fmk_type_(fmk_type),
num_(num) {}
: PatternProcessPass("multi_conv_split", multigraph), strategys_(strategys), fmk_type_(fmk_type), num_(num) {}
~MultiConvSplitPass() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
SplitStrategy strategy_{};
std::string IsMultiParallelConvNode(const AnfNodePtr &node) const;
private:
std::unordered_map<std::string, SplitStrategy> strategys_;
PrimitiveType primitive_type_{schema::PrimitiveType_NONE};
int32_t fmk_type_{-1};
int32_t num_{0};

View File

@ -31,8 +31,6 @@ using mindspore::schema::PrimitiveType_Conv2DFusion;
namespace mindspore {
namespace opt {
int Conv2DInfo::GetAttrs() { return lite::RET_OK; }
int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
int split_count = 0;
Strategys strategys = strategy.strategys;
@ -77,7 +75,6 @@ int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
MS_LOG(ERROR) << "Strategy ERROR, only support split one dimension.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
@ -120,7 +117,6 @@ AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t in
split_cnode->set_fullname_with_scope("Split_" + name_);
CreateMultipleOutputsOfAnfNode(split_cnode, split_num, split_outputs);
return split_cnode;
}
@ -208,7 +204,6 @@ int Conv2DInfo::InferParallelCNodes() {
MS_LOG(DEBUG) << "No Split mode chosen";
}
name_ = orig_name;
return ConstructOutputCNodes(conv_prim, feature_split_outputs, kernel_split_outputs, bias_split_outputs);
}
@ -218,7 +213,6 @@ int Conv2DInfo::ConstructOutputCNodes(const std::shared_ptr<ops::Conv2DFusion> &
const std::vector<AnfNodePtr> &bias_split_outputs) {
Strategys strategys = strategy_.strategys;
size_t dev_num = strategy_.dev_num;
int cin_strategy_sum = std::accumulate(strategys[0][kAxisCIn].begin(), strategys[0][kAxisCIn].end(), 0);
int cout_strategy_sum = std::accumulate(strategys[1][kAxisCOut].begin(), strategys[1][kAxisCOut].end(), 0);
std::string conv_cnode_name = cnode_->fullname_with_scope();
@ -243,7 +237,6 @@ int Conv2DInfo::ConstructOutputCNodes(const std::shared_ptr<ops::Conv2DFusion> &
prim->set_pad_list(conv_prim->get_pad_list());
prim->set_stride(conv_prim->get_stride());
prim->set_activation_type(conv_prim->get_activation_type());
switch (split_mode_) {
case SplitH: {
if (i != 0) {
@ -350,7 +343,6 @@ int DepthwiseConv2DInfo::InferParallelCNodes() {
size_t dev_num = strategy_.dev_num;
std::vector<AnfNodePtr> feature_split_outputs;
std::string orig_name = name_;
switch (split_mode_) {
case SplitCIN: {
MS_LOG(ERROR) << "DepthwiseConv2DInfo doesn't support split Cin.";
@ -367,7 +359,6 @@ int DepthwiseConv2DInfo::InferParallelCNodes() {
default:
break;
}
parallel_output_nodes_.clear();
std::string conv_cnode_name = cnode_->fullname_with_scope();
auto conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode_->input(kAnfPrimitiveIndex));
@ -385,7 +376,6 @@ int DepthwiseConv2DInfo::InferParallelCNodes() {
}
}
name_ = orig_name;
// construct parallel Conv2D nodes
for (size_t i = 0; i < dev_num; ++i) {
std::vector<AnfNodePtr> tmp_outputs;
@ -403,7 +393,6 @@ int DepthwiseConv2DInfo::InferParallelCNodes() {
prim->set_pad_list(conv_prim->get_pad_list());
prim->set_stride(conv_prim->get_stride());
prim->set_activation_type(conv_prim->get_activation_type());
std::vector<AnfNodePtr> conv_inputs = {NewValueNode(prim)};
conv_inputs.push_back(feature_split_outputs[i]);
conv_inputs.push_back(cnode_->input(2));

View File

@ -35,7 +35,6 @@ class Conv2DInfo : public OperatorInfo {
protected:
int CheckStrategy(const SplitStrategy &strategy) override;
int GetAttrs() override;
int InferReplaceOp() override;
int InferParallelCNodes() override;
int ConstructOutputCNodes(const std::shared_ptr<ops::Conv2DFusion> &conv_prim,

View File

@ -77,7 +77,7 @@ int MultiConvSplit::GetMultiConvNodes(const FuncGraphPtr &func_graph, const AnfN
auto curr_node = conv_nodes_[idx];
auto curr_cnode = conv_nodes_[idx]->cast<CNodePtr>();
auto tmp_node = curr_cnode->input(1);
if (IsConv2D(tmp_node)) {
if (!IsConv2D(tmp_node)) {
break;
}
auto name = tmp_node->fullname_with_scope();

View File

@ -20,6 +20,7 @@ namespace mindspore {
namespace opt {
int MultiNodeSplitProxy::InitResource() {
split_mode_ = strategy_.split_mode_;
switch (split_mode_) {
case SplitN:
multi_node_split_ = std::make_shared<MultiConvSplitN>(strategy_, primitive_type_, fmk_type_, num_);

View File

@ -18,7 +18,7 @@
#include <utility>
#include <memory>
#include "tools/optimizer/parallel/split_strategy.h"
#include "schema/inner/model_generated.h"
#include "schema/model_generated.h"
#include "include/errorcode.h"
#include "base/base.h"

View File

@ -36,12 +36,17 @@ bool is_any_not_none(const std::vector<int64_t> &split) {
return std::any_of(split.begin(), split.end(), [](int64_t v) { return v != static_cast<int64_t>(NoSplit); });
}
std::shared_ptr<abstract::AbstractTensor> OperatorInfo::CreateFakeAbstractTensor() {
auto type_ptr = TypeIdToType(operator_type_id_);
std::vector<int64_t> shape_vector;
return std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
}
int OperatorInfo::SetCNodeBackend() {
for (size_t i = 0; i < strategy_.dev_num; ++i) {
lite::DeviceType dt_type;
std::string type = strategy_.dev_types[i];
auto cnode = parallel_output_nodes_[i]->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (type == "GPU") {
dt_type = lite::DeviceType::DT_GPU;
} else if (type == "CPU") {
@ -90,7 +95,6 @@ int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t
for (size_t i = 0; i < output_num; ++i) {
auto idx = NewValueNode(SizeToInt(i));
MS_ASSERT(idx);
auto index = std::make_shared<Int32Imm>(SizeToInt(i));
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index);
idx->set_abstract(abstract_scalar);
@ -101,38 +105,13 @@ int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t
}
tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i));
outputs->push_back(tuple_getitem);
auto type_id = static_cast<TypeId>(operator_type_id_);
auto type_ptr = TypeIdToType(type_id);
std::vector<int64_t> shape_vector;
ptr_list.push_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
auto abstract_tensor = CreateFakeAbstractTensor();
ptr_list.push_back(abstract_tensor);
}
node->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
return lite::RET_OK;
}
AnfNodePtr OperatorInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
std::vector<AnfNodePtr> *split_outputs, size_t split_dim,
size_t split_num, const std::vector<int64_t> &splits, bool trans_format) {
MS_EXCEPTION_IF_NULL(orig_node);
auto split_prim = std::make_shared<ops::Split>();
split_prim->set_output_num(split_num);
split_prim->set_size_splits(splits);
split_prim->set_axis(split_dim);
auto value_node = NewValueNode(split_prim);
std::vector<AnfNodePtr> split_inputs = {value_node};
split_inputs.push_back(orig_node->input(input_index + 1));
auto split_cnode = func_graph_->NewCNode(split_inputs);
if (split_cnode == nullptr) {
MS_LOG(ERROR) << name_ << " : Failed to create split node.";
return nullptr;
}
split_cnode->set_fullname_with_scope("Split_" + name_);
CreateMultipleOutputsOfAnfNode(split_cnode, split_num, split_outputs);
return split_cnode;
}
AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
int32_t concat_dim, size_t input_nodes_num, bool trans_format) {
MS_EXCEPTION_IF_NULL(orig_node);
@ -185,10 +164,6 @@ AnfNodePtr OperatorInfo::CreateReduceNode(const CNodePtr &orig_node, const std::
}
int OperatorInfo::Init() {
if (GetAttrs() != lite::RET_OK) {
MS_LOG(ERROR) << name_ << ": Parse attrs failed.";
return lite::RET_ERROR;
}
if (CheckStrategyValue() != lite::RET_OK) {
MS_LOG(ERROR) << name_ << ": Invalid strategy values.";
return lite::RET_ERROR;

View File

@ -25,7 +25,7 @@
#include "tools/optimizer/parallel/split_strategy.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "schema/inner/model_generated.h"
#include "schema/model_generated.h"
#include "include/context.h"
#include "include/errorcode.h"
@ -42,10 +42,8 @@ namespace opt {
* 5.REGISTER XXXInfo in dynamic_creator.cc
*/
using schema::ReduceMode;
class OperatorInfo;
using OperatorInfoPtr = std::shared_ptr<OperatorInfo>;
class OperatorInfo {
public:
OperatorInfo(const std::string &name, const SplitStrategy &strategy)
@ -59,20 +57,23 @@ class OperatorInfo {
void set_name(const std::string &name) { name_ = name; }
void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; }
void setFmk(const int32_t fmk_type) { fmk_type_ = fmk_type; }
void set_fmk(const int32_t fmk_type) { fmk_type_ = fmk_type; }
AnfNodePtr replace_op() const { return replace_op_; }
int Init();
protected:
int CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, std::vector<AnfNodePtr> *outputs);
virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
const std::vector<int64_t> &splits, bool trans_format);
AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
int32_t concat_dim, size_t input_nodes_num, bool trans_format);
AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes, int32_t reduce_dim,
size_t input_nodes_num, bool trans_format);
virtual int GetAttrs() = 0;
std::shared_ptr<abstract::AbstractTensor> CreateFakeAbstractTensor();
virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
const std::vector<int64_t> &splits, bool trans_format) = 0;
virtual int InferReplaceOp() = 0;
virtual int InferParallelCNodes() = 0;
virtual int CheckStrategy(const SplitStrategy &strategy) = 0;

View File

@ -22,72 +22,108 @@
namespace mindspore {
namespace opt {
bool ParallelPass::IsParallelCareNode(const AnfNodePtr &node) {
return std::any_of(PARALLEL_LIST.begin(), PARALLEL_LIST.end(), [this, &node](auto &prim) {
type_name_.clear();
return std::any_of(kParallelSet.begin(), kParallelSet.end(), [this, &node](auto &prim) {
if (CheckPrimitiveType(node, prim)) {
type_name_ = PrimToString(prim);
return true;
} else {
return false;
}
return !type_name_.empty();
});
}
std::string ParallelPass::PrimToString(const PrimitivePtr &prim) {
if (type_string.find(prim->name()) == type_string.end()) {
MS_LOG(EXCEPTION) << "String of the type not registered";
std::string parallel_name;
if (kParallelOpNames.find(prim) == kParallelOpNames.end()) {
MS_LOG(ERROR) << "String of the type not registered";
return parallel_name;
}
return type_string.at(prim->name());
return kParallelOpNames.at(prim);
}
AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
return nullptr;
}
bool ParallelPass::SetParallelOpName(const AnfNodePtr &node, std::string *parallel_name) {
if (!utils::isa<CNode>(node)) {
return nullptr;
return false;
}
auto cnode = node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(cnode) != lite::RET_OK) {
return nullptr;
}
if (!IsParallelCareNode(node)) {
return nullptr;
}
std::string cnode_name = cnode->fullname_with_scope();
std::string name = cnode_name;
// find operator name first, then operator type name.
if (split_strategys_.find(name) == split_strategys_.end()) {
name = type_name_;
}
if (cnode_name.find(PARALLEL_NAME_SUFFIX) != std::string::npos) {
MS_LOG(DEBUG) << " : Skip splited cnode " << cnode_name;
return nullptr;
return false;
}
// find operator name first, then operator type name.
if (split_strategys_.find(*parallel_name) == split_strategys_.end()) {
*parallel_name = type_name_;
}
MS_LOG(DEBUG) << " : Reached a parallel care node: " << cnode_name;
if (split_strategys_.find(name) == split_strategys_.end()) {
MS_LOG(DEBUG) << name << " : No split strategy for the current CNode.";
return nullptr;
if (split_strategys_.find(*parallel_name) == split_strategys_.end()) {
MS_LOG(DEBUG) << *parallel_name << " : No split strategy for the current CNode.";
return false;
}
cnode->set_fullname_with_scope(cnode_name + PARALLEL_NAME_SUFFIX);
return true;
}
cnode_name = cnode->fullname_with_scope();
OperatorInfoPtr ParallelPass::CreateParallelOperator(const AnfNodePtr &node, const std::string &scope_name,
const std::string &parallel_op_name) {
// foreach kernel_list && data_type
SplitOpKey op_key = SplitOpKey(schema::PrimitiveType_Conv2DFusion, kNumberTypeFloat32);
auto cnode = node->cast<CNodePtr>();
auto node_prim = cnode->input(kParallelPrimitiveIndex);
auto prim = GetValueNode<PrimitivePtr>(node_prim);
auto split_key_pair = kParallelSchemaId.find(prim);
if (split_key_pair == kParallelSchemaId.end()) {
return nullptr;
}
auto split_schema_id = split_key_pair->second.first;
auto split_type_id = split_key_pair->second.second;
SplitOpKey op_key = SplitOpKey(split_schema_id, split_type_id);
auto op_create_func = OperatorInfoFactory::GeInstance()->FindOperatorInfo(op_key);
if (op_create_func == nullptr) {
return nullptr;
}
OperatorInfoPtr operator_ = op_create_func(cnode_name, split_strategys_[name]);
if (operator_ == nullptr) {
MS_LOG(EXCEPTION) << "Failure: Create " << name << " OperatorInstance failed";
OperatorInfoPtr op = op_create_func(scope_name, split_strategys_[parallel_op_name]);
return op;
}
AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
return node;
}
operator_->set_cnode(cnode);
operator_->set_func_graph(func_graph);
operator_->setFmk(fmk_type_);
if (operator_->Init() == RET_ERROR) {
MS_LOG(EXCEPTION) << "Failure: operator " << name << " init failed";
if (!utils::isa<CNode>(node)) {
return node;
}
return operator_->replace_op();
if (!IsParallelCareNode(node)) {
return node;
}
auto cnode = node->cast<CNodePtr>();
std::string parallel_op_name = cnode->fullname_with_scope();
if (CheckIfCNodeIsNull(cnode) != lite::RET_OK) {
return nullptr;
}
if (!SetParallelOpName(node, &parallel_op_name)) {
return node;
}
std::string cnode_name = cnode->fullname_with_scope();
OperatorInfoPtr parallel_operator = CreateParallelOperator(node, cnode_name, parallel_op_name);
if (parallel_operator == nullptr) {
MS_LOG(ERROR) << "Failure: Create " << parallel_op_name << " OperatorInstance failed";
return node;
}
parallel_operator->set_cnode(cnode);
parallel_operator->set_func_graph(func_graph);
parallel_operator->set_fmk(fmk_type_);
if (parallel_operator->Init() == RET_ERROR) {
MS_LOG(ERROR) << "Failure: operator " << parallel_op_name << " init failed";
return node;
}
return parallel_operator->replace_op();
}
} // namespace opt

View File

@ -23,6 +23,7 @@
#include "tools/optimizer/common/gllo_utils.h"
#include "backend/optimizer/common/node_pass.h"
#include "tools/optimizer/parallel/split_strategy.h"
#include "tools/optimizer/parallel/operator_info.h"
#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_
@ -37,12 +38,19 @@ class ParallelPass : public opt::NodePass {
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
private:
const std::set<PrimitivePtr> PARALLEL_LIST = {prim::kPrimConv2DFusion};
const std::unordered_map<std::string, std::string> type_string = {{prim::kPrimConv2DFusion->name(), "Conv2D"}};
// to check this node whether support to parallel && split
bool IsParallelCareNode(const AnfNodePtr &node);
// mapping primitive to a parallel_op_name
std::string PrimToString(const PrimitivePtr &prim);
// set curr_node a new op_name with parallel symbol
bool SetParallelOpName(const AnfNodePtr &node, std::string *parallel_name);
// create a parallel operator from different scope_name
OperatorInfoPtr CreateParallelOperator(const AnfNodePtr &node, const std::string &scope_name,
const std::string &parallel_op_name);
private:
std::string type_name_;
std::unordered_map<std::string, SplitStrategy> split_strategys_;
int32_t fmk_type_;

View File

@ -21,7 +21,7 @@
namespace mindspore {
namespace opt {
std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(SplitMode parallel_mode) {
std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy() {
std::unordered_map<std::string, opt::SplitStrategy> split_strategys;
if (kSplitRatio.empty() || kSplitDefaultRatio.empty() || kSplitDevTypes.empty()) {
return split_strategys;
@ -31,7 +31,7 @@ std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(SplitMod
}
std::vector<std::vector<int64_t>> split_feature_map;
std::vector<std::vector<int64_t>> split_weight;
switch (parallel_mode) {
switch (kParallelMode) {
case SplitN:
split_feature_map = {kSplitRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
@ -52,7 +52,11 @@ std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(SplitMod
return split_strategys;
}
opt::Strategys strategys = {split_feature_map, split_weight};
split_strategys[opt::kSplitOp] = {strategys, kSplitDevTypes, kSplitDevTypes.size()};
for (const auto &supported_parallel_op : kParallelOpNames) {
split_strategys[supported_parallel_op.second] = {strategys, kSplitDevTypes, kSplitDevTypes.size(), kParallelMode};
}
return split_strategys;
}
} // namespace opt

View File

@ -16,19 +16,19 @@
#include <vector>
#include <string>
#include <set>
#include <utility>
#include <unordered_map>
#include "schema/ops_generated.h"
#include "base/core_ops.h"
#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_SPLIT_STRATEGY_H_
#define MINDSPORE_LITE_SRC_PASS_PARALLEL_SPLIT_STRATEGY_H_
namespace mindspore {
namespace opt {
constexpr auto OP = "op";
constexpr auto STRATEGY = "strategy";
constexpr auto DEV_TYPE = "dev_type";
constexpr auto PARALLEL_NAME_SUFFIX = "_parallel";
constexpr auto kSplitOp = "Conv2D";
constexpr auto kParallelPrimitiveIndex = 0;
const std::vector<int64_t> kSplitRatio = {1, 1};
@ -62,13 +62,27 @@ enum SplitMode {
SplitCOUT = 4,
};
constexpr auto kParallelMode = SplitH;
struct SplitStrategy {
Strategys strategys;
std::vector<std::string> dev_types;
size_t dev_num;
SplitMode split_mode_;
};
std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(SplitMode parallel_mode);
// this is a set to add mindspore supported ops
const std::set<PrimitivePtr> kParallelSet = {prim::kPrimConv2DFusion, prim::kPrimConv2D};
// this is a map for key: primitive value: parallel_op_name
const std::unordered_map<PrimitivePtr, std::string> kParallelOpNames = {{prim::kPrimConv2D, "Conv2D"},
{prim::kPrimConv2DFusion, "Conv2D"}};
// this is a map for key: primitive value: schema_primitive_id
const std::unordered_map<PrimitivePtr, std::pair<schema::PrimitiveType, TypeId>> kParallelSchemaId = {
{prim::kPrimConv2D, {schema::PrimitiveType_Conv2DFusion, kNumberTypeFloat32}}};
std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy();
} // namespace opt
} // namespace mindspore