!16141 fix parallel pass

From: @zoloft
Reviewed-by: @hangangqiang,@wangchengyuan
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2021-05-10 19:50:34 +08:00 committed by Gitee
commit a334118229
12 changed files with 86 additions and 44 deletions

View File

@ -75,7 +75,7 @@ int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size,
borders[i + 1] = cur_border;
}
}
borders[number_split - 1] = split_dim_size;
borders[number_split] = split_dim_size;
for (int i = 0; i < number_split; ++i) {
int output_shape[MAX_SHAPE_SIZE];

View File

@ -35,7 +35,7 @@ std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const AnfNodePtr &node) {
return nullptr;
}
MS_LOG(INFO) << "export prim: " << prim->name();
MS_LOG(DEBUG) << "export prim: " << prim->name();
auto creator = MSOpsRegistry::GetInstance()->GetPrimitiveCreator(prim->name());
if (creator != nullptr) {
return creator(node);
@ -766,6 +766,11 @@ std::unique_ptr<schema::PrimitiveT> CumSumPrimitiveCreator(const AnfNodePtr &nod
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
std::unique_ptr<schema::PrimitiveT> SplitWithOverlapPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SplitWithOverlap>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator);
RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator);
RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator);
@ -982,6 +987,7 @@ RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator)
RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator);
RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator);
RegistryMSOps g_CumSumPrimitiveCreatorRegistry("CumSum", CumSumPrimitiveCreator);
RegistryMSOps g_SplitWithOverlapCreatorRegistry("SplitWithOverlap", SplitWithOverlapPrimitiveCreator);
std::unique_ptr<schema::PrimitiveT> CustomPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Custom>>(node);

View File

@ -293,6 +293,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/parallel/split_strategy.cc
${LITE_DIR}/tools/optimizer/parallel/operator_info_register.cc
${LITE_DIR}/tools/optimizer/parallel/spliter.cc
${LITE_DIR}/tools/optimizer/parallel/conv2d_info.cc
${LITE_DIR}/tools/common/graph_util.cc
${LITE_DIR}/tools/common/tensor_util.cc
${LITE_DIR}/tools/common/node_util.cc

View File

@ -127,10 +127,10 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
MS_LOG(DEBUG) << "Run ParallelPass start";
auto optimizer = std::make_shared<opt::GraphOptimizer>();
if (config->trainModel || !config->parallelMode) {
if (config->trainModel || static_cast<opt::SplitMode>(config->parallelMode) == opt::NoSplit) {
return RET_OK;
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
// 1. deal with split strategy
std::unordered_map<std::string, opt::SplitStrategy> split_strategys =
ParserSplitStrategy(static_cast<opt::SplitMode>(config->parallelMode));

View File

@ -45,9 +45,9 @@ constexpr auto kPadDown = 1;
constexpr auto kPadLeft = 2;
constexpr auto kPadRight = 3;
lite::STATUS Conv2DInfo::GetAttrs() { return lite::RET_OK; }
int Conv2DInfo::GetAttrs() { return lite::RET_OK; }
lite::STATUS Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
int split_count = 0;
Strategys strategys = strategy.strategys;
@ -95,7 +95,7 @@ lite::STATUS Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
return lite::RET_OK;
}
AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(CNodePtr orig_node, size_t input_index,
AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
const std::vector<int64_t> &splits, bool trans_format) {
if (orig_node == nullptr) {
@ -123,7 +123,7 @@ AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(CNodePtr orig_node, size_t input_ind
split_prim->set_stride(0);
split_prim->set_pad_top(0);
}
std::vector<AnfNodePtr> split_inputs = {NewValueNode(std::make_shared<ops::SplitWithOverlap>())};
std::vector<AnfNodePtr> split_inputs = {NewValueNode(split_prim)};
split_inputs.push_back(orig_node->input(input_index + 1));
auto split_cnode = func_graph_->NewCNode(split_inputs);
if (split_cnode == nullptr) {
@ -138,13 +138,23 @@ AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(CNodePtr orig_node, size_t input_ind
return split_cnode;
}
lite::STATUS Conv2DInfo::InferParallelCNodes() {
int Conv2DInfo::CheckConv2DPrimitiveType() {
if (CheckIfFuncGraphIsNull(func_graph_) != lite::RET_OK) {
return lite::RET_ERROR;
}
if (CheckIfAnfNodeIsNull(cnode_) != lite::RET_OK) {
return lite::RET_ERROR;
}
if (!CheckPrimitiveType(cnode_, prim::kPrimConv2D) && !CheckPrimitiveType(cnode_, prim::kPrimConv2DFusion)) {
return RET_ERROR;
}
return RET_OK;
}
int Conv2DInfo::InferParallelCNodes() {
if (!CheckConv2DPrimitiveType()) {
return RET_OK;
}
Strategys strategys = strategy_.strategys;
size_t dev_num = strategy_.dev_num;
std::vector<AnfNodePtr> feature_split_outputs;
@ -213,15 +223,16 @@ lite::STATUS Conv2DInfo::InferParallelCNodes() {
}
name_ = orig_name;
return ConstructOutputCNodes(conv_prim);
return ConstructOutputCNodes(conv_prim, feature_split_outputs, kernel_split_outputs, bias_split_outputs);
}
lite::STATUS Conv2DInfo::ConstructOutputCNodes(std::shared_ptr<ops::Conv2DFusion> conv_prim) {
int Conv2DInfo::ConstructOutputCNodes(const std::shared_ptr<ops::Conv2DFusion> &conv_prim,
const std::vector<AnfNodePtr> &feature_split_outputs,
const std::vector<AnfNodePtr> &kernel_split_outputs,
const std::vector<AnfNodePtr> &bias_split_outputs) {
Strategys strategys = strategy_.strategys;
size_t dev_num = strategy_.dev_num;
std::vector<AnfNodePtr> feature_split_outputs;
std::vector<AnfNodePtr> kernel_split_outputs;
std::vector<AnfNodePtr> bias_split_outputs;
int cin_strategy_sum = std::accumulate(strategys[0][kAxisCIn].begin(), strategys[0][kAxisCIn].end(), 0);
int cout_strategy_sum = std::accumulate(strategys[1][kAxisCOut].begin(), strategys[1][kAxisCOut].end(), 0);
std::string conv_cnode_name = cnode_->fullname_with_scope();
@ -235,6 +246,7 @@ lite::STATUS Conv2DInfo::ConstructOutputCNodes(std::shared_ptr<ops::Conv2DFusion
}
// copy attr
auto prim = std::make_shared<ops::Conv2DFusion>();
prim->set_pad(conv_prim->get_pad());
prim->set_in_channel(conv_prim->get_in_channel());
prim->set_out_channel(conv_prim->get_out_channel());
prim->set_dilation(conv_prim->get_dilation());
@ -278,7 +290,7 @@ lite::STATUS Conv2DInfo::ConstructOutputCNodes(std::shared_ptr<ops::Conv2DFusion
default:
break;
}
std::vector<AnfNodePtr> conv_inputs = {NewValueNode(std::make_shared<ops::Conv2DFusion>())};
std::vector<AnfNodePtr> conv_inputs = {NewValueNode(prim)};
// if split Cout, feature will not be splited
if (splitMode_ == SplitCOUT) {
conv_inputs.push_back(cnode_->input(1));
@ -310,7 +322,7 @@ lite::STATUS Conv2DInfo::ConstructOutputCNodes(std::shared_ptr<ops::Conv2DFusion
return lite::RET_OK;
}
lite::STATUS Conv2DInfo::InferReplaceOp() {
int Conv2DInfo::InferReplaceOp() {
size_t dev_num = strategy_.dev_num;
if (splitMode_ == SplitCIN) {
MS_LOG(DEBUG) << name_ << " : Split Cin, infer Forward op.";
@ -334,7 +346,14 @@ lite::STATUS Conv2DInfo::InferReplaceOp() {
return lite::RET_OK;
}
lite::STATUS DepthwiseConv2DInfo::InferParallelCNodes() {
AnfNodePtr DepthwiseConv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
std::vector<AnfNodePtr> *split_outputs, size_t split_dim,
size_t split_num, const std::vector<int64_t> &splits,
bool trans_format) {
return nullptr;
}
int DepthwiseConv2DInfo::InferParallelCNodes() {
if (CheckIfFuncGraphIsNull(func_graph_) != lite::RET_OK) {
return lite::RET_ERROR;
}
@ -387,6 +406,7 @@ lite::STATUS DepthwiseConv2DInfo::InferParallelCNodes() {
// copy attr
auto prim = std::make_shared<ops::Conv2DFusion>();
prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true));
prim->set_pad(conv_prim->get_pad());
prim->set_in_channel(conv_prim->get_in_channel());
prim->set_out_channel(conv_prim->get_out_channel());
prim->set_dilation(conv_prim->get_dilation());
@ -398,7 +418,7 @@ lite::STATUS DepthwiseConv2DInfo::InferParallelCNodes() {
prim->set_stride(conv_prim->get_stride());
prim->set_activation_type(conv_prim->get_activation_type());
std::vector<AnfNodePtr> conv_inputs = {NewValueNode(std::make_shared<ops::Conv2DFusion>())};
std::vector<AnfNodePtr> conv_inputs = {NewValueNode(prim)};
conv_inputs.push_back(feature_split_outputs[i]);
conv_inputs.push_back(cnode_->input(2));
if (cnode_->size() >= 4) {
@ -417,7 +437,7 @@ lite::STATUS DepthwiseConv2DInfo::InferParallelCNodes() {
return lite::RET_OK;
}
lite::STATUS DepthwiseConv2DInfo::InferReplaceOp() {
int DepthwiseConv2DInfo::InferReplaceOp() {
size_t dev_num = strategy_.dev_num;
replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, kAxisN, dev_num, true);
if (replace_op_ == nullptr) {

View File

@ -34,14 +34,18 @@ class Conv2DInfo : public OperatorInfo {
~Conv2DInfo() override = default;
protected:
lite::STATUS CheckStrategy(const SplitStrategy &strategy) override;
lite::STATUS GetAttrs() override;
lite::STATUS InferReplaceOp() override;
lite::STATUS InferParallelCNodes() override;
lite::STATUS ConstructOutputCNodes(std::shared_ptr<ops::Conv2DFusion> conv_prim);
AnfNodePtr CreateOutputsOfSplit(CNodePtr orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs,
int CheckStrategy(const SplitStrategy &strategy) override;
int GetAttrs() override;
int InferReplaceOp() override;
int InferParallelCNodes() override;
int ConstructOutputCNodes(const std::shared_ptr<ops::Conv2DFusion> &conv_prim,
const std::vector<AnfNodePtr> &feature_split_outputs,
const std::vector<AnfNodePtr> &kernel_split_outputs,
const std::vector<AnfNodePtr> &bias_split_outputs);
AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs,
size_t split_dim, size_t split_num, const std::vector<int64_t> &splits,
bool trans_format);
bool trans_format) override;
int CheckConv2DPrimitiveType();
SplitMode splitMode_ = NoSplit;
bool format_NCHW_ = false;
@ -54,8 +58,11 @@ class DepthwiseConv2DInfo : public Conv2DInfo {
~DepthwiseConv2DInfo() override = default;
protected:
lite::STATUS InferReplaceOp() override;
lite::STATUS InferParallelCNodes() override;
int InferReplaceOp() override;
int InferParallelCNodes() override;
AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs,
size_t split_dim, size_t split_num, const std::vector<int64_t> &splits,
bool trans_format) override;
};
} // namespace opt

View File

@ -15,10 +15,13 @@
*/
#include "tools/optimizer/parallel/dynamic_creator.h"
#include "tools/optimizer/parallel/conv2d_info.h"
namespace mindspore {
namespace opt {
// operator register
REGISTER(Conv2DInfo);
REGISTER(DepthwiseConv2DInfo);
std::string GetDisOpName(const std::string &prim_name) {
std::string op_name = prim_name;
@ -31,7 +34,7 @@ std::string GetDisOpName(const std::string &prim_name) {
// create the OperatorInfo instance
OperatorInfoPtr OperatorInstance(const std::string &type_name, const std::string &orig_name,
const SplitStrategy &strategy) {
if (type_name.length() == 0) {
if (type_name.empty()) {
MS_LOG(EXCEPTION) << "Length of name is zero!";
}
std::string distribute_opname = GetDisOpName(type_name);

View File

@ -101,7 +101,10 @@ int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t
}
tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i));
outputs->push_back(tuple_getitem);
ptr_list.push_back(abstract_scalar);
auto type_id = static_cast<TypeId>(operator_type_id_);
auto type_ptr = TypeIdToType(type_id);
std::vector<int64_t> shape_vector;
ptr_list.push_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
}
node->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
return lite::RET_OK;
@ -151,7 +154,8 @@ AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std:
}
concat_cnode->set_fullname_with_scope("Concat_" + name_);
concat_cnode->set_scope(orig_node->scope());
std::vector<AnfNodePtr> outputs;
CreateMultipleOutputsOfAnfNode(concat_cnode, 1, &outputs);
return concat_cnode;
}

View File

@ -60,14 +60,14 @@ class OperatorInfo {
void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; }
void setFmk(const int32_t fmk_type) { fmk_type_ = fmk_type; }
AnfNodePtr replace_op() { return replace_op_; }
AnfNodePtr replace_op() const { return replace_op_; }
int Init();
protected:
int CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, std::vector<AnfNodePtr> *outputs);
AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs,
size_t split_dim, size_t split_num, const std::vector<int64_t> &splits,
bool trans_format);
virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
const std::vector<int64_t> &splits, bool trans_format);
AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
int32_t concat_dim, size_t input_nodes_num, bool trans_format);
AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes, int32_t reduce_dim,
@ -84,6 +84,7 @@ class OperatorInfo {
FuncGraphPtr func_graph_{nullptr};
CNodePtr cnode_{nullptr};
int32_t fmk_type_{};
TypeId operator_type_id_ = kNumberTypeFloat32;
private:
int SetCNodeBackend();

View File

@ -75,7 +75,7 @@ AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &n
}
operator_->set_cnode(cnode);
operator_->set_func_graph(func_graph);
operator_->setFmk(FmkType_);
operator_->setFmk(fmk_type_);
if (operator_->Init() == RET_ERROR) {
MS_LOG(EXCEPTION) << "Failure: operator " << name << " init failed";
}

View File

@ -32,8 +32,8 @@ namespace mindspore {
namespace opt {
class ParallelPass : public opt::NodePass {
public:
explicit ParallelPass(const std::unordered_map<std::string, SplitStrategy> strategys, const int32_t FmkType)
: NodePass("parallel_pass"), split_strategys_(strategys), FmkType_(FmkType) {}
explicit ParallelPass(const std::unordered_map<std::string, SplitStrategy> &strategys, const int32_t fmk_type)
: NodePass("parallel_pass"), split_strategys_(strategys), fmk_type_(fmk_type) {}
~ParallelPass() override = default;
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
@ -46,7 +46,7 @@ class ParallelPass : public opt::NodePass {
std::string type_name_;
std::unordered_map<std::string, SplitStrategy> split_strategys_;
int32_t FmkType_;
int32_t fmk_type_;
};
} // namespace opt

View File

@ -39,11 +39,11 @@ const std::vector<std::string> kSplitDevTypes = {"CPU", "GPU"};
using Strategys = std::vector<std::vector<std::vector<int64_t>>>;
enum SplitMode {
SplitN = 0,
SplitH = 1,
SplitCIN = 2,
SplitCOUT = 3,
NoSplit = 4,
NoSplit = 0,
SplitN = 1,
SplitH = 2,
SplitCIN = 3,
SplitCOUT = 4,
};
struct SplitStrategy {