From 5a5c498371e74e3aa0df60d21d1e97a8cf04e063 Mon Sep 17 00:00:00 2001 From: z00512249 Date: Wed, 12 May 2021 10:32:49 +0800 Subject: [PATCH] add multi conv parallel pass --- mindspore/lite/test/CMakeLists.txt | 7 +- mindspore/lite/tools/converter/CMakeLists.txt | 7 +- .../lite/tools/converter/anf_transform.cc | 11 +- .../lite/tools/converter/converter_flags.h | 2 +- .../tools/optimizer/fisson/fisson_util.cc | 371 ++++++++++++++---- .../lite/tools/optimizer/fisson/fisson_util.h | 38 +- .../optimizer/fisson/multi_conv_split_pass.cc | 59 +++ .../optimizer/fisson/multi_conv_split_pass.h | 51 +++ .../tools/optimizer/parallel/conv2d_info.cc | 14 - .../optimizer/parallel/multi_conv_info.cc | 218 ++++++++++ .../optimizer/parallel/multi_conv_info.h | 107 +++++ .../optimizer/parallel/multi_node_split.cc | 60 +++ .../optimizer/parallel/multi_node_split.h | 61 +++ .../tools/optimizer/parallel/split_strategy.h | 16 + 14 files changed, 908 insertions(+), 114 deletions(-) create mode 100644 mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.cc create mode 100644 mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.h create mode 100644 mindspore/lite/tools/optimizer/parallel/multi_conv_info.cc create mode 100644 mindspore/lite/tools/optimizer/parallel/multi_conv_info.h create mode 100644 mindspore/lite/tools/optimizer/parallel/multi_node_split.cc create mode 100644 mindspore/lite/tools/optimizer/parallel/multi_node_split.h diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 6a29494e69b..0a5ac862769 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -287,12 +287,15 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fisson/fisson_util.cc ${LITE_DIR}/tools/optimizer/fisson/iter_node_outputs.cc ${LITE_DIR}/tools/optimizer/fisson/node_out_shapes.cc + ${LITE_DIR}/tools/optimizer/fisson/multi_conv_split_pass.cc + ${LITE_DIR}/tools/optimizer/parallel/multi_node_split.cc + ${LITE_DIR}/tools/optimizer/parallel/multi_conv_info.cc ${LITE_DIR}/tools/optimizer/parallel/parallel_pass.cc ${LITE_DIR}/tools/optimizer/parallel/operator_info.cc - ${LITE_DIR}/tools/optimizer/parallel/split_strategy.cc ${LITE_DIR}/tools/optimizer/parallel/operator_info_register.cc - ${LITE_DIR}/tools/optimizer/parallel/spliter.cc ${LITE_DIR}/tools/optimizer/parallel/conv2d_info.cc + ${LITE_DIR}/tools/optimizer/parallel/spliter.cc + ${LITE_DIR}/tools/optimizer/parallel/split_strategy.cc ${LITE_DIR}/tools/common/graph_util.cc ${LITE_DIR}/tools/common/tensor_util.cc ${LITE_DIR}/tools/common/node_util.cc diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 7d11dab9645..a864426782a 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -65,12 +65,15 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fisson/fisson_util.cc ../optimizer/fisson/iter_node_outputs.cc ../optimizer/fisson/node_out_shapes.cc + ../optimizer/fisson/multi_conv_split_pass.cc + ../optimizer/parallel/multi_node_split.cc + ../optimizer/parallel/multi_conv_info.cc + ../optimizer/parallel/parallel_pass.cc ../optimizer/parallel/conv2d_info.cc ../optimizer/parallel/operator_info.cc - ../optimizer/parallel/parallel_pass.cc - ../optimizer/parallel/split_strategy.cc ../optimizer/parallel/operator_info_register.cc ../optimizer/parallel/spliter.cc + ../optimizer/parallel/split_strategy.cc ../optimizer/graph/conv1d_inout_adjust_pass.cc ../optimizer/graph/weight_format_transform_pass.cc ../optimizer/graph/weight_format_hardcode_pass.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 82b1fea7b0d..7317ffa02dd 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -71,6 +71,7 @@ #include "tools/optimizer/fisson/node_out_shapes.h" #include "tools/optimizer/parallel/parallel_pass.h" #include "tools/converter/registry/pass_registry.h" +#include "tools/optimizer/fisson/multi_conv_split_pass.h" using std::string; namespace mindspore::lite { @@ -128,13 +129,12 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { MS_LOG(DEBUG) << "Run ParallelPass start"; - if (config->trainModel || static_cast(config->parallelMode) == opt::NoSplit) { + if (config->trainModel || !config->parallelMode) { return RET_OK; } auto optimizer = std::make_shared(); // 1. deal with split strategy - std::unordered_map split_strategys = - ParserSplitStrategy(static_cast(config->parallelMode)); + std::unordered_map split_strategys = ParserSplitStrategy(opt::SplitH); if (split_strategys.empty()) { MS_LOG(ERROR) << "parse split_strategy error."; return RET_OK; @@ -144,7 +144,10 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter parallel_pm->AddPass(std::make_shared()); parallel_pm->AddPass(std::make_shared()); // 3. multi_conv parallel pass - parallel_pm->AddPass(std::make_shared()); + auto strategy = split_strategys.begin()->second; + parallel_pm->AddPass( + std::make_shared(strategy, schema::PrimitiveType_Conv2DFusion, config->fmk, 3)); + parallel_pm->AddPass(std::make_shared()); // 4. single conv parallel pass parallel_pm->AddPass(std::make_shared(split_strategys, config->fmk)); optimizer->AddPassManager(parallel_pm); diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index a16269b5069..60b515407e8 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -80,7 +80,7 @@ class Flags : public virtual mindspore::lite::FlagParser { int quantWeightSize; std::string bitNumIn; int bitNum; - int parallelMode = 0; + bool parallelMode = false; std::string configFile; std::string quantWeightChannelStr; int quantWeightChannel; diff --git a/mindspore/lite/tools/optimizer/fisson/fisson_util.cc b/mindspore/lite/tools/optimizer/fisson/fisson_util.cc index 2d35328a373..d799e941eba 100644 --- a/mindspore/lite/tools/optimizer/fisson/fisson_util.cc +++ b/mindspore/lite/tools/optimizer/fisson/fisson_util.cc @@ -15,31 +15,282 @@ */ #include -#include #include #include "tools/optimizer/fisson/fisson_util.h" #include "base/core_ops.h" #include "src/common/utils.h" +#include "mindspore/core/ops/split_with_overlap.h" +#include "tools/common/node_util.h" +#include "tools/common/tensor_util.h" +#include "ops/concat.h" +#include "tools/optimizer/parallel/spliter.h" +#include "tools/optimizer/parallel/split_strategy.h" +using mindspore::lite::converter::FmkType; namespace mindspore { -using lite::converter::FmkType; - namespace opt { -AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, - const std::vector &conv_outputs, const SplitInfo &split_info, +namespace { + +bool CalSplitOutputShape(int32_t splited_axis_value, const SplitInfo *split_info, + std::vector *split_axis_out_shape, + std::vector *split_axis_reduce_out_shape) { + // ori ratio + int32_t split_num = split_info->size_splits.size(); + int32_t split_len = 0; + for (int32_t i = 0; i < split_num; i++) { + split_len += split_info->size_splits[i]; + } + if (split_len > splited_axis_value) { + return false; + } + // out-shape after splited + int32_t tmp_value = 0; + for (int32_t i = 0; i < split_num - 1; i++) { + int32_t tmp = (split_info->size_splits[i] * splited_axis_value) / split_len; + tmp_value += tmp; + split_axis_out_shape->push_back(tmp); + split_axis_reduce_out_shape->push_back(tmp_value); + } + split_axis_out_shape->push_back(splited_axis_value - tmp_value); + split_axis_reduce_out_shape->push_back(splited_axis_value); + return true; +} + +void CalSplitInShape(int32_t splited_axis_value, const SplitInfo *split_info, + const std::shared_ptr &ori_attr, int32_t idx_node, + std::vector> *split_axis_inputs_shape, + std::vector> *split_axis_reduce_inputs_shape) { + int32_t split_num = split_info->size_splits.size(); + int32_t tmp = 0; + std::vector split_axis_shape; + std::vector split_axis_reduce_shape; + + // iter splited_num + for (int32_t idx = 0; idx < split_num; idx++) { + // shape + if (split_info->axis == CuttingStragedy::CUT_H) { // H + if ((splited_axis_value + ori_attr->get_pad_list()[kPadUp] + ori_attr->get_pad_list()[kPadDown] - + (ori_attr->get_kernel_size()[kAxisH] - 1)) % + ori_attr->get_stride()[kIndexH] == + 0) { + if (idx == 0) { + tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx]) + + (ori_attr->get_kernel_size()[kAxisH] - 1) - ori_attr->get_pad_list()[kPadUp]; + } else if (idx == split_num - 1) { + tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx]) + + (ori_attr->get_kernel_size()[kAxisH] - 1) - ori_attr->get_pad_list()[kPadDown]; + } else { + tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx]) + + (ori_attr->get_kernel_size()[kAxisH] - 1) - 0; + } + } else { + if (idx == 0) { + tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) - + ori_attr->get_pad_list()[kPadUp] + ori_attr->get_kernel_size()[kAxisH]; + } else if (idx == split_num - 1) { + tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) - + ori_attr->get_pad_list()[kPadDown] + ori_attr->get_kernel_size()[kAxisH]; + } else { + tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) - 0 + + ori_attr->get_kernel_size()[kAxisH]; + } + } + + } else if (split_info->axis == CuttingStragedy::CUT_W) { // W + if (idx == 0) { + tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) - + ori_attr->get_pad_list()[kPadLeft] + ori_attr->get_kernel_size()[kAxisW]; + } else if (idx == split_num - 1) { + tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) - + ori_attr->get_pad_list()[kPadRight] + ori_attr->get_kernel_size()[kAxisW]; + } else { + tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) - 0 + + ori_attr->get_kernel_size()[kAxisW]; + } + } + split_axis_shape.push_back(tmp); + + // reduce shape + if (split_info->axis == CuttingStragedy::CUT_H) { // H + if ((splited_axis_value + ori_attr->get_pad_list()[kPadUp] + ori_attr->get_pad_list()[kPadDown] - + (ori_attr->get_kernel_size()[kAxisH] - 1)) % + ori_attr->get_stride()[kIndexH] == + 0) { + if (idx == split_num - 1) { + tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[idx_node][idx]) + + ori_attr->get_kernel_size()[kAxisH] - 1 - ori_attr->get_pad_list()[kPadDown] - + ori_attr->get_pad_list()[kPadUp]; + } else { + tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[idx_node][idx]) + + ori_attr->get_kernel_size()[kAxisH] - 1 - ori_attr->get_pad_list()[kPadUp]; + } + } else { + if (idx == split_num - 1) { + tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[idx_node][idx] - 1) - + ori_attr->get_pad_list()[kPadDown] - ori_attr->get_pad_list()[kPadUp] + + ori_attr->get_kernel_size()[kAxisH]; + } else { + tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[idx_node][idx] - 1) - + ori_attr->get_pad_list()[kPadUp] + ori_attr->get_kernel_size()[kAxisH]; + } + } + } else if (split_info->axis == CuttingStragedy::CUT_W) { // W + if (idx == split_num - 1) { + tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_reduce_inputs_shape)[idx_node][idx] - 1) - + ori_attr->get_pad_list()[kPadRight] - ori_attr->get_pad_list()[kPadLeft] + + ori_attr->get_kernel_size()[kAxisW]; + } else { + tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_reduce_inputs_shape)[idx_node][idx] - 1) - + ori_attr->get_pad_list()[kPadLeft] + ori_attr->get_kernel_size()[kAxisW]; + } + } + split_axis_reduce_shape.push_back(tmp); + } + split_axis_inputs_shape->push_back(split_axis_shape); + split_axis_reduce_inputs_shape->push_back(split_axis_reduce_shape); +} + +bool CheckPrim(const std::shared_ptr &ori_attr, int32_t splited_axis_value) { + return !(splited_axis_value == ori_attr->get_kernel_size()[kAxisH] && ori_attr->get_pad_list()[kPadUp] == 0 && + ori_attr->get_pad_list()[kPadDown] == 0); +} +} // namespace + +bool IsConv2D(const AnfNodePtr &node) { + return (CheckPrimitiveType(node, prim::kPrimConv2D) || CheckPrimitiveType(node, prim::kPrimConv2DFusion)); +} + +std::shared_ptr CopyConvPrim(const std::shared_ptr &ori_attr) { + auto prim = std::make_shared(); + prim->set_pad(ori_attr->get_pad()); + prim->set_in_channel(ori_attr->get_in_channel()); + prim->set_out_channel(ori_attr->get_out_channel()); + prim->set_dilation(ori_attr->get_dilation()); + prim->set_format(ori_attr->get_format()); + prim->set_group(ori_attr->get_group()); + prim->set_kernel_size(ori_attr->get_kernel_size()); + prim->set_pad_mode(ori_attr->get_pad_mode()); + prim->set_pad_list(ori_attr->get_pad_list()); + prim->set_stride(ori_attr->get_stride()); + prim->set_activation_type(ori_attr->get_activation_type()); + prim->set_pad_list(prim->get_pad_list()); + return prim; +} + +bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector &conv_nodes, SplitInfo *split_info) { + if (split_info->axis != CuttingStragedy::CUT_H) { + return false; + } + auto splited_axis = split_info->axis; + if (split_info->fmk_type == FmkType::FmkType_CAFFE || + split_info->fmk_type == FmkType::FmkType_ONNX) { // NHWC -> NCHW + splited_axis += 1; + } + + const int32_t node_size = conv_nodes.size(); + int32_t idx_node = 0; + std::vector> node_in_out_shapes; + while (idx_node < node_size) { + // [conv3, conv2, conv1] conv1->conv2->conv3 + auto out_node_name = conv_nodes[idx_node]->fullname_with_scope(); + auto output_shapes = Spliter::GetInstance()->graph_node_output_shapes()[out_node_name]; + auto input_shapes = Spliter::GetInstance()->graph_node_input_shapes()[out_node_name]; + // 0-> in-shape 1->out-shape + // only one in and one output + node_in_out_shapes.push_back({output_shapes.front(), input_shapes.front()}); + idx_node++; + } + + const int32_t splited_axis_value = node_in_out_shapes[0][1][splited_axis]; + int32_t split_num = split_info->size_splits.size(); + std::vector split_axis_out_shape; + std::vector split_axis_reduce_out_shape; + if (!CalSplitOutputShape(splited_axis_value, split_info, &split_axis_out_shape, &split_axis_reduce_out_shape)) { + return false; + } + // infer in-shape after splited + std::vector> split_axis_inputs_shape{split_axis_out_shape}; + std::vector> split_axis_reduce_inputs_shape{split_axis_reduce_out_shape}; + idx_node = 0; + // iter node + while (idx_node < node_size) { + auto conv_cnode = conv_nodes[idx_node]->cast(); + auto ori_attr = GetValueNode>(conv_cnode->input(kAnfPrimitiveIndex)); + if (!CheckPrim(ori_attr, splited_axis_value)) { + return false; + } + CalSplitInShape(splited_axis_value, split_info, ori_attr, idx_node, &split_axis_inputs_shape, + &split_axis_reduce_inputs_shape); + idx_node++; + } + + // update ratio + split_info->size_splits.clear(); + split_info->extend_top.clear(); + split_info->extend_bottom.clear(); + + int32_t top = 0; + int32_t bottom = 0; + split_info->size_splits.push_back(split_axis_inputs_shape[node_size][0]); + split_info->extend_top.push_back(top); + split_info->extend_bottom.push_back(bottom); + + for (int32_t i = 1; i < split_num; i++) { + auto begin = split_axis_reduce_inputs_shape[node_size][i] - split_axis_inputs_shape[node_size][i] + 1; + top = split_axis_reduce_inputs_shape[node_size][i - 1] - begin + 1; + auto value = split_axis_inputs_shape[node_size][i] - top; + split_info->size_splits.push_back(value); + split_info->extend_top.push_back(top); + split_info->extend_bottom.push_back(bottom); + } + return true; +} + +void GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num, + std::vector *outputs) { + if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { + return; + } + auto cnode = node->cast(); + if (CheckIfCNodeIsNull(cnode)) { + return; + } + for (size_t i = 0; i < output_num; i++) { + auto idx = NewValueNode(SizeToInt(i)); + if (CheckIfValueNodeIsNull(idx)) { + return; + } + size_t temp = SizeToInt(i); + auto imm = std::make_shared(temp); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); + if (CheckIfCNodeIsNull(tuple_getitem)) { + return; + } + tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i + 1)); + outputs->push_back(tuple_getitem); + } +} + +AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode, + const std::vector &conv_outputs, SplitInfo *split_info, const std::string &node_name) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(conv_cnode); int32_t nodes_num = conv_outputs.size(); - if (nodes_num != split_info.out_num) { + if (nodes_num != static_cast(split_info->out_num)) { MS_LOG(ERROR) << "Conv outputs has wrong input size"; return nullptr; } + auto concat_prim = std::make_shared(); + concat_prim->set_axis(split_info->axis); + // the inputs of concate are from the outputs of conv - std::vector concate_inputs = {NewValueNode(std::make_shared(prim::kPrimConcat->name()))}; + std::vector concate_inputs = {NewValueNode(concat_prim)}; for (int32_t i = 0; i < nodes_num; i++) { concate_inputs.push_back(conv_outputs[i]); } @@ -49,78 +300,52 @@ AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr concate_cnode->set_fullname_with_scope(node_name + "_Concat"); concate_cnode->set_scope(conv_cnode->scope()); - + std::vector outputs; + GetMultipleOutputsOfAnfNode(func_graph, concate_cnode, 1, &outputs); return concate_cnode; } -int32_t GetCOutAxis(int32_t format) { - switch (format) { - case schema::Format_KHWC: - return 0; - case schema::Format_CHWK: - return 3; - case schema::Format_NCHW: - return 0; - default: - MS_LOG(ERROR) << "Do not support format: " << format << " now."; - return -1; - } -} - -int32_t GetCInAxis(int32_t format) { - switch (format) { - case schema::Format_KHWC: - return 3; - case schema::Format_CHWK: - return 0; - default: - MS_LOG(ERROR) << "Do not support format: " << format << " now."; - return -1; - } -} - -int32_t GetAxis(int32_t axis, int32_t format, const SplitInfo &split_info) { - switch (split_info.primitive_type) { - case mindspore::schema::PrimitiveType_Conv2DFusion: - if (axis == CuttingStragedy::CUT_C_OUT) { - return GetCOutAxis(format); - } else if (axis == CuttingStragedy::CUT_C_IN) { - return GetCInAxis(format); - } else { - MS_LOG(ERROR) << "Only channel_in and channel_out need to transform."; - } - break; - default: - MS_LOG(ERROR) << "Now, do not support the type : " << split_info.primitive_type; - } - return -1; -} - -AnfNodePtr CreateOutputsOfAddN(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, - const std::vector &conv_outputs, const SplitInfo &split_info, - const std::string &node_name) { +void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node, + std::vector *split_outputs, SplitInfo *split_info, + const std::string &node_name) { MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(conv_cnode); + MS_EXCEPTION_IF_NULL(conv_node); + // attr of split + auto split_prim = std::make_shared(); + split_prim->set_split_dim(split_info->axis); + split_prim->set_number_split(split_info->out_num); + split_prim->set_ratio(split_info->size_splits); + split_prim->set_extend_top(split_info->extend_top); + split_prim->set_extend_bottom(split_info->extend_bottom); + // default to format khwc or nhwc + split_prim->set_trans_format(true); - int32_t nodes_num = conv_outputs.size(); - if (nodes_num != split_info.out_num) { - MS_LOG(ERROR) << "Conv outputs has wrong input size"; - return nullptr; + // the inputs of split is from the inputs of conv + std::vector split_inputs = {NewValueNode(split_prim)}; + auto conv_cnode = conv_node->cast(); + + // this conv only has one input, which has been ensured before + split_inputs.push_back(conv_cnode->input(1)); + + auto split_cnode = func_graph->NewCNode(split_inputs); + MS_EXCEPTION_IF_NULL(split_cnode); + + split_cnode->set_fullname_with_scope(node_name + "_Split"); + // create outputs op split + GetMultipleOutputsOfAnfNode(func_graph, split_cnode, split_info->out_num, split_outputs); + + AbstractBasePtrList ptr_list; + for (size_t i = 0; i < split_info->out_num; i++) { + auto node = (*split_outputs)[i]; + // set date_type same with weight + auto type_id = static_cast(kNumberTypeFloat32); + auto type_ptr = TypeIdToType(type_id); + std::vector shape_vector; + auto value_node = std::make_shared(type_ptr, shape_vector); + ptr_list.push_back(value_node); } - - // the inputs of addn are from the outputs of conv - std::vector addn_inputs = {NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; - for (int32_t i = 0; i < nodes_num; i++) { - addn_inputs.push_back(conv_outputs[i]); - } - - auto addn_cnode = func_graph->NewCNode(addn_inputs); - MS_EXCEPTION_IF_NULL(addn_cnode); - - addn_cnode->set_fullname_with_scope(node_name + "_AddN"); - addn_cnode->set_scope(conv_cnode->scope()); - - return addn_cnode; + split_cnode->set_abstract(std::make_shared(ptr_list)); } + } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fisson/fisson_util.h b/mindspore/lite/tools/optimizer/fisson/fisson_util.h index e6080406511..532d71c4735 100644 --- a/mindspore/lite/tools/optimizer/fisson/fisson_util.h +++ b/mindspore/lite/tools/optimizer/fisson/fisson_util.h @@ -20,47 +20,49 @@ #include #include #include +#include #include "schema/inner/model_generated.h" #include "mindspore/ccsrc/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/converter_flags.h" #include "mindspore/lite/include/context.h" #include "mindspore/lite/include/lite_types.h" +#include "ops/fusion/conv2d_fusion.h" namespace mindspore { using mindspore::schema::PrimitiveType; namespace opt { struct SplitInfo { - int32_t axis; - int32_t out_num; - std::vector size_splits; - std::vector extend_top; - std::vector extend_bottom; + int64_t axis; + size_t out_num; + std::vector size_splits; + std::vector extend_top; + std::vector extend_bottom; std::vector dev_types; - int32_t in_num_conv; - int32_t fmk_type; - std::vector weight_channel; + int64_t in_num_conv; + int64_t fmk_type; + std::vector weight_channel; PrimitiveType primitive_type; }; typedef enum { CUT_N, CUT_H, CUT_W, CUT_C_IN, CUT_C_OUT, CUT_NONE } CuttingStragedy; +bool IsConv2D(const AnfNodePtr &node); + +std::shared_ptr CopyConvPrim(const std::shared_ptr &ori_attr); + +bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector &conv_nodes, SplitInfo *split_info); + void GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num, std::vector *outputs); -AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, - const std::vector &conv_outputs, const SplitInfo &split_info, +AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode, + const std::vector &conv_outputs, SplitInfo *split_info, const std::string &node_name); -void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, - std::vector *split_outputs, const SplitInfo &split_info, +void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode, + std::vector *split_outputs, SplitInfo *split_info, const std::string &node_name); - -void GetCNodeShapeInfo(const FuncGraphPtr &func_graph, int32_t fmk_type); - -AnfNodePtr CreateOutputsOfAddN(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, - const std::vector &conv_outputs, const SplitInfo &split_info, - const std::string &node_name); } // namespace opt } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_FISSON_UTIL_H_ diff --git a/mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.cc b/mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.cc new file mode 100644 index 00000000000..68d8e17d8e7 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/ccsrc/utils/utils.h" +#include "mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "mindspore/core/base/base.h" +#include "mindspore/core/ops/fusion/conv2d_fusion.h" + +using mindspore::lite::converter::FmkType; +using mindspore::schema::PrimitiveType_Conv2dTransposeFusion; +namespace mindspore { +namespace opt { +const BaseRef MultiConvSplitPass::DefinePattern() const { + auto conv1_var = std::make_shared(IsConvNode); + auto conv1_other_var = std::make_shared(); + VectorRef res = VectorRef({conv1_var, conv1_other_var}); + int32_t idx = 1; + while (idx < num_) { + auto tmp_var = std::make_shared(IsConvNode); + auto tmp_other_var = std::make_shared(); + res = VectorRef({tmp_var, res, tmp_other_var}); + idx++; + } + return res; +} + +const AnfNodePtr MultiConvSplitPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_LOG(INFO) << "---Enter pass MultiConvSplit."; + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + auto device_type_attr = cnode->GetAttr(mindspore::ops::kDeviceType); + auto device_type = (device_type_attr != nullptr) ? GetValue(device_type_attr) : kDeviceTypeNone; + if (device_type != kDeviceTypeNone) { + return node; + } + std::shared_ptr multi_node_split_proxy = + std::make_shared(strategy_, primitive_type_, fmk_type_, num_); + return multi_node_split_proxy->DoSplit(func_graph, node); +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.h b/mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.h new file mode 100644 index 00000000000..922de3becbb --- /dev/null +++ b/mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.h @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_MULTI_CONV_SPLIT_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_MULTI_CONV_SPLIT_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "tools/optimizer/fisson/fisson_util.h" +#include "tools/optimizer/parallel/split_strategy.h" +#include "schema/inner/model_generated.h" +#include "tools/optimizer/parallel/multi_node_split.h" + +using mindspore::schema::PrimitiveType; +namespace mindspore { +namespace opt { +class MultiConvSplitPass : public PatternProcessPass { + public: + explicit MultiConvSplitPass(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, + int32_t num = 3, bool multigraph = true) + : PatternProcessPass("multi_conv_split", multigraph), + strategy_(strategy), + primitive_type_(primitive_type), + fmk_type_(fmk_type), + num_(num) {} + ~MultiConvSplitPass() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + SplitStrategy strategy_{}; + PrimitiveType primitive_type_{schema::PrimitiveType_NONE}; + int32_t fmk_type_{-1}; + int32_t num_{0}; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_MULTI_CONV_SPLIT_H_ diff --git a/mindspore/lite/tools/optimizer/parallel/conv2d_info.cc b/mindspore/lite/tools/optimizer/parallel/conv2d_info.cc index a1c7db1f2b3..640a3e551d2 100644 --- a/mindspore/lite/tools/optimizer/parallel/conv2d_info.cc +++ b/mindspore/lite/tools/optimizer/parallel/conv2d_info.cc @@ -30,20 +30,6 @@ using mindspore::schema::PrimitiveType_Conv2DFusion; namespace mindspore { namespace opt { -// strategy format is NHWC-KHWC -constexpr int32_t kAxisN = 0; -constexpr int32_t kAxisCIn = 3; -constexpr int32_t kAxisCOut = 0; -constexpr int32_t kAxisH = 1; -constexpr int32_t kAxisW = 2; - -constexpr auto kIndexH = 0; -constexpr auto kIndexW = 1; - -constexpr auto kPadUp = 0; -constexpr auto kPadDown = 1; -constexpr auto kPadLeft = 2; -constexpr auto kPadRight = 3; int Conv2DInfo::GetAttrs() { return lite::RET_OK; } diff --git a/mindspore/lite/tools/optimizer/parallel/multi_conv_info.cc b/mindspore/lite/tools/optimizer/parallel/multi_conv_info.cc new file mode 100644 index 00000000000..9bcf55599c5 --- /dev/null +++ b/mindspore/lite/tools/optimizer/parallel/multi_conv_info.cc @@ -0,0 +1,218 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/parallel/multi_conv_info.h" +#include +#include +#include "tools/optimizer/parallel/spliter.h" +#include "ops/fusion/conv2d_fusion.h" +#include "tools/optimizer/parallel/split_strategy.h" + +using mindspore::lite::converter::FmkType; +using mindspore::schema::PrimitiveType_Conv2dTransposeFusion; +namespace mindspore { +namespace opt { +int MultiConvSplit ::GenSplitInfo() { + split_info_.out_num = this->strategy_.dev_num; + for (const auto &dev_type : this->strategy_.dev_types) { + if (dev_type == "CPU") { + split_info_.dev_types.push_back(mindspore::lite::DeviceType::DT_CPU); + } else if (dev_type == "GPU") { + split_info_.dev_types.push_back(mindspore::lite::DeviceType::DT_GPU); + } else if (dev_type == "NPU") { + split_info_.dev_types.push_back(mindspore::lite::DeviceType::DT_NPU); + } else { + MS_LOG(ERROR) << "Do not support DeviceType:" << dev_type << "now."; + return RET_ERROR; + } + } + // only can get N && H && CIN && + std::vector tmp(split_info_.out_num, 0); + for (size_t i = 0; i < this->strategy_.strategys[0].size(); i++) { + if (this->strategy_.strategys[0][i] == tmp) { + continue; + } + split_info_.axis = i; // NHWC + split_info_.size_splits.clear(); + split_info_.size_splits = this->strategy_.strategys[0][i]; // cal base on compute_cap + break; + } + split_info_.in_num_conv = num_; + split_info_.fmk_type = fmk_type_; + split_info_.extend_bottom = std::vector(split_info_.size_splits.size(), 0); + split_info_.extend_top = std::vector(split_info_.size_splits.size(), 0); + split_info_.primitive_type = primitive_type_; + return RET_OK; +} + +int MultiConvSplit::GetMultiConvNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(conv_node); + // get nodes to be splited + // node in graph 1->2->3... + // node in vector ...->3->2->1 + std::string conv_cnode_name = conv_node->fullname_with_scope(); + MS_LOG(INFO) << "---node name:" << conv_cnode_name; + auto graph_node_outputs = Spliter::GetInstance()->graph_node_outputs(); + auto it = graph_node_outputs.find(conv_cnode_name); + if (it == graph_node_outputs.end()) { + MS_LOG(INFO) << "This node may be the last node of graph,it do not has any out-nodes."; + return RET_ERROR; + } + conv_nodes_.push_back(conv_node); + int32_t idx = 0; + while (idx < split_info_.in_num_conv - 1) { + auto curr_node = conv_nodes_[idx]; + auto curr_cnode = conv_nodes_[idx]->cast(); + auto tmp_node = curr_cnode->input(1); + if (IsConv2D(tmp_node)) { + break; + } + auto name = tmp_node->fullname_with_scope(); + // check outputs's bigger than two + it = graph_node_outputs.find(name); + if (it == graph_node_outputs.end()) { + return RET_ERROR; + } + if (it->second.size() > 1) { + break; + } + conv_nodes_.push_back(tmp_node); + idx++; + } + + // no need split in multi_node_pass + if (conv_nodes_.size() < 2) { + return RET_ERROR; + } + return RET_OK; +} + +AnfNodePtr MultiConvSplit::MultiConvNHSplit(const AnfNodePtr &node) { + std::string conv_cnode_name = node->fullname_with_scope(); + // Create Split node and get outputs of Split + std::vector split_outputs; + CreateOutputsOfSplitWithOverlap(func_graph_, conv_nodes_[conv_nodes_.size() - 1], &split_outputs, &split_info_, + conv_cnode_name); + // Create Conv node + for (int32_t i = conv_nodes_.size() - 1; i >= 0; i--) { + std::vector outputs_node; + SplitSingleConv(conv_nodes_[i], split_outputs, {}, {}, &outputs_node); + split_outputs.clear(); + std::copy(outputs_node.begin(), outputs_node.end(), std::back_inserter(split_outputs)); + outputs_node.clear(); + } + // Create concate node + auto concat_node = CreateOutputsOfConcat(func_graph_, node, split_outputs, &split_info_, conv_cnode_name); + split_outputs.clear(); + return concat_node; +} + +void MultiConvSplit::SplitSingleConv(const AnfNodePtr &ori_node, const std::vector &inputs_node, + const std::vector &weight_node, + const std::vector &bias_nodes, std::vector *outputs_node) { + auto ori_conv_cnode = ori_node->cast(); + auto ori_attr = GetValueNode>(ori_conv_cnode->input(kAnfPrimitiveIndex)); + for (int32_t output_conv_index = 0; output_conv_index < static_cast(split_info_.out_num); + output_conv_index++) { + // Create Conv node attr + auto conv_prim = CopyConvPrim(ori_attr); + // adjust primitive + AdJustConvPrim(conv_prim, output_conv_index); + // node inputs + std::vector conv_inputs; + conv_inputs.push_back(NewValueNode(conv_prim)); + AdJustInputs(ori_node, inputs_node, weight_node, bias_nodes, output_conv_index, &conv_inputs); + // create new conv node + CreateNewConvNode(ori_node, conv_inputs, output_conv_index, outputs_node); + } +} + +void MultiConvSplit::AdJustInputs(const AnfNodePtr &ori_conv_node, const std::vector &new_inputs_node, + const std::vector &weight_node, const std::vector &bias_nodes, + int output_conv_index, std::vector *conv_inputs) { + auto ori_conv_cnode = ori_conv_node->cast(); + // feature_map + conv_inputs->push_back(new_inputs_node[output_conv_index]); + // W+bias + for (size_t j = 2; j < ori_conv_cnode->size(); j++) { + conv_inputs->push_back(ori_conv_cnode->input(j)); + } +} + +void MultiConvSplit::CreateNewConvNode(const AnfNodePtr &ori_conv_node, const std::vector &conv_inputs, + int output_conv_index, std::vector *outputs_node) { + auto ori_conv_cnode = ori_conv_node->cast(); + std::string ori_cnode_name = ori_conv_cnode->fullname_with_scope(); + // new conv_node + auto conv_cnode = func_graph_->NewCNode(conv_inputs); + conv_cnode->set_fullname_with_scope(ori_cnode_name + "_" + PARALLEL_NAME_SUFFIX + + std::to_string(output_conv_index + 1)); + conv_cnode->AddAttr(mindspore::ops::kDeviceType, + MakeValue(static_cast(split_info_.dev_types[output_conv_index]))); + std::vector tmp_outputs; + // conv2d only has one output, set to output_nodes + GetMultipleOutputsOfAnfNode(func_graph_, conv_cnode, 1, &tmp_outputs); + outputs_node->push_back(tmp_outputs[0]->cast()->input(1)); + tmp_outputs.clear(); +} + +AnfNodePtr MultiConvSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + int ret = GenSplitInfo(); + if (ret != RET_OK) { + return node; + } + func_graph_ = func_graph; + ret = GetMultiConvNodes(func_graph, node); + if (ret != RET_OK) { + return node; + } + return SplitMultiConv(node); +} + +AnfNodePtr MultiConvSplitN::SplitMultiConv(const AnfNodePtr &node) { + if (conv_nodes_.size() == 2 && split_info_.axis == CuttingStragedy::CUT_N) { + return node; + } + return MultiConvNHSplit(node); +} + +AnfNodePtr MultiConvSplitH::SplitMultiConv(const AnfNodePtr &node) { + // update info, N do not need, C do not support + if (!UpdateSplitInfo(func_graph_, conv_nodes_, &split_info_)) { + return node; + } + return MultiConvNHSplit(node); +} + +void MultiConvSplitH::AdJustConvPrim(const std::shared_ptr &conv_prim, int output_conv_index) { + auto pad_list = conv_prim->get_pad_list(); + if (output_conv_index == 0) { + pad_list[kPadDown] = 0; + } else if (output_conv_index == static_cast(split_info_.out_num - 1)) { + pad_list[kPadUp] = 0; + } else { + pad_list[kPadUp] = 0; + pad_list[kPadDown] = 0; + } + conv_prim->set_pad_list(pad_list); +} + +AnfNodePtr MultiConvSplitCIN::SplitMultiConv(const AnfNodePtr &node) { return nullptr; } + +AnfNodePtr MultiConvSplitCOUT::SplitMultiConv(const AnfNodePtr &node) { return nullptr; } + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/parallel/multi_conv_info.h b/mindspore/lite/tools/optimizer/parallel/multi_conv_info.h new file mode 100644 index 00000000000..1fae785ddf6 --- /dev/null +++ b/mindspore/lite/tools/optimizer/parallel/multi_conv_info.h @@ -0,0 +1,107 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_CONV_INFO_H +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_CONV_INFO_H +#include +#include +#include "tools/optimizer/parallel/multi_node_split.h" +#include "tools/optimizer/fisson/fisson_util.h" +#include "ops/fusion/conv2d_fusion.h" +namespace mindspore { +namespace opt { +class MultiConvSplit : public MultiNodeSplit { + public: + explicit MultiConvSplit(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, + int32_t num = 3) + : MultiNodeSplit(), strategy_(strategy), primitive_type_(primitive_type), fmk_type_(fmk_type), num_(num) {} + + AnfNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; + + virtual AnfNodePtr SplitMultiConv(const AnfNodePtr &node) = 0; + + virtual void AdJustConvPrim(const std::shared_ptr &ori_attr, int output_conv_index) = 0; + + virtual AnfNodePtr MultiConvNHSplit(const AnfNodePtr &node); + + virtual void AdJustInputs(const AnfNodePtr &ori_node, const std::vector &new_inputs_node, + const std::vector &weight_node, const std::vector &bias_nodes, + int output_conv_index, std::vector *conv_inputs); + + virtual void CreateNewConvNode(const AnfNodePtr &ori_conv_node, const std::vector &conv_inputs, + int output_conv_index, std::vector *outputs_node); + + virtual void SplitSingleConv(const AnfNodePtr &ori_node, const std::vector &inputs_node, + const std::vector &weight_node, const std::vector &bias_nodes, + std::vector *outputs_node); + + protected: + FuncGraphPtr func_graph_{nullptr}; + SplitInfo split_info_{}; + SplitStrategy strategy_{}; + PrimitiveType primitive_type_{schema::PrimitiveType_NONE}; + int32_t fmk_type_{-1}; + int32_t num_{0}; + std::vector conv_nodes_{}; + + private: + int GenSplitInfo(); + int GetMultiConvNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node); +}; + +class MultiConvSplitN final : public MultiConvSplit { + public: + MultiConvSplitN(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, int32_t num = 3) + : MultiConvSplit(strategy, primitive_type, fmk_type, num) {} + + AnfNodePtr SplitMultiConv(const AnfNodePtr &node) override; + + void AdJustConvPrim(const std::shared_ptr &ori_attr, int output_conv_index) override {} +}; + +class MultiConvSplitCIN final : public MultiConvSplit { + public: + MultiConvSplitCIN(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, int32_t num = 3) + : MultiConvSplit(strategy, primitive_type, fmk_type, num) {} + + AnfNodePtr SplitMultiConv(const AnfNodePtr &node) override; + + void AdJustConvPrim(const std::shared_ptr &ori_attr, int output_conv_index) override {} +}; + +class MultiConvSplitCOUT final : public MultiConvSplit { + public: + MultiConvSplitCOUT(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, + int32_t num = 3) + : MultiConvSplit(strategy, primitive_type, fmk_type, num) {} + + AnfNodePtr SplitMultiConv(const AnfNodePtr &node) override; + + void AdJustConvPrim(const std::shared_ptr &ori_attr, int output_conv_index) override {} +}; + +class MultiConvSplitH final : public MultiConvSplit { + public: + MultiConvSplitH(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, int32_t num = 3) + : MultiConvSplit(strategy, primitive_type, fmk_type, num) {} + + AnfNodePtr SplitMultiConv(const AnfNodePtr &node) override; + + void AdJustConvPrim(const std::shared_ptr &ori_attr, int output_conv_index) override; +}; + +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_CONV_INFO_H diff --git a/mindspore/lite/tools/optimizer/parallel/multi_node_split.cc b/mindspore/lite/tools/optimizer/parallel/multi_node_split.cc new file mode 100644 index 00000000000..ebf1d5f5537 --- /dev/null +++ b/mindspore/lite/tools/optimizer/parallel/multi_node_split.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/parallel/multi_node_split.h" +#include "tools/optimizer/parallel/multi_conv_info.h" +namespace mindspore { +namespace opt { + +int MultiNodeSplitProxy::InitResource() { + switch (split_mode_) { + case SplitN: + multi_node_split_ = std::make_shared(strategy_, primitive_type_, fmk_type_, num_); + return RET_OK; + case SplitH: + multi_node_split_ = std::make_shared(strategy_, primitive_type_, fmk_type_, num_); + return RET_OK; + case SplitCIN: + multi_node_split_ = std::make_shared(strategy_, primitive_type_, fmk_type_, num_); + return RET_OK; + case SplitCOUT: + multi_node_split_ = std::make_shared(strategy_, primitive_type_, fmk_type_, num_); + return RET_OK; + default: + return RET_ERROR; + } +} + +int MultiNodeSplitProxy::FreeResource() { + multi_node_split_ = nullptr; + return RET_OK; +} + +AnfNodePtr MultiNodeSplitProxy::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + int ret = InitResource(); + if (ret != RET_OK) { + return node; + } + auto res_node = multi_node_split_->DoSplit(func_graph, node); + ret = FreeResource(); + if (ret != RET_OK) { + return node; + } + return res_node; +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/parallel/multi_node_split.h b/mindspore/lite/tools/optimizer/parallel/multi_node_split.h new file mode 100644 index 00000000000..606bf587644 --- /dev/null +++ b/mindspore/lite/tools/optimizer/parallel/multi_node_split.h @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_NODE_SPLIT_H +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_NODE_SPLIT_H +#include +#include +#include "tools/optimizer/parallel/split_strategy.h" +#include "schema/inner/model_generated.h" +#include "include/errorcode.h" +#include "base/base.h" + +using mindspore::schema::PrimitiveType; +namespace mindspore { +namespace opt { +class MultiNodeSplit { + public: + MultiNodeSplit() = default; + + virtual ~MultiNodeSplit() = default; + + virtual AnfNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0; +}; + +class MultiNodeSplitProxy : public MultiNodeSplit { + public: + explicit MultiNodeSplitProxy(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, + int32_t num = 3) + : MultiNodeSplit(), strategy_(strategy), primitive_type_(primitive_type), fmk_type_(fmk_type), num_(num) {} + + ~MultiNodeSplitProxy() override = default; + + AnfNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; + + private: + int InitResource(); + int FreeResource(); + + private: + SplitMode split_mode_{NoSplit}; + SplitStrategy strategy_{}; + PrimitiveType primitive_type_{schema::PrimitiveType_NONE}; + int32_t fmk_type_{-1}; + int32_t num_{0}; + std::shared_ptr multi_node_split_{nullptr}; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_NODE_SPLIT_H diff --git a/mindspore/lite/tools/optimizer/parallel/split_strategy.h b/mindspore/lite/tools/optimizer/parallel/split_strategy.h index c7a214708e4..d5afa5012a4 100644 --- a/mindspore/lite/tools/optimizer/parallel/split_strategy.h +++ b/mindspore/lite/tools/optimizer/parallel/split_strategy.h @@ -38,6 +38,22 @@ const std::vector kSplitDevTypes = {"CPU", "GPU"}; using Strategys = std::vector>>; +constexpr auto kDeviceTypeNone = -1; +// strategy format is NHWC-KHWC +constexpr int32_t kAxisN = 0; +constexpr int32_t kAxisCIn = 3; +constexpr int32_t kAxisCOut = 0; +constexpr int32_t kAxisH = 1; +constexpr int32_t kAxisW = 2; + +constexpr auto kIndexH = 0; +constexpr auto kIndexW = 1; + +constexpr auto kPadUp = 0; +constexpr auto kPadDown = 1; +constexpr auto kPadLeft = 2; +constexpr auto kPadRight = 3; + enum SplitMode { NoSplit = 0, SplitN = 1,