forked from mindspore-Ecosystem/mindspore
!16141 fix parallel pass
From: @zoloft Reviewed-by: @hangangqiang,@wangchengyuan Signed-off-by: @hangangqiang
This commit is contained in:
commit
a334118229
|
@ -75,7 +75,7 @@ int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size,
|
|||
borders[i + 1] = cur_border;
|
||||
}
|
||||
}
|
||||
borders[number_split - 1] = split_dim_size;
|
||||
borders[number_split] = split_dim_size;
|
||||
|
||||
for (int i = 0; i < number_split; ++i) {
|
||||
int output_shape[MAX_SHAPE_SIZE];
|
||||
|
|
|
@ -35,7 +35,7 @@ std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const AnfNodePtr &node) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "export prim: " << prim->name();
|
||||
MS_LOG(DEBUG) << "export prim: " << prim->name();
|
||||
auto creator = MSOpsRegistry::GetInstance()->GetPrimitiveCreator(prim->name());
|
||||
if (creator != nullptr) {
|
||||
return creator(node);
|
||||
|
@ -766,6 +766,11 @@ std::unique_ptr<schema::PrimitiveT> CumSumPrimitiveCreator(const AnfNodePtr &nod
|
|||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::PrimitiveT> SplitWithOverlapPrimitiveCreator(const AnfNodePtr &node) {
|
||||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SplitWithOverlap>>(node);
|
||||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
}
|
||||
|
||||
RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator);
|
||||
RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator);
|
||||
RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator);
|
||||
|
@ -982,6 +987,7 @@ RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator)
|
|||
RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator);
|
||||
RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator);
|
||||
RegistryMSOps g_CumSumPrimitiveCreatorRegistry("CumSum", CumSumPrimitiveCreator);
|
||||
RegistryMSOps g_SplitWithOverlapCreatorRegistry("SplitWithOverlap", SplitWithOverlapPrimitiveCreator);
|
||||
|
||||
std::unique_ptr<schema::PrimitiveT> CustomPrimitiveCreator(const AnfNodePtr &node) {
|
||||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Custom>>(node);
|
||||
|
|
|
@ -293,6 +293,7 @@ if(ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/parallel/split_strategy.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/operator_info_register.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/spliter.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/conv2d_info.cc
|
||||
${LITE_DIR}/tools/common/graph_util.cc
|
||||
${LITE_DIR}/tools/common/tensor_util.cc
|
||||
${LITE_DIR}/tools/common/node_util.cc
|
||||
|
|
|
@ -127,10 +127,10 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
|||
|
||||
int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
MS_LOG(DEBUG) << "Run ParallelPass start";
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
if (config->trainModel || !config->parallelMode) {
|
||||
if (config->trainModel || static_cast<opt::SplitMode>(config->parallelMode) == opt::NoSplit) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
// 1. deal with split strategy
|
||||
std::unordered_map<std::string, opt::SplitStrategy> split_strategys =
|
||||
ParserSplitStrategy(static_cast<opt::SplitMode>(config->parallelMode));
|
||||
|
|
|
@ -45,9 +45,9 @@ constexpr auto kPadDown = 1;
|
|||
constexpr auto kPadLeft = 2;
|
||||
constexpr auto kPadRight = 3;
|
||||
|
||||
lite::STATUS Conv2DInfo::GetAttrs() { return lite::RET_OK; }
|
||||
int Conv2DInfo::GetAttrs() { return lite::RET_OK; }
|
||||
|
||||
lite::STATUS Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
|
||||
int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
|
||||
int split_count = 0;
|
||||
Strategys strategys = strategy.strategys;
|
||||
|
||||
|
@ -95,7 +95,7 @@ lite::STATUS Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(CNodePtr orig_node, size_t input_index,
|
||||
AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
|
||||
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
|
||||
const std::vector<int64_t> &splits, bool trans_format) {
|
||||
if (orig_node == nullptr) {
|
||||
|
@ -123,7 +123,7 @@ AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(CNodePtr orig_node, size_t input_ind
|
|||
split_prim->set_stride(0);
|
||||
split_prim->set_pad_top(0);
|
||||
}
|
||||
std::vector<AnfNodePtr> split_inputs = {NewValueNode(std::make_shared<ops::SplitWithOverlap>())};
|
||||
std::vector<AnfNodePtr> split_inputs = {NewValueNode(split_prim)};
|
||||
split_inputs.push_back(orig_node->input(input_index + 1));
|
||||
auto split_cnode = func_graph_->NewCNode(split_inputs);
|
||||
if (split_cnode == nullptr) {
|
||||
|
@ -138,13 +138,23 @@ AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(CNodePtr orig_node, size_t input_ind
|
|||
return split_cnode;
|
||||
}
|
||||
|
||||
lite::STATUS Conv2DInfo::InferParallelCNodes() {
|
||||
int Conv2DInfo::CheckConv2DPrimitiveType() {
|
||||
if (CheckIfFuncGraphIsNull(func_graph_) != lite::RET_OK) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (CheckIfAnfNodeIsNull(cnode_) != lite::RET_OK) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (!CheckPrimitiveType(cnode_, prim::kPrimConv2D) && !CheckPrimitiveType(cnode_, prim::kPrimConv2DFusion)) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Conv2DInfo::InferParallelCNodes() {
|
||||
if (!CheckConv2DPrimitiveType()) {
|
||||
return RET_OK;
|
||||
}
|
||||
Strategys strategys = strategy_.strategys;
|
||||
size_t dev_num = strategy_.dev_num;
|
||||
std::vector<AnfNodePtr> feature_split_outputs;
|
||||
|
@ -213,15 +223,16 @@ lite::STATUS Conv2DInfo::InferParallelCNodes() {
|
|||
}
|
||||
name_ = orig_name;
|
||||
|
||||
return ConstructOutputCNodes(conv_prim);
|
||||
return ConstructOutputCNodes(conv_prim, feature_split_outputs, kernel_split_outputs, bias_split_outputs);
|
||||
}
|
||||
|
||||
lite::STATUS Conv2DInfo::ConstructOutputCNodes(std::shared_ptr<ops::Conv2DFusion> conv_prim) {
|
||||
int Conv2DInfo::ConstructOutputCNodes(const std::shared_ptr<ops::Conv2DFusion> &conv_prim,
|
||||
const std::vector<AnfNodePtr> &feature_split_outputs,
|
||||
const std::vector<AnfNodePtr> &kernel_split_outputs,
|
||||
const std::vector<AnfNodePtr> &bias_split_outputs) {
|
||||
Strategys strategys = strategy_.strategys;
|
||||
size_t dev_num = strategy_.dev_num;
|
||||
std::vector<AnfNodePtr> feature_split_outputs;
|
||||
std::vector<AnfNodePtr> kernel_split_outputs;
|
||||
std::vector<AnfNodePtr> bias_split_outputs;
|
||||
|
||||
int cin_strategy_sum = std::accumulate(strategys[0][kAxisCIn].begin(), strategys[0][kAxisCIn].end(), 0);
|
||||
int cout_strategy_sum = std::accumulate(strategys[1][kAxisCOut].begin(), strategys[1][kAxisCOut].end(), 0);
|
||||
std::string conv_cnode_name = cnode_->fullname_with_scope();
|
||||
|
@ -235,6 +246,7 @@ lite::STATUS Conv2DInfo::ConstructOutputCNodes(std::shared_ptr<ops::Conv2DFusion
|
|||
}
|
||||
// copy attr
|
||||
auto prim = std::make_shared<ops::Conv2DFusion>();
|
||||
prim->set_pad(conv_prim->get_pad());
|
||||
prim->set_in_channel(conv_prim->get_in_channel());
|
||||
prim->set_out_channel(conv_prim->get_out_channel());
|
||||
prim->set_dilation(conv_prim->get_dilation());
|
||||
|
@ -278,7 +290,7 @@ lite::STATUS Conv2DInfo::ConstructOutputCNodes(std::shared_ptr<ops::Conv2DFusion
|
|||
default:
|
||||
break;
|
||||
}
|
||||
std::vector<AnfNodePtr> conv_inputs = {NewValueNode(std::make_shared<ops::Conv2DFusion>())};
|
||||
std::vector<AnfNodePtr> conv_inputs = {NewValueNode(prim)};
|
||||
// if split Cout, feature will not be splited
|
||||
if (splitMode_ == SplitCOUT) {
|
||||
conv_inputs.push_back(cnode_->input(1));
|
||||
|
@ -310,7 +322,7 @@ lite::STATUS Conv2DInfo::ConstructOutputCNodes(std::shared_ptr<ops::Conv2DFusion
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
lite::STATUS Conv2DInfo::InferReplaceOp() {
|
||||
int Conv2DInfo::InferReplaceOp() {
|
||||
size_t dev_num = strategy_.dev_num;
|
||||
if (splitMode_ == SplitCIN) {
|
||||
MS_LOG(DEBUG) << name_ << " : Split Cin, infer Forward op.";
|
||||
|
@ -334,7 +346,14 @@ lite::STATUS Conv2DInfo::InferReplaceOp() {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
lite::STATUS DepthwiseConv2DInfo::InferParallelCNodes() {
|
||||
AnfNodePtr DepthwiseConv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
|
||||
std::vector<AnfNodePtr> *split_outputs, size_t split_dim,
|
||||
size_t split_num, const std::vector<int64_t> &splits,
|
||||
bool trans_format) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int DepthwiseConv2DInfo::InferParallelCNodes() {
|
||||
if (CheckIfFuncGraphIsNull(func_graph_) != lite::RET_OK) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
@ -387,6 +406,7 @@ lite::STATUS DepthwiseConv2DInfo::InferParallelCNodes() {
|
|||
// copy attr
|
||||
auto prim = std::make_shared<ops::Conv2DFusion>();
|
||||
prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true));
|
||||
prim->set_pad(conv_prim->get_pad());
|
||||
prim->set_in_channel(conv_prim->get_in_channel());
|
||||
prim->set_out_channel(conv_prim->get_out_channel());
|
||||
prim->set_dilation(conv_prim->get_dilation());
|
||||
|
@ -398,7 +418,7 @@ lite::STATUS DepthwiseConv2DInfo::InferParallelCNodes() {
|
|||
prim->set_stride(conv_prim->get_stride());
|
||||
prim->set_activation_type(conv_prim->get_activation_type());
|
||||
|
||||
std::vector<AnfNodePtr> conv_inputs = {NewValueNode(std::make_shared<ops::Conv2DFusion>())};
|
||||
std::vector<AnfNodePtr> conv_inputs = {NewValueNode(prim)};
|
||||
conv_inputs.push_back(feature_split_outputs[i]);
|
||||
conv_inputs.push_back(cnode_->input(2));
|
||||
if (cnode_->size() >= 4) {
|
||||
|
@ -417,7 +437,7 @@ lite::STATUS DepthwiseConv2DInfo::InferParallelCNodes() {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
lite::STATUS DepthwiseConv2DInfo::InferReplaceOp() {
|
||||
int DepthwiseConv2DInfo::InferReplaceOp() {
|
||||
size_t dev_num = strategy_.dev_num;
|
||||
replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, kAxisN, dev_num, true);
|
||||
if (replace_op_ == nullptr) {
|
||||
|
|
|
@ -34,14 +34,18 @@ class Conv2DInfo : public OperatorInfo {
|
|||
~Conv2DInfo() override = default;
|
||||
|
||||
protected:
|
||||
lite::STATUS CheckStrategy(const SplitStrategy &strategy) override;
|
||||
lite::STATUS GetAttrs() override;
|
||||
lite::STATUS InferReplaceOp() override;
|
||||
lite::STATUS InferParallelCNodes() override;
|
||||
lite::STATUS ConstructOutputCNodes(std::shared_ptr<ops::Conv2DFusion> conv_prim);
|
||||
AnfNodePtr CreateOutputsOfSplit(CNodePtr orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs,
|
||||
int CheckStrategy(const SplitStrategy &strategy) override;
|
||||
int GetAttrs() override;
|
||||
int InferReplaceOp() override;
|
||||
int InferParallelCNodes() override;
|
||||
int ConstructOutputCNodes(const std::shared_ptr<ops::Conv2DFusion> &conv_prim,
|
||||
const std::vector<AnfNodePtr> &feature_split_outputs,
|
||||
const std::vector<AnfNodePtr> &kernel_split_outputs,
|
||||
const std::vector<AnfNodePtr> &bias_split_outputs);
|
||||
AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs,
|
||||
size_t split_dim, size_t split_num, const std::vector<int64_t> &splits,
|
||||
bool trans_format);
|
||||
bool trans_format) override;
|
||||
int CheckConv2DPrimitiveType();
|
||||
|
||||
SplitMode splitMode_ = NoSplit;
|
||||
bool format_NCHW_ = false;
|
||||
|
@ -54,8 +58,11 @@ class DepthwiseConv2DInfo : public Conv2DInfo {
|
|||
~DepthwiseConv2DInfo() override = default;
|
||||
|
||||
protected:
|
||||
lite::STATUS InferReplaceOp() override;
|
||||
lite::STATUS InferParallelCNodes() override;
|
||||
int InferReplaceOp() override;
|
||||
int InferParallelCNodes() override;
|
||||
AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs,
|
||||
size_t split_dim, size_t split_num, const std::vector<int64_t> &splits,
|
||||
bool trans_format) override;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
|
|
|
@ -15,10 +15,13 @@
|
|||
*/
|
||||
|
||||
#include "tools/optimizer/parallel/dynamic_creator.h"
|
||||
#include "tools/optimizer/parallel/conv2d_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// operator register
|
||||
REGISTER(Conv2DInfo);
|
||||
REGISTER(DepthwiseConv2DInfo);
|
||||
|
||||
std::string GetDisOpName(const std::string &prim_name) {
|
||||
std::string op_name = prim_name;
|
||||
|
@ -31,7 +34,7 @@ std::string GetDisOpName(const std::string &prim_name) {
|
|||
// create the OperatorInfo instance
|
||||
OperatorInfoPtr OperatorInstance(const std::string &type_name, const std::string &orig_name,
|
||||
const SplitStrategy &strategy) {
|
||||
if (type_name.length() == 0) {
|
||||
if (type_name.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Length of name is zero!";
|
||||
}
|
||||
std::string distribute_opname = GetDisOpName(type_name);
|
||||
|
|
|
@ -101,7 +101,10 @@ int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t
|
|||
}
|
||||
tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i));
|
||||
outputs->push_back(tuple_getitem);
|
||||
ptr_list.push_back(abstract_scalar);
|
||||
auto type_id = static_cast<TypeId>(operator_type_id_);
|
||||
auto type_ptr = TypeIdToType(type_id);
|
||||
std::vector<int64_t> shape_vector;
|
||||
ptr_list.push_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
||||
}
|
||||
node->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
|
||||
return lite::RET_OK;
|
||||
|
@ -151,7 +154,8 @@ AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std:
|
|||
}
|
||||
concat_cnode->set_fullname_with_scope("Concat_" + name_);
|
||||
concat_cnode->set_scope(orig_node->scope());
|
||||
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
CreateMultipleOutputsOfAnfNode(concat_cnode, 1, &outputs);
|
||||
return concat_cnode;
|
||||
}
|
||||
|
||||
|
|
|
@ -60,14 +60,14 @@ class OperatorInfo {
|
|||
void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
|
||||
void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; }
|
||||
void setFmk(const int32_t fmk_type) { fmk_type_ = fmk_type; }
|
||||
AnfNodePtr replace_op() { return replace_op_; }
|
||||
AnfNodePtr replace_op() const { return replace_op_; }
|
||||
int Init();
|
||||
|
||||
protected:
|
||||
int CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, std::vector<AnfNodePtr> *outputs);
|
||||
AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs,
|
||||
size_t split_dim, size_t split_num, const std::vector<int64_t> &splits,
|
||||
bool trans_format);
|
||||
virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
|
||||
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
|
||||
const std::vector<int64_t> &splits, bool trans_format);
|
||||
AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
|
||||
int32_t concat_dim, size_t input_nodes_num, bool trans_format);
|
||||
AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes, int32_t reduce_dim,
|
||||
|
@ -84,6 +84,7 @@ class OperatorInfo {
|
|||
FuncGraphPtr func_graph_{nullptr};
|
||||
CNodePtr cnode_{nullptr};
|
||||
int32_t fmk_type_{};
|
||||
TypeId operator_type_id_ = kNumberTypeFloat32;
|
||||
|
||||
private:
|
||||
int SetCNodeBackend();
|
||||
|
|
|
@ -75,7 +75,7 @@ AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &n
|
|||
}
|
||||
operator_->set_cnode(cnode);
|
||||
operator_->set_func_graph(func_graph);
|
||||
operator_->setFmk(FmkType_);
|
||||
operator_->setFmk(fmk_type_);
|
||||
if (operator_->Init() == RET_ERROR) {
|
||||
MS_LOG(EXCEPTION) << "Failure: operator " << name << " init failed";
|
||||
}
|
||||
|
|
|
@ -32,8 +32,8 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class ParallelPass : public opt::NodePass {
|
||||
public:
|
||||
explicit ParallelPass(const std::unordered_map<std::string, SplitStrategy> strategys, const int32_t FmkType)
|
||||
: NodePass("parallel_pass"), split_strategys_(strategys), FmkType_(FmkType) {}
|
||||
explicit ParallelPass(const std::unordered_map<std::string, SplitStrategy> &strategys, const int32_t fmk_type)
|
||||
: NodePass("parallel_pass"), split_strategys_(strategys), fmk_type_(fmk_type) {}
|
||||
~ParallelPass() override = default;
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||
|
||||
|
@ -46,7 +46,7 @@ class ParallelPass : public opt::NodePass {
|
|||
|
||||
std::string type_name_;
|
||||
std::unordered_map<std::string, SplitStrategy> split_strategys_;
|
||||
int32_t FmkType_;
|
||||
int32_t fmk_type_;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
|
|
|
@ -39,11 +39,11 @@ const std::vector<std::string> kSplitDevTypes = {"CPU", "GPU"};
|
|||
using Strategys = std::vector<std::vector<std::vector<int64_t>>>;
|
||||
|
||||
enum SplitMode {
|
||||
SplitN = 0,
|
||||
SplitH = 1,
|
||||
SplitCIN = 2,
|
||||
SplitCOUT = 3,
|
||||
NoSplit = 4,
|
||||
NoSplit = 0,
|
||||
SplitN = 1,
|
||||
SplitH = 2,
|
||||
SplitCIN = 3,
|
||||
SplitCOUT = 4,
|
||||
};
|
||||
|
||||
struct SplitStrategy {
|
||||
|
|
Loading…
Reference in New Issue