add multi conv parallel pass

This commit is contained in:
z00512249 2021-05-12 10:32:49 +08:00
parent dec89be367
commit 5a5c498371
14 changed files with 908 additions and 114 deletions

View File

@ -287,12 +287,15 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fisson/fisson_util.cc
${LITE_DIR}/tools/optimizer/fisson/iter_node_outputs.cc
${LITE_DIR}/tools/optimizer/fisson/node_out_shapes.cc
${LITE_DIR}/tools/optimizer/fisson/multi_conv_split_pass.cc
${LITE_DIR}/tools/optimizer/parallel/multi_node_split.cc
${LITE_DIR}/tools/optimizer/parallel/multi_conv_info.cc
${LITE_DIR}/tools/optimizer/parallel/parallel_pass.cc
${LITE_DIR}/tools/optimizer/parallel/operator_info.cc
${LITE_DIR}/tools/optimizer/parallel/split_strategy.cc
${LITE_DIR}/tools/optimizer/parallel/operator_info_register.cc
${LITE_DIR}/tools/optimizer/parallel/spliter.cc
${LITE_DIR}/tools/optimizer/parallel/conv2d_info.cc
${LITE_DIR}/tools/optimizer/parallel/spliter.cc
${LITE_DIR}/tools/optimizer/parallel/split_strategy.cc
${LITE_DIR}/tools/common/graph_util.cc
${LITE_DIR}/tools/common/tensor_util.cc
${LITE_DIR}/tools/common/node_util.cc

View File

@ -65,12 +65,15 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fisson/fisson_util.cc
../optimizer/fisson/iter_node_outputs.cc
../optimizer/fisson/node_out_shapes.cc
../optimizer/fisson/multi_conv_split_pass.cc
../optimizer/parallel/multi_node_split.cc
../optimizer/parallel/multi_conv_info.cc
../optimizer/parallel/parallel_pass.cc
../optimizer/parallel/conv2d_info.cc
../optimizer/parallel/operator_info.cc
../optimizer/parallel/parallel_pass.cc
../optimizer/parallel/split_strategy.cc
../optimizer/parallel/operator_info_register.cc
../optimizer/parallel/spliter.cc
../optimizer/parallel/split_strategy.cc
../optimizer/graph/conv1d_inout_adjust_pass.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc

View File

@ -71,6 +71,7 @@
#include "tools/optimizer/fisson/node_out_shapes.h"
#include "tools/optimizer/parallel/parallel_pass.h"
#include "tools/converter/registry/pass_registry.h"
#include "tools/optimizer/fisson/multi_conv_split_pass.h"
using std::string;
namespace mindspore::lite {
@ -128,13 +129,12 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
MS_LOG(DEBUG) << "Run ParallelPass start";
if (config->trainModel || static_cast<opt::SplitMode>(config->parallelMode) == opt::NoSplit) {
if (config->trainModel || !config->parallelMode) {
return RET_OK;
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
// 1. deal with split strategy
std::unordered_map<std::string, opt::SplitStrategy> split_strategys =
ParserSplitStrategy(static_cast<opt::SplitMode>(config->parallelMode));
std::unordered_map<std::string, opt::SplitStrategy> split_strategys = ParserSplitStrategy(opt::SplitH);
if (split_strategys.empty()) {
MS_LOG(ERROR) << "parse split_strategy error.";
return RET_OK;
@ -144,7 +144,10 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter
parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
// 3. multi_conv parallel pass
parallel_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
auto strategy = split_strategys.begin()->second;
parallel_pm->AddPass(
std::make_shared<opt::MultiConvSplitPass>(strategy, schema::PrimitiveType_Conv2DFusion, config->fmk, 3));
parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
// 4. single conv parallel pass
parallel_pm->AddPass(std::make_shared<opt::ParallelPass>(split_strategys, config->fmk));
optimizer->AddPassManager(parallel_pm);

View File

@ -80,7 +80,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
int quantWeightSize;
std::string bitNumIn;
int bitNum;
int parallelMode = 0;
bool parallelMode = false;
std::string configFile;
std::string quantWeightChannelStr;
int quantWeightChannel;

View File

@ -15,31 +15,282 @@
*/
#include <unordered_set>
#include <unordered_map>
#include <memory>
#include "tools/optimizer/fisson/fisson_util.h"
#include "base/core_ops.h"
#include "src/common/utils.h"
#include "mindspore/core/ops/split_with_overlap.h"
#include "tools/common/node_util.h"
#include "tools/common/tensor_util.h"
#include "ops/concat.h"
#include "tools/optimizer/parallel/spliter.h"
#include "tools/optimizer/parallel/split_strategy.h"
using mindspore::lite::converter::FmkType;
namespace mindspore {
using lite::converter::FmkType;
namespace opt {
AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
namespace {
bool CalSplitOutputShape(int32_t splited_axis_value, const SplitInfo *split_info,
std::vector<int32_t> *split_axis_out_shape,
std::vector<int32_t> *split_axis_reduce_out_shape) {
// ori ratio
int32_t split_num = split_info->size_splits.size();
int32_t split_len = 0;
for (int32_t i = 0; i < split_num; i++) {
split_len += split_info->size_splits[i];
}
if (split_len > splited_axis_value) {
return false;
}
// out-shape after splited
int32_t tmp_value = 0;
for (int32_t i = 0; i < split_num - 1; i++) {
int32_t tmp = (split_info->size_splits[i] * splited_axis_value) / split_len;
tmp_value += tmp;
split_axis_out_shape->push_back(tmp);
split_axis_reduce_out_shape->push_back(tmp_value);
}
split_axis_out_shape->push_back(splited_axis_value - tmp_value);
split_axis_reduce_out_shape->push_back(splited_axis_value);
return true;
}
void CalSplitInShape(int32_t splited_axis_value, const SplitInfo *split_info,
const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int32_t idx_node,
std::vector<std::vector<int32_t>> *split_axis_inputs_shape,
std::vector<std::vector<int32_t>> *split_axis_reduce_inputs_shape) {
int32_t split_num = split_info->size_splits.size();
int32_t tmp = 0;
std::vector<int32_t> split_axis_shape;
std::vector<int32_t> split_axis_reduce_shape;
// iter splited_num
for (int32_t idx = 0; idx < split_num; idx++) {
// shape
if (split_info->axis == CuttingStragedy::CUT_H) { // H
if ((splited_axis_value + ori_attr->get_pad_list()[kPadUp] + ori_attr->get_pad_list()[kPadDown] -
(ori_attr->get_kernel_size()[kAxisH] - 1)) %
ori_attr->get_stride()[kIndexH] ==
0) {
if (idx == 0) {
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx]) +
(ori_attr->get_kernel_size()[kAxisH] - 1) - ori_attr->get_pad_list()[kPadUp];
} else if (idx == split_num - 1) {
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx]) +
(ori_attr->get_kernel_size()[kAxisH] - 1) - ori_attr->get_pad_list()[kPadDown];
} else {
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx]) +
(ori_attr->get_kernel_size()[kAxisH] - 1) - 0;
}
} else {
if (idx == 0) {
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) -
ori_attr->get_pad_list()[kPadUp] + ori_attr->get_kernel_size()[kAxisH];
} else if (idx == split_num - 1) {
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) -
ori_attr->get_pad_list()[kPadDown] + ori_attr->get_kernel_size()[kAxisH];
} else {
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) - 0 +
ori_attr->get_kernel_size()[kAxisH];
}
}
} else if (split_info->axis == CuttingStragedy::CUT_W) { // W
if (idx == 0) {
tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) -
ori_attr->get_pad_list()[kPadLeft] + ori_attr->get_kernel_size()[kAxisW];
} else if (idx == split_num - 1) {
tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) -
ori_attr->get_pad_list()[kPadRight] + ori_attr->get_kernel_size()[kAxisW];
} else {
tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) - 0 +
ori_attr->get_kernel_size()[kAxisW];
}
}
split_axis_shape.push_back(tmp);
// reduce shape
if (split_info->axis == CuttingStragedy::CUT_H) { // H
if ((splited_axis_value + ori_attr->get_pad_list()[kPadUp] + ori_attr->get_pad_list()[kPadDown] -
(ori_attr->get_kernel_size()[kAxisH] - 1)) %
ori_attr->get_stride()[kIndexH] ==
0) {
if (idx == split_num - 1) {
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[idx_node][idx]) +
ori_attr->get_kernel_size()[kAxisH] - 1 - ori_attr->get_pad_list()[kPadDown] -
ori_attr->get_pad_list()[kPadUp];
} else {
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[idx_node][idx]) +
ori_attr->get_kernel_size()[kAxisH] - 1 - ori_attr->get_pad_list()[kPadUp];
}
} else {
if (idx == split_num - 1) {
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[idx_node][idx] - 1) -
ori_attr->get_pad_list()[kPadDown] - ori_attr->get_pad_list()[kPadUp] +
ori_attr->get_kernel_size()[kAxisH];
} else {
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[idx_node][idx] - 1) -
ori_attr->get_pad_list()[kPadUp] + ori_attr->get_kernel_size()[kAxisH];
}
}
} else if (split_info->axis == CuttingStragedy::CUT_W) { // W
if (idx == split_num - 1) {
tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_reduce_inputs_shape)[idx_node][idx] - 1) -
ori_attr->get_pad_list()[kPadRight] - ori_attr->get_pad_list()[kPadLeft] +
ori_attr->get_kernel_size()[kAxisW];
} else {
tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_reduce_inputs_shape)[idx_node][idx] - 1) -
ori_attr->get_pad_list()[kPadLeft] + ori_attr->get_kernel_size()[kAxisW];
}
}
split_axis_reduce_shape.push_back(tmp);
}
split_axis_inputs_shape->push_back(split_axis_shape);
split_axis_reduce_inputs_shape->push_back(split_axis_reduce_shape);
}
bool CheckPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int32_t splited_axis_value) {
return !(splited_axis_value == ori_attr->get_kernel_size()[kAxisH] && ori_attr->get_pad_list()[kPadUp] == 0 &&
ori_attr->get_pad_list()[kPadDown] == 0);
}
} // namespace
bool IsConv2D(const AnfNodePtr &node) {
return (CheckPrimitiveType(node, prim::kPrimConv2D) || CheckPrimitiveType(node, prim::kPrimConv2DFusion));
}
std::shared_ptr<ops::Conv2DFusion> CopyConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr) {
auto prim = std::make_shared<ops::Conv2DFusion>();
prim->set_pad(ori_attr->get_pad());
prim->set_in_channel(ori_attr->get_in_channel());
prim->set_out_channel(ori_attr->get_out_channel());
prim->set_dilation(ori_attr->get_dilation());
prim->set_format(ori_attr->get_format());
prim->set_group(ori_attr->get_group());
prim->set_kernel_size(ori_attr->get_kernel_size());
prim->set_pad_mode(ori_attr->get_pad_mode());
prim->set_pad_list(ori_attr->get_pad_list());
prim->set_stride(ori_attr->get_stride());
prim->set_activation_type(ori_attr->get_activation_type());
prim->set_pad_list(prim->get_pad_list());
return prim;
}
bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_nodes, SplitInfo *split_info) {
if (split_info->axis != CuttingStragedy::CUT_H) {
return false;
}
auto splited_axis = split_info->axis;
if (split_info->fmk_type == FmkType::FmkType_CAFFE ||
split_info->fmk_type == FmkType::FmkType_ONNX) { // NHWC -> NCHW
splited_axis += 1;
}
const int32_t node_size = conv_nodes.size();
int32_t idx_node = 0;
std::vector<std::vector<ShapeVector>> node_in_out_shapes;
while (idx_node < node_size) {
// [conv3, conv2, conv1] conv1->conv2->conv3
auto out_node_name = conv_nodes[idx_node]->fullname_with_scope();
auto output_shapes = Spliter::GetInstance()->graph_node_output_shapes()[out_node_name];
auto input_shapes = Spliter::GetInstance()->graph_node_input_shapes()[out_node_name];
// 0-> in-shape 1->out-shape
// only one in and one output
node_in_out_shapes.push_back({output_shapes.front(), input_shapes.front()});
idx_node++;
}
const int32_t splited_axis_value = node_in_out_shapes[0][1][splited_axis];
int32_t split_num = split_info->size_splits.size();
std::vector<int32_t> split_axis_out_shape;
std::vector<int32_t> split_axis_reduce_out_shape;
if (!CalSplitOutputShape(splited_axis_value, split_info, &split_axis_out_shape, &split_axis_reduce_out_shape)) {
return false;
}
// infer in-shape after splited
std::vector<std::vector<int32_t>> split_axis_inputs_shape{split_axis_out_shape};
std::vector<std::vector<int32_t>> split_axis_reduce_inputs_shape{split_axis_reduce_out_shape};
idx_node = 0;
// iter node
while (idx_node < node_size) {
auto conv_cnode = conv_nodes[idx_node]->cast<CNodePtr>();
auto ori_attr = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(conv_cnode->input(kAnfPrimitiveIndex));
if (!CheckPrim(ori_attr, splited_axis_value)) {
return false;
}
CalSplitInShape(splited_axis_value, split_info, ori_attr, idx_node, &split_axis_inputs_shape,
&split_axis_reduce_inputs_shape);
idx_node++;
}
// update ratio
split_info->size_splits.clear();
split_info->extend_top.clear();
split_info->extend_bottom.clear();
int32_t top = 0;
int32_t bottom = 0;
split_info->size_splits.push_back(split_axis_inputs_shape[node_size][0]);
split_info->extend_top.push_back(top);
split_info->extend_bottom.push_back(bottom);
for (int32_t i = 1; i < split_num; i++) {
auto begin = split_axis_reduce_inputs_shape[node_size][i] - split_axis_inputs_shape[node_size][i] + 1;
top = split_axis_reduce_inputs_shape[node_size][i - 1] - begin + 1;
auto value = split_axis_inputs_shape[node_size][i] - top;
split_info->size_splits.push_back(value);
split_info->extend_top.push_back(top);
split_info->extend_bottom.push_back(bottom);
}
return true;
}
void GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
std::vector<AnfNodePtr> *outputs) {
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
return;
}
auto cnode = node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(cnode)) {
return;
}
for (size_t i = 0; i < output_num; i++) {
auto idx = NewValueNode(SizeToInt(i));
if (CheckIfValueNodeIsNull(idx)) {
return;
}
size_t temp = SizeToInt(i);
auto imm = std::make_shared<Int32Imm>(temp);
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
idx->set_abstract(abstract_scalar);
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
if (CheckIfCNodeIsNull(tuple_getitem)) {
return;
}
tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i + 1));
outputs->push_back(tuple_getitem);
}
}
AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
const std::vector<AnfNodePtr> &conv_outputs, SplitInfo *split_info,
const std::string &node_name) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(conv_cnode);
int32_t nodes_num = conv_outputs.size();
if (nodes_num != split_info.out_num) {
if (nodes_num != static_cast<int32_t>(split_info->out_num)) {
MS_LOG(ERROR) << "Conv outputs has wrong input size";
return nullptr;
}
auto concat_prim = std::make_shared<ops::Concat>();
concat_prim->set_axis(split_info->axis);
// the inputs of concate are from the outputs of conv
std::vector<AnfNodePtr> concate_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
std::vector<AnfNodePtr> concate_inputs = {NewValueNode(concat_prim)};
for (int32_t i = 0; i < nodes_num; i++) {
concate_inputs.push_back(conv_outputs[i]);
}
@ -49,78 +300,52 @@ AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr
concate_cnode->set_fullname_with_scope(node_name + "_Concat");
concate_cnode->set_scope(conv_cnode->scope());
std::vector<AnfNodePtr> outputs;
GetMultipleOutputsOfAnfNode(func_graph, concate_cnode, 1, &outputs);
return concate_cnode;
}
int32_t GetCOutAxis(int32_t format) {
switch (format) {
case schema::Format_KHWC:
return 0;
case schema::Format_CHWK:
return 3;
case schema::Format_NCHW:
return 0;
default:
MS_LOG(ERROR) << "Do not support format: " << format << " now.";
return -1;
}
}
int32_t GetCInAxis(int32_t format) {
switch (format) {
case schema::Format_KHWC:
return 3;
case schema::Format_CHWK:
return 0;
default:
MS_LOG(ERROR) << "Do not support format: " << format << " now.";
return -1;
}
}
int32_t GetAxis(int32_t axis, int32_t format, const SplitInfo &split_info) {
switch (split_info.primitive_type) {
case mindspore::schema::PrimitiveType_Conv2DFusion:
if (axis == CuttingStragedy::CUT_C_OUT) {
return GetCOutAxis(format);
} else if (axis == CuttingStragedy::CUT_C_IN) {
return GetCInAxis(format);
} else {
MS_LOG(ERROR) << "Only channel_in and channel_out need to transform.";
}
break;
default:
MS_LOG(ERROR) << "Now, do not support the type : " << split_info.primitive_type;
}
return -1;
}
AnfNodePtr CreateOutputsOfAddN(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
const std::string &node_name) {
void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node,
std::vector<AnfNodePtr> *split_outputs, SplitInfo *split_info,
const std::string &node_name) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(conv_cnode);
MS_EXCEPTION_IF_NULL(conv_node);
// attr of split
auto split_prim = std::make_shared<ops::SplitWithOverlap>();
split_prim->set_split_dim(split_info->axis);
split_prim->set_number_split(split_info->out_num);
split_prim->set_ratio(split_info->size_splits);
split_prim->set_extend_top(split_info->extend_top);
split_prim->set_extend_bottom(split_info->extend_bottom);
// default to format khwc or nhwc
split_prim->set_trans_format(true);
int32_t nodes_num = conv_outputs.size();
if (nodes_num != split_info.out_num) {
MS_LOG(ERROR) << "Conv outputs has wrong input size";
return nullptr;
// the inputs of split is from the inputs of conv
std::vector<AnfNodePtr> split_inputs = {NewValueNode(split_prim)};
auto conv_cnode = conv_node->cast<CNodePtr>();
// this conv only has one input, which has been ensured before
split_inputs.push_back(conv_cnode->input(1));
auto split_cnode = func_graph->NewCNode(split_inputs);
MS_EXCEPTION_IF_NULL(split_cnode);
split_cnode->set_fullname_with_scope(node_name + "_Split");
// create outputs op split
GetMultipleOutputsOfAnfNode(func_graph, split_cnode, split_info->out_num, split_outputs);
AbstractBasePtrList ptr_list;
for (size_t i = 0; i < split_info->out_num; i++) {
auto node = (*split_outputs)[i];
// set date_type same with weight
auto type_id = static_cast<TypeId>(kNumberTypeFloat32);
auto type_ptr = TypeIdToType(type_id);
std::vector<int64_t> shape_vector;
auto value_node = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
ptr_list.push_back(value_node);
}
// the inputs of addn are from the outputs of conv
std::vector<AnfNodePtr> addn_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))};
for (int32_t i = 0; i < nodes_num; i++) {
addn_inputs.push_back(conv_outputs[i]);
}
auto addn_cnode = func_graph->NewCNode(addn_inputs);
MS_EXCEPTION_IF_NULL(addn_cnode);
addn_cnode->set_fullname_with_scope(node_name + "_AddN");
addn_cnode->set_scope(conv_cnode->scope());
return addn_cnode;
split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
}
} // namespace opt
} // namespace mindspore

View File

@ -20,47 +20,49 @@
#include <vector>
#include <string>
#include <unordered_map>
#include <memory>
#include "schema/inner/model_generated.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/converter_flags.h"
#include "mindspore/lite/include/context.h"
#include "mindspore/lite/include/lite_types.h"
#include "ops/fusion/conv2d_fusion.h"
namespace mindspore {
using mindspore::schema::PrimitiveType;
namespace opt {
struct SplitInfo {
int32_t axis;
int32_t out_num;
std::vector<int32_t> size_splits;
std::vector<int32_t> extend_top;
std::vector<int32_t> extend_bottom;
int64_t axis;
size_t out_num;
std::vector<int64_t> size_splits;
std::vector<int64_t> extend_top;
std::vector<int64_t> extend_bottom;
std::vector<mindspore::lite::DeviceType> dev_types;
int32_t in_num_conv;
int32_t fmk_type;
std::vector<int32_t> weight_channel;
int64_t in_num_conv;
int64_t fmk_type;
std::vector<int64_t> weight_channel;
PrimitiveType primitive_type;
};
typedef enum { CUT_N, CUT_H, CUT_W, CUT_C_IN, CUT_C_OUT, CUT_NONE } CuttingStragedy;
bool IsConv2D(const AnfNodePtr &node);
std::shared_ptr<ops::Conv2DFusion> CopyConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr);
bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_nodes, SplitInfo *split_info);
void GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
std::vector<AnfNodePtr> *outputs);
AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
const std::vector<AnfNodePtr> &conv_outputs, SplitInfo *split_info,
const std::string &node_name);
void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
std::vector<AnfNodePtr> *split_outputs, const SplitInfo &split_info,
void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
std::vector<AnfNodePtr> *split_outputs, SplitInfo *split_info,
const std::string &node_name);
void GetCNodeShapeInfo(const FuncGraphPtr &func_graph, int32_t fmk_type);
AnfNodePtr CreateOutputsOfAddN(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
const std::string &node_name);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_FISSON_UTIL_H_

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <memory>
#include "mindspore/ccsrc/utils/utils.h"
#include "mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "mindspore/core/base/base.h"
#include "mindspore/core/ops/fusion/conv2d_fusion.h"
using mindspore::lite::converter::FmkType;
using mindspore::schema::PrimitiveType_Conv2dTransposeFusion;
namespace mindspore {
namespace opt {
const BaseRef MultiConvSplitPass::DefinePattern() const {
auto conv1_var = std::make_shared<CondVar>(IsConvNode);
auto conv1_other_var = std::make_shared<SeqVar>();
VectorRef res = VectorRef({conv1_var, conv1_other_var});
int32_t idx = 1;
while (idx < num_) {
auto tmp_var = std::make_shared<CondVar>(IsConvNode);
auto tmp_other_var = std::make_shared<SeqVar>();
res = VectorRef({tmp_var, res, tmp_other_var});
idx++;
}
return res;
}
const AnfNodePtr MultiConvSplitPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_LOG(INFO) << "---Enter pass MultiConvSplit.";
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
auto device_type_attr = cnode->GetAttr(mindspore::ops::kDeviceType);
auto device_type = (device_type_attr != nullptr) ? GetValue<int32_t>(device_type_attr) : kDeviceTypeNone;
if (device_type != kDeviceTypeNone) {
return node;
}
std::shared_ptr<MultiNodeSplitProxy> multi_node_split_proxy =
std::make_shared<MultiNodeSplitProxy>(strategy_, primitive_type_, fmk_type_, num_);
return multi_node_split_proxy->DoSplit(func_graph, node);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_MULTI_CONV_SPLIT_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_MULTI_CONV_SPLIT_H_
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "tools/optimizer/fisson/fisson_util.h"
#include "tools/optimizer/parallel/split_strategy.h"
#include "schema/inner/model_generated.h"
#include "tools/optimizer/parallel/multi_node_split.h"
using mindspore::schema::PrimitiveType;
namespace mindspore {
namespace opt {
class MultiConvSplitPass : public PatternProcessPass {
public:
explicit MultiConvSplitPass(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1,
int32_t num = 3, bool multigraph = true)
: PatternProcessPass("multi_conv_split", multigraph),
strategy_(strategy),
primitive_type_(primitive_type),
fmk_type_(fmk_type),
num_(num) {}
~MultiConvSplitPass() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
SplitStrategy strategy_{};
PrimitiveType primitive_type_{schema::PrimitiveType_NONE};
int32_t fmk_type_{-1};
int32_t num_{0};
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_MULTI_CONV_SPLIT_H_

View File

@ -30,20 +30,6 @@
using mindspore::schema::PrimitiveType_Conv2DFusion;
namespace mindspore {
namespace opt {
// strategy format is NHWC-KHWC
constexpr int32_t kAxisN = 0;
constexpr int32_t kAxisCIn = 3;
constexpr int32_t kAxisCOut = 0;
constexpr int32_t kAxisH = 1;
constexpr int32_t kAxisW = 2;
constexpr auto kIndexH = 0;
constexpr auto kIndexW = 1;
constexpr auto kPadUp = 0;
constexpr auto kPadDown = 1;
constexpr auto kPadLeft = 2;
constexpr auto kPadRight = 3;
int Conv2DInfo::GetAttrs() { return lite::RET_OK; }

View File

@ -0,0 +1,218 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/parallel/multi_conv_info.h"
#include <string>
#include <algorithm>
#include "tools/optimizer/parallel/spliter.h"
#include "ops/fusion/conv2d_fusion.h"
#include "tools/optimizer/parallel/split_strategy.h"
using mindspore::lite::converter::FmkType;
using mindspore::schema::PrimitiveType_Conv2dTransposeFusion;
namespace mindspore {
namespace opt {
int MultiConvSplit ::GenSplitInfo() {
split_info_.out_num = this->strategy_.dev_num;
for (const auto &dev_type : this->strategy_.dev_types) {
if (dev_type == "CPU") {
split_info_.dev_types.push_back(mindspore::lite::DeviceType::DT_CPU);
} else if (dev_type == "GPU") {
split_info_.dev_types.push_back(mindspore::lite::DeviceType::DT_GPU);
} else if (dev_type == "NPU") {
split_info_.dev_types.push_back(mindspore::lite::DeviceType::DT_NPU);
} else {
MS_LOG(ERROR) << "Do not support DeviceType:" << dev_type << "now.";
return RET_ERROR;
}
}
// only can get N && H && CIN &&
std::vector<int64_t> tmp(split_info_.out_num, 0);
for (size_t i = 0; i < this->strategy_.strategys[0].size(); i++) {
if (this->strategy_.strategys[0][i] == tmp) {
continue;
}
split_info_.axis = i; // NHWC
split_info_.size_splits.clear();
split_info_.size_splits = this->strategy_.strategys[0][i]; // cal base on compute_cap
break;
}
split_info_.in_num_conv = num_;
split_info_.fmk_type = fmk_type_;
split_info_.extend_bottom = std::vector<int64_t>(split_info_.size_splits.size(), 0);
split_info_.extend_top = std::vector<int64_t>(split_info_.size_splits.size(), 0);
split_info_.primitive_type = primitive_type_;
return RET_OK;
}
int MultiConvSplit::GetMultiConvNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(conv_node);
// get nodes to be splited
// node in graph 1->2->3...
// node in vector ...->3->2->1
std::string conv_cnode_name = conv_node->fullname_with_scope();
MS_LOG(INFO) << "---node name:" << conv_cnode_name;
auto graph_node_outputs = Spliter::GetInstance()->graph_node_outputs();
auto it = graph_node_outputs.find(conv_cnode_name);
if (it == graph_node_outputs.end()) {
MS_LOG(INFO) << "This node may be the last node of graph,it do not has any out-nodes.";
return RET_ERROR;
}
conv_nodes_.push_back(conv_node);
int32_t idx = 0;
while (idx < split_info_.in_num_conv - 1) {
auto curr_node = conv_nodes_[idx];
auto curr_cnode = conv_nodes_[idx]->cast<CNodePtr>();
auto tmp_node = curr_cnode->input(1);
if (IsConv2D(tmp_node)) {
break;
}
auto name = tmp_node->fullname_with_scope();
// check outputs's bigger than two
it = graph_node_outputs.find(name);
if (it == graph_node_outputs.end()) {
return RET_ERROR;
}
if (it->second.size() > 1) {
break;
}
conv_nodes_.push_back(tmp_node);
idx++;
}
// no need split in multi_node_pass
if (conv_nodes_.size() < 2) {
return RET_ERROR;
}
return RET_OK;
}
AnfNodePtr MultiConvSplit::MultiConvNHSplit(const AnfNodePtr &node) {
std::string conv_cnode_name = node->fullname_with_scope();
// Create Split node and get outputs of Split
std::vector<AnfNodePtr> split_outputs;
CreateOutputsOfSplitWithOverlap(func_graph_, conv_nodes_[conv_nodes_.size() - 1], &split_outputs, &split_info_,
conv_cnode_name);
// Create Conv node
for (int32_t i = conv_nodes_.size() - 1; i >= 0; i--) {
std::vector<AnfNodePtr> outputs_node;
SplitSingleConv(conv_nodes_[i], split_outputs, {}, {}, &outputs_node);
split_outputs.clear();
std::copy(outputs_node.begin(), outputs_node.end(), std::back_inserter(split_outputs));
outputs_node.clear();
}
// Create concate node
auto concat_node = CreateOutputsOfConcat(func_graph_, node, split_outputs, &split_info_, conv_cnode_name);
split_outputs.clear();
return concat_node;
}
void MultiConvSplit::SplitSingleConv(const AnfNodePtr &ori_node, const std::vector<AnfNodePtr> &inputs_node,
const std::vector<AnfNodePtr> &weight_node,
const std::vector<AnfNodePtr> &bias_nodes, std::vector<AnfNodePtr> *outputs_node) {
auto ori_conv_cnode = ori_node->cast<CNodePtr>();
auto ori_attr = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(ori_conv_cnode->input(kAnfPrimitiveIndex));
for (int32_t output_conv_index = 0; output_conv_index < static_cast<int32_t>(split_info_.out_num);
output_conv_index++) {
// Create Conv node attr
auto conv_prim = CopyConvPrim(ori_attr);
// adjust primitive
AdJustConvPrim(conv_prim, output_conv_index);
// node inputs
std::vector<AnfNodePtr> conv_inputs;
conv_inputs.push_back(NewValueNode(conv_prim));
AdJustInputs(ori_node, inputs_node, weight_node, bias_nodes, output_conv_index, &conv_inputs);
// create new conv node
CreateNewConvNode(ori_node, conv_inputs, output_conv_index, outputs_node);
}
}
void MultiConvSplit::AdJustInputs(const AnfNodePtr &ori_conv_node, const std::vector<AnfNodePtr> &new_inputs_node,
const std::vector<AnfNodePtr> &weight_node, const std::vector<AnfNodePtr> &bias_nodes,
int output_conv_index, std::vector<AnfNodePtr> *conv_inputs) {
auto ori_conv_cnode = ori_conv_node->cast<CNodePtr>();
// feature_map
conv_inputs->push_back(new_inputs_node[output_conv_index]);
// W+bias
for (size_t j = 2; j < ori_conv_cnode->size(); j++) {
conv_inputs->push_back(ori_conv_cnode->input(j));
}
}
void MultiConvSplit::CreateNewConvNode(const AnfNodePtr &ori_conv_node, const std::vector<AnfNodePtr> &conv_inputs,
int output_conv_index, std::vector<AnfNodePtr> *outputs_node) {
auto ori_conv_cnode = ori_conv_node->cast<CNodePtr>();
std::string ori_cnode_name = ori_conv_cnode->fullname_with_scope();
// new conv_node
auto conv_cnode = func_graph_->NewCNode(conv_inputs);
conv_cnode->set_fullname_with_scope(ori_cnode_name + "_" + PARALLEL_NAME_SUFFIX +
std::to_string(output_conv_index + 1));
conv_cnode->AddAttr(mindspore::ops::kDeviceType,
MakeValue(static_cast<int>(split_info_.dev_types[output_conv_index])));
std::vector<AnfNodePtr> tmp_outputs;
// conv2d only has one output, set to output_nodes
GetMultipleOutputsOfAnfNode(func_graph_, conv_cnode, 1, &tmp_outputs);
outputs_node->push_back(tmp_outputs[0]->cast<CNodePtr>()->input(1));
tmp_outputs.clear();
}
AnfNodePtr MultiConvSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
int ret = GenSplitInfo();
if (ret != RET_OK) {
return node;
}
func_graph_ = func_graph;
ret = GetMultiConvNodes(func_graph, node);
if (ret != RET_OK) {
return node;
}
return SplitMultiConv(node);
}
AnfNodePtr MultiConvSplitN::SplitMultiConv(const AnfNodePtr &node) {
if (conv_nodes_.size() == 2 && split_info_.axis == CuttingStragedy::CUT_N) {
return node;
}
return MultiConvNHSplit(node);
}
AnfNodePtr MultiConvSplitH::SplitMultiConv(const AnfNodePtr &node) {
// update info, N do not need, C do not support
if (!UpdateSplitInfo(func_graph_, conv_nodes_, &split_info_)) {
return node;
}
return MultiConvNHSplit(node);
}
void MultiConvSplitH::AdJustConvPrim(const std::shared_ptr<ops::Conv2DFusion> &conv_prim, int output_conv_index) {
auto pad_list = conv_prim->get_pad_list();
if (output_conv_index == 0) {
pad_list[kPadDown] = 0;
} else if (output_conv_index == static_cast<int32_t>(split_info_.out_num - 1)) {
pad_list[kPadUp] = 0;
} else {
pad_list[kPadUp] = 0;
pad_list[kPadDown] = 0;
}
conv_prim->set_pad_list(pad_list);
}
AnfNodePtr MultiConvSplitCIN::SplitMultiConv(const AnfNodePtr &node) { return nullptr; }
AnfNodePtr MultiConvSplitCOUT::SplitMultiConv(const AnfNodePtr &node) { return nullptr; }
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,107 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_CONV_INFO_H
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_CONV_INFO_H
#include <memory>
#include <vector>
#include "tools/optimizer/parallel/multi_node_split.h"
#include "tools/optimizer/fisson/fisson_util.h"
#include "ops/fusion/conv2d_fusion.h"
namespace mindspore {
namespace opt {
class MultiConvSplit : public MultiNodeSplit {
public:
explicit MultiConvSplit(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1,
int32_t num = 3)
: MultiNodeSplit(), strategy_(strategy), primitive_type_(primitive_type), fmk_type_(fmk_type), num_(num) {}
AnfNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
virtual AnfNodePtr SplitMultiConv(const AnfNodePtr &node) = 0;
virtual void AdJustConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int output_conv_index) = 0;
virtual AnfNodePtr MultiConvNHSplit(const AnfNodePtr &node);
virtual void AdJustInputs(const AnfNodePtr &ori_node, const std::vector<AnfNodePtr> &new_inputs_node,
const std::vector<AnfNodePtr> &weight_node, const std::vector<AnfNodePtr> &bias_nodes,
int output_conv_index, std::vector<AnfNodePtr> *conv_inputs);
virtual void CreateNewConvNode(const AnfNodePtr &ori_conv_node, const std::vector<AnfNodePtr> &conv_inputs,
int output_conv_index, std::vector<AnfNodePtr> *outputs_node);
virtual void SplitSingleConv(const AnfNodePtr &ori_node, const std::vector<AnfNodePtr> &inputs_node,
const std::vector<AnfNodePtr> &weight_node, const std::vector<AnfNodePtr> &bias_nodes,
std::vector<AnfNodePtr> *outputs_node);
protected:
FuncGraphPtr func_graph_{nullptr};
SplitInfo split_info_{};
SplitStrategy strategy_{};
PrimitiveType primitive_type_{schema::PrimitiveType_NONE};
int32_t fmk_type_{-1};
int32_t num_{0};
std::vector<AnfNodePtr> conv_nodes_{};
private:
int GenSplitInfo();
int GetMultiConvNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node);
};
class MultiConvSplitN final : public MultiConvSplit {
public:
MultiConvSplitN(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, int32_t num = 3)
: MultiConvSplit(strategy, primitive_type, fmk_type, num) {}
AnfNodePtr SplitMultiConv(const AnfNodePtr &node) override;
void AdJustConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int output_conv_index) override {}
};
class MultiConvSplitCIN final : public MultiConvSplit {
public:
MultiConvSplitCIN(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, int32_t num = 3)
: MultiConvSplit(strategy, primitive_type, fmk_type, num) {}
AnfNodePtr SplitMultiConv(const AnfNodePtr &node) override;
void AdJustConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int output_conv_index) override {}
};
class MultiConvSplitCOUT final : public MultiConvSplit {
public:
MultiConvSplitCOUT(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1,
int32_t num = 3)
: MultiConvSplit(strategy, primitive_type, fmk_type, num) {}
AnfNodePtr SplitMultiConv(const AnfNodePtr &node) override;
void AdJustConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int output_conv_index) override {}
};
class MultiConvSplitH final : public MultiConvSplit {
public:
MultiConvSplitH(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, int32_t num = 3)
: MultiConvSplit(strategy, primitive_type, fmk_type, num) {}
AnfNodePtr SplitMultiConv(const AnfNodePtr &node) override;
void AdJustConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int output_conv_index) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_CONV_INFO_H

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/parallel/multi_node_split.h"
#include "tools/optimizer/parallel/multi_conv_info.h"
namespace mindspore {
namespace opt {
int MultiNodeSplitProxy::InitResource() {
switch (split_mode_) {
case SplitN:
multi_node_split_ = std::make_shared<MultiConvSplitN>(strategy_, primitive_type_, fmk_type_, num_);
return RET_OK;
case SplitH:
multi_node_split_ = std::make_shared<MultiConvSplitH>(strategy_, primitive_type_, fmk_type_, num_);
return RET_OK;
case SplitCIN:
multi_node_split_ = std::make_shared<MultiConvSplitCIN>(strategy_, primitive_type_, fmk_type_, num_);
return RET_OK;
case SplitCOUT:
multi_node_split_ = std::make_shared<MultiConvSplitCOUT>(strategy_, primitive_type_, fmk_type_, num_);
return RET_OK;
default:
return RET_ERROR;
}
}
int MultiNodeSplitProxy::FreeResource() {
multi_node_split_ = nullptr;
return RET_OK;
}
AnfNodePtr MultiNodeSplitProxy::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
int ret = InitResource();
if (ret != RET_OK) {
return node;
}
auto res_node = multi_node_split_->DoSplit(func_graph, node);
ret = FreeResource();
if (ret != RET_OK) {
return node;
}
return res_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_NODE_SPLIT_H
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_NODE_SPLIT_H
#include <utility>
#include <memory>
#include "tools/optimizer/parallel/split_strategy.h"
#include "schema/inner/model_generated.h"
#include "include/errorcode.h"
#include "base/base.h"
using mindspore::schema::PrimitiveType;
namespace mindspore {
namespace opt {
class MultiNodeSplit {
public:
MultiNodeSplit() = default;
virtual ~MultiNodeSplit() = default;
virtual AnfNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0;
};
class MultiNodeSplitProxy : public MultiNodeSplit {
public:
explicit MultiNodeSplitProxy(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1,
int32_t num = 3)
: MultiNodeSplit(), strategy_(strategy), primitive_type_(primitive_type), fmk_type_(fmk_type), num_(num) {}
~MultiNodeSplitProxy() override = default;
AnfNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
private:
int InitResource();
int FreeResource();
private:
SplitMode split_mode_{NoSplit};
SplitStrategy strategy_{};
PrimitiveType primitive_type_{schema::PrimitiveType_NONE};
int32_t fmk_type_{-1};
int32_t num_{0};
std::shared_ptr<MultiNodeSplit> multi_node_split_{nullptr};
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_NODE_SPLIT_H

View File

@ -38,6 +38,22 @@ const std::vector<std::string> kSplitDevTypes = {"CPU", "GPU"};
using Strategys = std::vector<std::vector<std::vector<int64_t>>>;
constexpr auto kDeviceTypeNone = -1;
// strategy format is NHWC-KHWC
constexpr int32_t kAxisN = 0;
constexpr int32_t kAxisCIn = 3;
constexpr int32_t kAxisCOut = 0;
constexpr int32_t kAxisH = 1;
constexpr int32_t kAxisW = 2;
constexpr auto kIndexH = 0;
constexpr auto kIndexW = 1;
constexpr auto kPadUp = 0;
constexpr auto kPadDown = 1;
constexpr auto kPadLeft = 2;
constexpr auto kPadRight = 3;
enum SplitMode {
NoSplit = 0,
SplitN = 1,