forked from mindspore-Ecosystem/mindspore
!16440 adjsut sigle parallel split node pass
From: @zoloft Reviewed-by: @zhang_xue_tong Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
ea99e19b47
|
@ -134,7 +134,7 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter
|
|||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
// 1. deal with split strategy
|
||||
std::unordered_map<std::string, opt::SplitStrategy> split_strategys = ParserSplitStrategy(opt::SplitH);
|
||||
std::unordered_map<std::string, opt::SplitStrategy> split_strategys = opt::ParserSplitStrategy();
|
||||
if (split_strategys.empty()) {
|
||||
MS_LOG(ERROR) << "parse split_strategy error.";
|
||||
return RET_OK;
|
||||
|
@ -144,9 +144,7 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter
|
|||
parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
|
||||
parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
|
||||
// 3. multi_conv parallel pass
|
||||
auto strategy = split_strategys.begin()->second;
|
||||
parallel_pm->AddPass(
|
||||
std::make_shared<opt::MultiConvSplitPass>(strategy, schema::PrimitiveType_Conv2DFusion, config->fmk, 3));
|
||||
parallel_pm->AddPass(std::make_shared<opt::MultiConvSplitPass>(split_strategys, config->fmk, 3));
|
||||
parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
|
||||
// 4. single conv parallel pass
|
||||
parallel_pm->AddPass(std::make_shared<opt::ParallelPass>(split_strategys, config->fmk));
|
||||
|
|
|
@ -19,13 +19,27 @@
|
|||
#include "mindspore/ccsrc/utils/utils.h"
|
||||
#include "mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "mindspore/core/base/base.h"
|
||||
#include "mindspore/core/ops/fusion/conv2d_fusion.h"
|
||||
#include "base/base.h"
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::schema::PrimitiveType_Conv2dTransposeFusion;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
std::string MultiConvSplitPass::IsMultiParallelConvNode(const AnfNodePtr &node) const {
|
||||
std::string parallel_name;
|
||||
for (const auto ¶llel_prim : kParallelSet) {
|
||||
if (CheckPrimitiveType(node, parallel_prim)) {
|
||||
if (kParallelOpNames.find(parallel_prim) != kParallelOpNames.end()) {
|
||||
return kParallelOpNames.at(parallel_prim);
|
||||
}
|
||||
}
|
||||
}
|
||||
return parallel_name;
|
||||
}
|
||||
|
||||
const BaseRef MultiConvSplitPass::DefinePattern() const {
|
||||
auto conv1_var = std::make_shared<CondVar>(IsConvNode);
|
||||
auto conv1_other_var = std::make_shared<SeqVar>();
|
||||
|
@ -50,8 +64,12 @@ const AnfNodePtr MultiConvSplitPass::Process(const FuncGraphPtr &func_graph, con
|
|||
if (device_type != kDeviceTypeNone) {
|
||||
return node;
|
||||
}
|
||||
auto parallel_name = IsMultiParallelConvNode(node);
|
||||
if (parallel_name.empty()) {
|
||||
return node;
|
||||
}
|
||||
std::shared_ptr<MultiNodeSplitProxy> multi_node_split_proxy =
|
||||
std::make_shared<MultiNodeSplitProxy>(strategy_, primitive_type_, fmk_type_, num_);
|
||||
std::make_shared<MultiNodeSplitProxy>(strategys_.at(parallel_name), primitive_type_, fmk_type_, num_);
|
||||
return multi_node_split_proxy->DoSplit(func_graph, node);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_MULTI_CONV_SPLIT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "tools/optimizer/fisson/fisson_util.h"
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
|
@ -29,19 +31,18 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class MultiConvSplitPass : public PatternProcessPass {
|
||||
public:
|
||||
explicit MultiConvSplitPass(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1,
|
||||
explicit MultiConvSplitPass(const std::unordered_map<std::string, SplitStrategy> &strategys, int32_t fmk_type = -1,
|
||||
int32_t num = 3, bool multigraph = true)
|
||||
: PatternProcessPass("multi_conv_split", multigraph),
|
||||
strategy_(strategy),
|
||||
primitive_type_(primitive_type),
|
||||
fmk_type_(fmk_type),
|
||||
num_(num) {}
|
||||
: PatternProcessPass("multi_conv_split", multigraph), strategys_(strategys), fmk_type_(fmk_type), num_(num) {}
|
||||
~MultiConvSplitPass() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
SplitStrategy strategy_{};
|
||||
std::string IsMultiParallelConvNode(const AnfNodePtr &node) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, SplitStrategy> strategys_;
|
||||
PrimitiveType primitive_type_{schema::PrimitiveType_NONE};
|
||||
int32_t fmk_type_{-1};
|
||||
int32_t num_{0};
|
||||
|
|
|
@ -31,8 +31,6 @@ using mindspore::schema::PrimitiveType_Conv2DFusion;
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
int Conv2DInfo::GetAttrs() { return lite::RET_OK; }
|
||||
|
||||
int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
|
||||
int split_count = 0;
|
||||
Strategys strategys = strategy.strategys;
|
||||
|
@ -77,7 +75,6 @@ int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
|
|||
MS_LOG(ERROR) << "Strategy ERROR, only support split one dimension.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
|
@ -120,7 +117,6 @@ AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t in
|
|||
|
||||
split_cnode->set_fullname_with_scope("Split_" + name_);
|
||||
CreateMultipleOutputsOfAnfNode(split_cnode, split_num, split_outputs);
|
||||
|
||||
return split_cnode;
|
||||
}
|
||||
|
||||
|
@ -208,7 +204,6 @@ int Conv2DInfo::InferParallelCNodes() {
|
|||
MS_LOG(DEBUG) << "No Split mode chosen";
|
||||
}
|
||||
name_ = orig_name;
|
||||
|
||||
return ConstructOutputCNodes(conv_prim, feature_split_outputs, kernel_split_outputs, bias_split_outputs);
|
||||
}
|
||||
|
||||
|
@ -218,7 +213,6 @@ int Conv2DInfo::ConstructOutputCNodes(const std::shared_ptr<ops::Conv2DFusion> &
|
|||
const std::vector<AnfNodePtr> &bias_split_outputs) {
|
||||
Strategys strategys = strategy_.strategys;
|
||||
size_t dev_num = strategy_.dev_num;
|
||||
|
||||
int cin_strategy_sum = std::accumulate(strategys[0][kAxisCIn].begin(), strategys[0][kAxisCIn].end(), 0);
|
||||
int cout_strategy_sum = std::accumulate(strategys[1][kAxisCOut].begin(), strategys[1][kAxisCOut].end(), 0);
|
||||
std::string conv_cnode_name = cnode_->fullname_with_scope();
|
||||
|
@ -243,7 +237,6 @@ int Conv2DInfo::ConstructOutputCNodes(const std::shared_ptr<ops::Conv2DFusion> &
|
|||
prim->set_pad_list(conv_prim->get_pad_list());
|
||||
prim->set_stride(conv_prim->get_stride());
|
||||
prim->set_activation_type(conv_prim->get_activation_type());
|
||||
|
||||
switch (split_mode_) {
|
||||
case SplitH: {
|
||||
if (i != 0) {
|
||||
|
@ -350,7 +343,6 @@ int DepthwiseConv2DInfo::InferParallelCNodes() {
|
|||
size_t dev_num = strategy_.dev_num;
|
||||
std::vector<AnfNodePtr> feature_split_outputs;
|
||||
std::string orig_name = name_;
|
||||
|
||||
switch (split_mode_) {
|
||||
case SplitCIN: {
|
||||
MS_LOG(ERROR) << "DepthwiseConv2DInfo doesn't support split Cin.";
|
||||
|
@ -367,7 +359,6 @@ int DepthwiseConv2DInfo::InferParallelCNodes() {
|
|||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
parallel_output_nodes_.clear();
|
||||
std::string conv_cnode_name = cnode_->fullname_with_scope();
|
||||
auto conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode_->input(kAnfPrimitiveIndex));
|
||||
|
@ -385,7 +376,6 @@ int DepthwiseConv2DInfo::InferParallelCNodes() {
|
|||
}
|
||||
}
|
||||
name_ = orig_name;
|
||||
|
||||
// construct parallel Conv2D nodes
|
||||
for (size_t i = 0; i < dev_num; ++i) {
|
||||
std::vector<AnfNodePtr> tmp_outputs;
|
||||
|
@ -403,7 +393,6 @@ int DepthwiseConv2DInfo::InferParallelCNodes() {
|
|||
prim->set_pad_list(conv_prim->get_pad_list());
|
||||
prim->set_stride(conv_prim->get_stride());
|
||||
prim->set_activation_type(conv_prim->get_activation_type());
|
||||
|
||||
std::vector<AnfNodePtr> conv_inputs = {NewValueNode(prim)};
|
||||
conv_inputs.push_back(feature_split_outputs[i]);
|
||||
conv_inputs.push_back(cnode_->input(2));
|
||||
|
|
|
@ -35,7 +35,6 @@ class Conv2DInfo : public OperatorInfo {
|
|||
|
||||
protected:
|
||||
int CheckStrategy(const SplitStrategy &strategy) override;
|
||||
int GetAttrs() override;
|
||||
int InferReplaceOp() override;
|
||||
int InferParallelCNodes() override;
|
||||
int ConstructOutputCNodes(const std::shared_ptr<ops::Conv2DFusion> &conv_prim,
|
||||
|
|
|
@ -77,7 +77,7 @@ int MultiConvSplit::GetMultiConvNodes(const FuncGraphPtr &func_graph, const AnfN
|
|||
auto curr_node = conv_nodes_[idx];
|
||||
auto curr_cnode = conv_nodes_[idx]->cast<CNodePtr>();
|
||||
auto tmp_node = curr_cnode->input(1);
|
||||
if (IsConv2D(tmp_node)) {
|
||||
if (!IsConv2D(tmp_node)) {
|
||||
break;
|
||||
}
|
||||
auto name = tmp_node->fullname_with_scope();
|
||||
|
|
|
@ -20,6 +20,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
|
||||
int MultiNodeSplitProxy::InitResource() {
|
||||
split_mode_ = strategy_.split_mode_;
|
||||
switch (split_mode_) {
|
||||
case SplitN:
|
||||
multi_node_split_ = std::make_shared<MultiConvSplitN>(strategy_, primitive_type_, fmk_type_, num_);
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <utility>
|
||||
#include <memory>
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "base/base.h"
|
||||
|
||||
|
|
|
@ -36,12 +36,17 @@ bool is_any_not_none(const std::vector<int64_t> &split) {
|
|||
return std::any_of(split.begin(), split.end(), [](int64_t v) { return v != static_cast<int64_t>(NoSplit); });
|
||||
}
|
||||
|
||||
std::shared_ptr<abstract::AbstractTensor> OperatorInfo::CreateFakeAbstractTensor() {
|
||||
auto type_ptr = TypeIdToType(operator_type_id_);
|
||||
std::vector<int64_t> shape_vector;
|
||||
return std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
||||
}
|
||||
|
||||
int OperatorInfo::SetCNodeBackend() {
|
||||
for (size_t i = 0; i < strategy_.dev_num; ++i) {
|
||||
lite::DeviceType dt_type;
|
||||
std::string type = strategy_.dev_types[i];
|
||||
auto cnode = parallel_output_nodes_[i]->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (type == "GPU") {
|
||||
dt_type = lite::DeviceType::DT_GPU;
|
||||
} else if (type == "CPU") {
|
||||
|
@ -90,7 +95,6 @@ int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t
|
|||
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
auto idx = NewValueNode(SizeToInt(i));
|
||||
MS_ASSERT(idx);
|
||||
auto index = std::make_shared<Int32Imm>(SizeToInt(i));
|
||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index);
|
||||
idx->set_abstract(abstract_scalar);
|
||||
|
@ -101,38 +105,13 @@ int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t
|
|||
}
|
||||
tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i));
|
||||
outputs->push_back(tuple_getitem);
|
||||
auto type_id = static_cast<TypeId>(operator_type_id_);
|
||||
auto type_ptr = TypeIdToType(type_id);
|
||||
std::vector<int64_t> shape_vector;
|
||||
ptr_list.push_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
||||
auto abstract_tensor = CreateFakeAbstractTensor();
|
||||
ptr_list.push_back(abstract_tensor);
|
||||
}
|
||||
node->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
AnfNodePtr OperatorInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
|
||||
std::vector<AnfNodePtr> *split_outputs, size_t split_dim,
|
||||
size_t split_num, const std::vector<int64_t> &splits, bool trans_format) {
|
||||
MS_EXCEPTION_IF_NULL(orig_node);
|
||||
|
||||
auto split_prim = std::make_shared<ops::Split>();
|
||||
split_prim->set_output_num(split_num);
|
||||
split_prim->set_size_splits(splits);
|
||||
split_prim->set_axis(split_dim);
|
||||
auto value_node = NewValueNode(split_prim);
|
||||
std::vector<AnfNodePtr> split_inputs = {value_node};
|
||||
split_inputs.push_back(orig_node->input(input_index + 1));
|
||||
auto split_cnode = func_graph_->NewCNode(split_inputs);
|
||||
if (split_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << " : Failed to create split node.";
|
||||
return nullptr;
|
||||
}
|
||||
split_cnode->set_fullname_with_scope("Split_" + name_);
|
||||
CreateMultipleOutputsOfAnfNode(split_cnode, split_num, split_outputs);
|
||||
|
||||
return split_cnode;
|
||||
}
|
||||
|
||||
AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
|
||||
int32_t concat_dim, size_t input_nodes_num, bool trans_format) {
|
||||
MS_EXCEPTION_IF_NULL(orig_node);
|
||||
|
@ -185,10 +164,6 @@ AnfNodePtr OperatorInfo::CreateReduceNode(const CNodePtr &orig_node, const std::
|
|||
}
|
||||
|
||||
int OperatorInfo::Init() {
|
||||
if (GetAttrs() != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << name_ << ": Parse attrs failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (CheckStrategyValue() != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy values.";
|
||||
return lite::RET_ERROR;
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
|
@ -42,10 +42,8 @@ namespace opt {
|
|||
* 5.REGISTER XXXInfo in dynamic_creator.cc
|
||||
*/
|
||||
using schema::ReduceMode;
|
||||
|
||||
class OperatorInfo;
|
||||
using OperatorInfoPtr = std::shared_ptr<OperatorInfo>;
|
||||
|
||||
class OperatorInfo {
|
||||
public:
|
||||
OperatorInfo(const std::string &name, const SplitStrategy &strategy)
|
||||
|
@ -59,20 +57,23 @@ class OperatorInfo {
|
|||
void set_name(const std::string &name) { name_ = name; }
|
||||
void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
|
||||
void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; }
|
||||
void setFmk(const int32_t fmk_type) { fmk_type_ = fmk_type; }
|
||||
void set_fmk(const int32_t fmk_type) { fmk_type_ = fmk_type; }
|
||||
AnfNodePtr replace_op() const { return replace_op_; }
|
||||
int Init();
|
||||
|
||||
protected:
|
||||
int CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, std::vector<AnfNodePtr> *outputs);
|
||||
virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
|
||||
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
|
||||
const std::vector<int64_t> &splits, bool trans_format);
|
||||
|
||||
AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
|
||||
int32_t concat_dim, size_t input_nodes_num, bool trans_format);
|
||||
AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes, int32_t reduce_dim,
|
||||
size_t input_nodes_num, bool trans_format);
|
||||
virtual int GetAttrs() = 0;
|
||||
|
||||
std::shared_ptr<abstract::AbstractTensor> CreateFakeAbstractTensor();
|
||||
|
||||
virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
|
||||
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
|
||||
const std::vector<int64_t> &splits, bool trans_format) = 0;
|
||||
virtual int InferReplaceOp() = 0;
|
||||
virtual int InferParallelCNodes() = 0;
|
||||
virtual int CheckStrategy(const SplitStrategy &strategy) = 0;
|
||||
|
|
|
@ -22,72 +22,108 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool ParallelPass::IsParallelCareNode(const AnfNodePtr &node) {
|
||||
return std::any_of(PARALLEL_LIST.begin(), PARALLEL_LIST.end(), [this, &node](auto &prim) {
|
||||
type_name_.clear();
|
||||
return std::any_of(kParallelSet.begin(), kParallelSet.end(), [this, &node](auto &prim) {
|
||||
if (CheckPrimitiveType(node, prim)) {
|
||||
type_name_ = PrimToString(prim);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return !type_name_.empty();
|
||||
});
|
||||
}
|
||||
|
||||
std::string ParallelPass::PrimToString(const PrimitivePtr &prim) {
|
||||
if (type_string.find(prim->name()) == type_string.end()) {
|
||||
MS_LOG(EXCEPTION) << "String of the type not registered";
|
||||
std::string parallel_name;
|
||||
if (kParallelOpNames.find(prim) == kParallelOpNames.end()) {
|
||||
MS_LOG(ERROR) << "String of the type not registered";
|
||||
return parallel_name;
|
||||
}
|
||||
return type_string.at(prim->name());
|
||||
return kParallelOpNames.at(prim);
|
||||
}
|
||||
|
||||
AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
bool ParallelPass::SetParallelOpName(const AnfNodePtr &node, std::string *parallel_name) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (CheckIfCNodeIsNull(cnode) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!IsParallelCareNode(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string cnode_name = cnode->fullname_with_scope();
|
||||
std::string name = cnode_name;
|
||||
// find operator name first, then operator type name.
|
||||
if (split_strategys_.find(name) == split_strategys_.end()) {
|
||||
name = type_name_;
|
||||
}
|
||||
if (cnode_name.find(PARALLEL_NAME_SUFFIX) != std::string::npos) {
|
||||
MS_LOG(DEBUG) << " : Skip splited cnode " << cnode_name;
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
|
||||
// find operator name first, then operator type name.
|
||||
if (split_strategys_.find(*parallel_name) == split_strategys_.end()) {
|
||||
*parallel_name = type_name_;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << " : Reached a parallel care node: " << cnode_name;
|
||||
if (split_strategys_.find(name) == split_strategys_.end()) {
|
||||
MS_LOG(DEBUG) << name << " : No split strategy for the current CNode.";
|
||||
return nullptr;
|
||||
if (split_strategys_.find(*parallel_name) == split_strategys_.end()) {
|
||||
MS_LOG(DEBUG) << *parallel_name << " : No split strategy for the current CNode.";
|
||||
return false;
|
||||
}
|
||||
cnode->set_fullname_with_scope(cnode_name + PARALLEL_NAME_SUFFIX);
|
||||
return true;
|
||||
}
|
||||
|
||||
cnode_name = cnode->fullname_with_scope();
|
||||
OperatorInfoPtr ParallelPass::CreateParallelOperator(const AnfNodePtr &node, const std::string &scope_name,
|
||||
const std::string ¶llel_op_name) {
|
||||
// foreach kernel_list && data_type
|
||||
SplitOpKey op_key = SplitOpKey(schema::PrimitiveType_Conv2DFusion, kNumberTypeFloat32);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto node_prim = cnode->input(kParallelPrimitiveIndex);
|
||||
auto prim = GetValueNode<PrimitivePtr>(node_prim);
|
||||
auto split_key_pair = kParallelSchemaId.find(prim);
|
||||
|
||||
if (split_key_pair == kParallelSchemaId.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto split_schema_id = split_key_pair->second.first;
|
||||
auto split_type_id = split_key_pair->second.second;
|
||||
SplitOpKey op_key = SplitOpKey(split_schema_id, split_type_id);
|
||||
|
||||
auto op_create_func = OperatorInfoFactory::GeInstance()->FindOperatorInfo(op_key);
|
||||
if (op_create_func == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
OperatorInfoPtr operator_ = op_create_func(cnode_name, split_strategys_[name]);
|
||||
if (operator_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Create " << name << " OperatorInstance failed";
|
||||
OperatorInfoPtr op = op_create_func(scope_name, split_strategys_[parallel_op_name]);
|
||||
return op;
|
||||
}
|
||||
|
||||
AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
return node;
|
||||
}
|
||||
operator_->set_cnode(cnode);
|
||||
operator_->set_func_graph(func_graph);
|
||||
operator_->setFmk(fmk_type_);
|
||||
if (operator_->Init() == RET_ERROR) {
|
||||
MS_LOG(EXCEPTION) << "Failure: operator " << name << " init failed";
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
return node;
|
||||
}
|
||||
return operator_->replace_op();
|
||||
|
||||
if (!IsParallelCareNode(node)) {
|
||||
return node;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
std::string parallel_op_name = cnode->fullname_with_scope();
|
||||
if (CheckIfCNodeIsNull(cnode) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!SetParallelOpName(node, ¶llel_op_name)) {
|
||||
return node;
|
||||
}
|
||||
|
||||
std::string cnode_name = cnode->fullname_with_scope();
|
||||
OperatorInfoPtr parallel_operator = CreateParallelOperator(node, cnode_name, parallel_op_name);
|
||||
if (parallel_operator == nullptr) {
|
||||
MS_LOG(ERROR) << "Failure: Create " << parallel_op_name << " OperatorInstance failed";
|
||||
return node;
|
||||
}
|
||||
parallel_operator->set_cnode(cnode);
|
||||
parallel_operator->set_func_graph(func_graph);
|
||||
parallel_operator->set_fmk(fmk_type_);
|
||||
if (parallel_operator->Init() == RET_ERROR) {
|
||||
MS_LOG(ERROR) << "Failure: operator " << parallel_op_name << " init failed";
|
||||
return node;
|
||||
}
|
||||
return parallel_operator->replace_op();
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "backend/optimizer/common/node_pass.h"
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
#include "tools/optimizer/parallel/operator_info.h"
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_
|
||||
|
@ -37,12 +38,19 @@ class ParallelPass : public opt::NodePass {
|
|||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
const std::set<PrimitivePtr> PARALLEL_LIST = {prim::kPrimConv2DFusion};
|
||||
const std::unordered_map<std::string, std::string> type_string = {{prim::kPrimConv2DFusion->name(), "Conv2D"}};
|
||||
|
||||
// to check this node whether support to parallel && split
|
||||
bool IsParallelCareNode(const AnfNodePtr &node);
|
||||
// mapping primitive to a parallel_op_name
|
||||
std::string PrimToString(const PrimitivePtr &prim);
|
||||
|
||||
// set curr_node a new op_name with parallel symbol
|
||||
bool SetParallelOpName(const AnfNodePtr &node, std::string *parallel_name);
|
||||
|
||||
// create a parallel operator from different scope_name
|
||||
OperatorInfoPtr CreateParallelOperator(const AnfNodePtr &node, const std::string &scope_name,
|
||||
const std::string ¶llel_op_name);
|
||||
|
||||
private:
|
||||
std::string type_name_;
|
||||
std::unordered_map<std::string, SplitStrategy> split_strategys_;
|
||||
int32_t fmk_type_;
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(SplitMode parallel_mode) {
|
||||
std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy() {
|
||||
std::unordered_map<std::string, opt::SplitStrategy> split_strategys;
|
||||
if (kSplitRatio.empty() || kSplitDefaultRatio.empty() || kSplitDevTypes.empty()) {
|
||||
return split_strategys;
|
||||
|
@ -31,7 +31,7 @@ std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(SplitMod
|
|||
}
|
||||
std::vector<std::vector<int64_t>> split_feature_map;
|
||||
std::vector<std::vector<int64_t>> split_weight;
|
||||
switch (parallel_mode) {
|
||||
switch (kParallelMode) {
|
||||
case SplitN:
|
||||
split_feature_map = {kSplitRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
|
||||
split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
|
||||
|
@ -52,7 +52,11 @@ std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(SplitMod
|
|||
return split_strategys;
|
||||
}
|
||||
opt::Strategys strategys = {split_feature_map, split_weight};
|
||||
split_strategys[opt::kSplitOp] = {strategys, kSplitDevTypes, kSplitDevTypes.size()};
|
||||
|
||||
for (const auto &supported_parallel_op : kParallelOpNames) {
|
||||
split_strategys[supported_parallel_op.second] = {strategys, kSplitDevTypes, kSplitDevTypes.size(), kParallelMode};
|
||||
}
|
||||
|
||||
return split_strategys;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -16,19 +16,19 @@
|
|||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "schema/ops_generated.h"
|
||||
#include "base/core_ops.h"
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_SPLIT_STRATEGY_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_PARALLEL_SPLIT_STRATEGY_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr auto OP = "op";
|
||||
constexpr auto STRATEGY = "strategy";
|
||||
constexpr auto DEV_TYPE = "dev_type";
|
||||
|
||||
constexpr auto PARALLEL_NAME_SUFFIX = "_parallel";
|
||||
|
||||
constexpr auto kSplitOp = "Conv2D";
|
||||
constexpr auto kParallelPrimitiveIndex = 0;
|
||||
|
||||
const std::vector<int64_t> kSplitRatio = {1, 1};
|
||||
|
||||
|
@ -62,13 +62,27 @@ enum SplitMode {
|
|||
SplitCOUT = 4,
|
||||
};
|
||||
|
||||
constexpr auto kParallelMode = SplitH;
|
||||
|
||||
struct SplitStrategy {
|
||||
Strategys strategys;
|
||||
std::vector<std::string> dev_types;
|
||||
size_t dev_num;
|
||||
SplitMode split_mode_;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(SplitMode parallel_mode);
|
||||
// this is a set to add mindspore supported ops
|
||||
const std::set<PrimitivePtr> kParallelSet = {prim::kPrimConv2DFusion, prim::kPrimConv2D};
|
||||
|
||||
// this is a map for key: primitive value: parallel_op_name
|
||||
const std::unordered_map<PrimitivePtr, std::string> kParallelOpNames = {{prim::kPrimConv2D, "Conv2D"},
|
||||
{prim::kPrimConv2DFusion, "Conv2D"}};
|
||||
|
||||
// this is a map for key: primitive value: schema_primitive_id
|
||||
const std::unordered_map<PrimitivePtr, std::pair<schema::PrimitiveType, TypeId>> kParallelSchemaId = {
|
||||
{prim::kPrimConv2D, {schema::PrimitiveType_Conv2DFusion, kNumberTypeFloat32}}};
|
||||
|
||||
std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy();
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue