diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.c index 6f5ce5eb7ce..4ae6f80bef1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.c @@ -75,7 +75,7 @@ int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, borders[i + 1] = cur_border; } } - borders[number_split - 1] = split_dim_size; + borders[number_split] = split_dim_size; for (int i = 0; i < number_split; ++i) { int output_shape[MAX_SHAPE_SIZE]; diff --git a/mindspore/lite/src/ops/ops_utils.cc b/mindspore/lite/src/ops/ops_utils.cc index e4e08a3b748..0fda79e769f 100644 --- a/mindspore/lite/src/ops/ops_utils.cc +++ b/mindspore/lite/src/ops/ops_utils.cc @@ -35,7 +35,7 @@ std::unique_ptr GetPrimitiveT(const AnfNodePtr &node) { return nullptr; } - MS_LOG(INFO) << "export prim: " << prim->name(); + MS_LOG(DEBUG) << "export prim: " << prim->name(); auto creator = MSOpsRegistry::GetInstance()->GetPrimitiveCreator(prim->name()); if (creator != nullptr) { return creator(node); @@ -766,6 +766,11 @@ std::unique_ptr CumSumPrimitiveCreator(const AnfNodePtr &nod return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +std::unique_ptr SplitWithOverlapPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} + RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator); RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); @@ -982,6 +987,7 @@ RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator) RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator); RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator); RegistryMSOps g_CumSumPrimitiveCreatorRegistry("CumSum", CumSumPrimitiveCreator); +RegistryMSOps g_SplitWithOverlapCreatorRegistry("SplitWithOverlap", SplitWithOverlapPrimitiveCreator); std::unique_ptr CustomPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index f07260965b6..bc75d1526a8 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -293,6 +293,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/parallel/split_strategy.cc ${LITE_DIR}/tools/optimizer/parallel/operator_info_register.cc ${LITE_DIR}/tools/optimizer/parallel/spliter.cc + ${LITE_DIR}/tools/optimizer/parallel/conv2d_info.cc ${LITE_DIR}/tools/common/graph_util.cc ${LITE_DIR}/tools/common/tensor_util.cc ${LITE_DIR}/tools/common/node_util.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index a38041f1a3f..b35209d1bc6 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -127,10 +127,10 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { MS_LOG(DEBUG) << "Run ParallelPass start"; - auto optimizer = std::make_shared(); - if (config->trainModel || !config->parallelMode) { + if (config->trainModel || static_cast(config->parallelMode) == opt::NoSplit) { return RET_OK; } + auto optimizer = std::make_shared(); // 1. deal with split strategy std::unordered_map split_strategys = ParserSplitStrategy(static_cast(config->parallelMode)); diff --git a/mindspore/lite/tools/optimizer/parallel/conv2d_info.cc b/mindspore/lite/tools/optimizer/parallel/conv2d_info.cc index 2bb82bfa0bc..2ffcd498f95 100644 --- a/mindspore/lite/tools/optimizer/parallel/conv2d_info.cc +++ b/mindspore/lite/tools/optimizer/parallel/conv2d_info.cc @@ -45,9 +45,9 @@ constexpr auto kPadDown = 1; constexpr auto kPadLeft = 2; constexpr auto kPadRight = 3; -lite::STATUS Conv2DInfo::GetAttrs() { return lite::RET_OK; } +int Conv2DInfo::GetAttrs() { return lite::RET_OK; } -lite::STATUS Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) { +int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) { int split_count = 0; Strategys strategys = strategy.strategys; @@ -95,7 +95,7 @@ lite::STATUS Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) { return lite::RET_OK; } -AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(CNodePtr orig_node, size_t input_index, +AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector *split_outputs, size_t split_dim, size_t split_num, const std::vector &splits, bool trans_format) { if (orig_node == nullptr) { @@ -123,7 +123,7 @@ AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(CNodePtr orig_node, size_t input_ind split_prim->set_stride(0); split_prim->set_pad_top(0); } - std::vector split_inputs = {NewValueNode(std::make_shared())}; + std::vector split_inputs = {NewValueNode(split_prim)}; split_inputs.push_back(orig_node->input(input_index + 1)); auto split_cnode = func_graph_->NewCNode(split_inputs); if (split_cnode == nullptr) { @@ -138,13 +138,23 @@ AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(CNodePtr orig_node, size_t input_ind return split_cnode; } -lite::STATUS Conv2DInfo::InferParallelCNodes() { +int Conv2DInfo::CheckConv2DPrimitiveType() { if (CheckIfFuncGraphIsNull(func_graph_) != lite::RET_OK) { return lite::RET_ERROR; } if (CheckIfAnfNodeIsNull(cnode_) != lite::RET_OK) { return lite::RET_ERROR; } + if (!CheckPrimitiveType(cnode_, prim::kPrimConv2D) && !CheckPrimitiveType(cnode_, prim::kPrimConv2DFusion)) { + return RET_ERROR; + } + return RET_OK; +} + +int Conv2DInfo::InferParallelCNodes() { + if (!CheckConv2DPrimitiveType()) { + return RET_OK; + } Strategys strategys = strategy_.strategys; size_t dev_num = strategy_.dev_num; std::vector feature_split_outputs; @@ -213,15 +223,16 @@ lite::STATUS Conv2DInfo::InferParallelCNodes() { } name_ = orig_name; - return ConstructOutputCNodes(conv_prim); + return ConstructOutputCNodes(conv_prim, feature_split_outputs, kernel_split_outputs, bias_split_outputs); } -lite::STATUS Conv2DInfo::ConstructOutputCNodes(std::shared_ptr conv_prim) { +int Conv2DInfo::ConstructOutputCNodes(const std::shared_ptr &conv_prim, + const std::vector &feature_split_outputs, + const std::vector &kernel_split_outputs, + const std::vector &bias_split_outputs) { Strategys strategys = strategy_.strategys; size_t dev_num = strategy_.dev_num; - std::vector feature_split_outputs; - std::vector kernel_split_outputs; - std::vector bias_split_outputs; + int cin_strategy_sum = std::accumulate(strategys[0][kAxisCIn].begin(), strategys[0][kAxisCIn].end(), 0); int cout_strategy_sum = std::accumulate(strategys[1][kAxisCOut].begin(), strategys[1][kAxisCOut].end(), 0); std::string conv_cnode_name = cnode_->fullname_with_scope(); @@ -235,6 +246,7 @@ lite::STATUS Conv2DInfo::ConstructOutputCNodes(std::shared_ptr(); + prim->set_pad(conv_prim->get_pad()); prim->set_in_channel(conv_prim->get_in_channel()); prim->set_out_channel(conv_prim->get_out_channel()); prim->set_dilation(conv_prim->get_dilation()); @@ -278,7 +290,7 @@ lite::STATUS Conv2DInfo::ConstructOutputCNodes(std::shared_ptr conv_inputs = {NewValueNode(std::make_shared())}; + std::vector conv_inputs = {NewValueNode(prim)}; // if split Cout, feature will not be splited if (splitMode_ == SplitCOUT) { conv_inputs.push_back(cnode_->input(1)); @@ -310,7 +322,7 @@ lite::STATUS Conv2DInfo::ConstructOutputCNodes(std::shared_ptr *split_outputs, size_t split_dim, + size_t split_num, const std::vector &splits, + bool trans_format) { + return nullptr; +} + +int DepthwiseConv2DInfo::InferParallelCNodes() { if (CheckIfFuncGraphIsNull(func_graph_) != lite::RET_OK) { return lite::RET_ERROR; } @@ -387,6 +406,7 @@ lite::STATUS DepthwiseConv2DInfo::InferParallelCNodes() { // copy attr auto prim = std::make_shared(); prim->AddAttr(ops::kIsDepthWise, MakeValue(true)); + prim->set_pad(conv_prim->get_pad()); prim->set_in_channel(conv_prim->get_in_channel()); prim->set_out_channel(conv_prim->get_out_channel()); prim->set_dilation(conv_prim->get_dilation()); @@ -398,7 +418,7 @@ lite::STATUS DepthwiseConv2DInfo::InferParallelCNodes() { prim->set_stride(conv_prim->get_stride()); prim->set_activation_type(conv_prim->get_activation_type()); - std::vector conv_inputs = {NewValueNode(std::make_shared())}; + std::vector conv_inputs = {NewValueNode(prim)}; conv_inputs.push_back(feature_split_outputs[i]); conv_inputs.push_back(cnode_->input(2)); if (cnode_->size() >= 4) { @@ -417,7 +437,7 @@ lite::STATUS DepthwiseConv2DInfo::InferParallelCNodes() { return lite::RET_OK; } -lite::STATUS DepthwiseConv2DInfo::InferReplaceOp() { +int DepthwiseConv2DInfo::InferReplaceOp() { size_t dev_num = strategy_.dev_num; replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, kAxisN, dev_num, true); if (replace_op_ == nullptr) { diff --git a/mindspore/lite/tools/optimizer/parallel/conv2d_info.h b/mindspore/lite/tools/optimizer/parallel/conv2d_info.h index 90ee83698ed..f946821cf83 100644 --- a/mindspore/lite/tools/optimizer/parallel/conv2d_info.h +++ b/mindspore/lite/tools/optimizer/parallel/conv2d_info.h @@ -34,14 +34,18 @@ class Conv2DInfo : public OperatorInfo { ~Conv2DInfo() override = default; protected: - lite::STATUS CheckStrategy(const SplitStrategy &strategy) override; - lite::STATUS GetAttrs() override; - lite::STATUS InferReplaceOp() override; - lite::STATUS InferParallelCNodes() override; - lite::STATUS ConstructOutputCNodes(std::shared_ptr conv_prim); - AnfNodePtr CreateOutputsOfSplit(CNodePtr orig_node, size_t input_index, std::vector *split_outputs, + int CheckStrategy(const SplitStrategy &strategy) override; + int GetAttrs() override; + int InferReplaceOp() override; + int InferParallelCNodes() override; + int ConstructOutputCNodes(const std::shared_ptr &conv_prim, + const std::vector &feature_split_outputs, + const std::vector &kernel_split_outputs, + const std::vector &bias_split_outputs); + AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector *split_outputs, size_t split_dim, size_t split_num, const std::vector &splits, - bool trans_format); + bool trans_format) override; + int CheckConv2DPrimitiveType(); SplitMode splitMode_ = NoSplit; bool format_NCHW_ = false; @@ -54,8 +58,11 @@ class DepthwiseConv2DInfo : public Conv2DInfo { ~DepthwiseConv2DInfo() override = default; protected: - lite::STATUS InferReplaceOp() override; - lite::STATUS InferParallelCNodes() override; + int InferReplaceOp() override; + int InferParallelCNodes() override; + AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector *split_outputs, + size_t split_dim, size_t split_num, const std::vector &splits, + bool trans_format) override; }; } // namespace opt diff --git a/mindspore/lite/tools/optimizer/parallel/dynamic_creator.cc b/mindspore/lite/tools/optimizer/parallel/dynamic_creator.cc index eb741a1e07b..a6adfa50770 100644 --- a/mindspore/lite/tools/optimizer/parallel/dynamic_creator.cc +++ b/mindspore/lite/tools/optimizer/parallel/dynamic_creator.cc @@ -15,10 +15,13 @@ */ #include "tools/optimizer/parallel/dynamic_creator.h" +#include "tools/optimizer/parallel/conv2d_info.h" namespace mindspore { namespace opt { // operator register +REGISTER(Conv2DInfo); +REGISTER(DepthwiseConv2DInfo); std::string GetDisOpName(const std::string &prim_name) { std::string op_name = prim_name; @@ -31,7 +34,7 @@ std::string GetDisOpName(const std::string &prim_name) { // create the OperatorInfo instance OperatorInfoPtr OperatorInstance(const std::string &type_name, const std::string &orig_name, const SplitStrategy &strategy) { - if (type_name.length() == 0) { + if (type_name.empty()) { MS_LOG(EXCEPTION) << "Length of name is zero!"; } std::string distribute_opname = GetDisOpName(type_name); diff --git a/mindspore/lite/tools/optimizer/parallel/operator_info.cc b/mindspore/lite/tools/optimizer/parallel/operator_info.cc index f60603bc3d0..5c36b56f735 100644 --- a/mindspore/lite/tools/optimizer/parallel/operator_info.cc +++ b/mindspore/lite/tools/optimizer/parallel/operator_info.cc @@ -101,7 +101,10 @@ int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t } tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i)); outputs->push_back(tuple_getitem); - ptr_list.push_back(abstract_scalar); + auto type_id = static_cast(operator_type_id_); + auto type_ptr = TypeIdToType(type_id); + std::vector shape_vector; + ptr_list.push_back(std::make_shared(type_ptr, shape_vector)); } node->set_abstract(std::make_shared(ptr_list)); return lite::RET_OK; @@ -151,7 +154,8 @@ AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std: } concat_cnode->set_fullname_with_scope("Concat_" + name_); concat_cnode->set_scope(orig_node->scope()); - + std::vector outputs; + CreateMultipleOutputsOfAnfNode(concat_cnode, 1, &outputs); return concat_cnode; } diff --git a/mindspore/lite/tools/optimizer/parallel/operator_info.h b/mindspore/lite/tools/optimizer/parallel/operator_info.h index 2cc1e145063..e5b9c37a867 100644 --- a/mindspore/lite/tools/optimizer/parallel/operator_info.h +++ b/mindspore/lite/tools/optimizer/parallel/operator_info.h @@ -60,14 +60,14 @@ class OperatorInfo { void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } void setFmk(const int32_t fmk_type) { fmk_type_ = fmk_type; } - AnfNodePtr replace_op() { return replace_op_; } + AnfNodePtr replace_op() const { return replace_op_; } int Init(); protected: int CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, std::vector *outputs); - AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector *split_outputs, - size_t split_dim, size_t split_num, const std::vector &splits, - bool trans_format); + virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, + std::vector *split_outputs, size_t split_dim, size_t split_num, + const std::vector &splits, bool trans_format); AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector &input_nodes, int32_t concat_dim, size_t input_nodes_num, bool trans_format); AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector &input_nodes, int32_t reduce_dim, @@ -84,6 +84,7 @@ class OperatorInfo { FuncGraphPtr func_graph_{nullptr}; CNodePtr cnode_{nullptr}; int32_t fmk_type_{}; + TypeId operator_type_id_ = kNumberTypeFloat32; private: int SetCNodeBackend(); diff --git a/mindspore/lite/tools/optimizer/parallel/parallel_pass.cc b/mindspore/lite/tools/optimizer/parallel/parallel_pass.cc index 4d9dabb913c..d15338ca66c 100644 --- a/mindspore/lite/tools/optimizer/parallel/parallel_pass.cc +++ b/mindspore/lite/tools/optimizer/parallel/parallel_pass.cc @@ -75,7 +75,7 @@ AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &n } operator_->set_cnode(cnode); operator_->set_func_graph(func_graph); - operator_->setFmk(FmkType_); + operator_->setFmk(fmk_type_); if (operator_->Init() == RET_ERROR) { MS_LOG(EXCEPTION) << "Failure: operator " << name << " init failed"; } diff --git a/mindspore/lite/tools/optimizer/parallel/parallel_pass.h b/mindspore/lite/tools/optimizer/parallel/parallel_pass.h index d8af45ad765..35579423e92 100644 --- a/mindspore/lite/tools/optimizer/parallel/parallel_pass.h +++ b/mindspore/lite/tools/optimizer/parallel/parallel_pass.h @@ -32,8 +32,8 @@ namespace mindspore { namespace opt { class ParallelPass : public opt::NodePass { public: - explicit ParallelPass(const std::unordered_map strategys, const int32_t FmkType) - : NodePass("parallel_pass"), split_strategys_(strategys), FmkType_(FmkType) {} + explicit ParallelPass(const std::unordered_map &strategys, const int32_t fmk_type) + : NodePass("parallel_pass"), split_strategys_(strategys), fmk_type_(fmk_type) {} ~ParallelPass() override = default; AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; @@ -46,7 +46,7 @@ class ParallelPass : public opt::NodePass { std::string type_name_; std::unordered_map split_strategys_; - int32_t FmkType_; + int32_t fmk_type_; }; } // namespace opt diff --git a/mindspore/lite/tools/optimizer/parallel/split_strategy.h b/mindspore/lite/tools/optimizer/parallel/split_strategy.h index 7b69cc3197c..c7a214708e4 100644 --- a/mindspore/lite/tools/optimizer/parallel/split_strategy.h +++ b/mindspore/lite/tools/optimizer/parallel/split_strategy.h @@ -39,11 +39,11 @@ const std::vector kSplitDevTypes = {"CPU", "GPU"}; using Strategys = std::vector>>; enum SplitMode { - SplitN = 0, - SplitH = 1, - SplitCIN = 2, - SplitCOUT = 3, - NoSplit = 4, + NoSplit = 0, + SplitN = 1, + SplitH = 2, + SplitCIN = 3, + SplitCOUT = 4, }; struct SplitStrategy {