!16023 add multi_conv parallel split
From: @zoloft Reviewed-by: Signed-off-by:
This commit is contained in:
commit
e54c810c54
|
@ -287,12 +287,15 @@ if(ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/fisson/fisson_util.cc
|
||||
${LITE_DIR}/tools/optimizer/fisson/iter_node_outputs.cc
|
||||
${LITE_DIR}/tools/optimizer/fisson/node_out_shapes.cc
|
||||
${LITE_DIR}/tools/optimizer/fisson/multi_conv_split_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/multi_node_split.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/multi_conv_info.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/parallel_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/operator_info.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/split_strategy.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/operator_info_register.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/spliter.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/conv2d_info.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/spliter.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/split_strategy.cc
|
||||
${LITE_DIR}/tools/common/graph_util.cc
|
||||
${LITE_DIR}/tools/common/tensor_util.cc
|
||||
${LITE_DIR}/tools/common/node_util.cc
|
||||
|
|
|
@ -65,12 +65,15 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/fisson/fisson_util.cc
|
||||
../optimizer/fisson/iter_node_outputs.cc
|
||||
../optimizer/fisson/node_out_shapes.cc
|
||||
../optimizer/fisson/multi_conv_split_pass.cc
|
||||
../optimizer/parallel/multi_node_split.cc
|
||||
../optimizer/parallel/multi_conv_info.cc
|
||||
../optimizer/parallel/parallel_pass.cc
|
||||
../optimizer/parallel/conv2d_info.cc
|
||||
../optimizer/parallel/operator_info.cc
|
||||
../optimizer/parallel/parallel_pass.cc
|
||||
../optimizer/parallel/split_strategy.cc
|
||||
../optimizer/parallel/operator_info_register.cc
|
||||
../optimizer/parallel/spliter.cc
|
||||
../optimizer/parallel/split_strategy.cc
|
||||
../optimizer/graph/conv1d_inout_adjust_pass.cc
|
||||
../optimizer/graph/weight_format_transform_pass.cc
|
||||
../optimizer/graph/weight_format_hardcode_pass.cc
|
||||
|
|
|
@ -71,6 +71,7 @@
|
|||
#include "tools/optimizer/fisson/node_out_shapes.h"
|
||||
#include "tools/optimizer/parallel/parallel_pass.h"
|
||||
#include "tools/converter/registry/pass_registry.h"
|
||||
#include "tools/optimizer/fisson/multi_conv_split_pass.h"
|
||||
|
||||
using std::string;
|
||||
namespace mindspore::lite {
|
||||
|
@ -128,13 +129,12 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
|||
|
||||
int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
MS_LOG(DEBUG) << "Run ParallelPass start";
|
||||
if (config->trainModel || static_cast<opt::SplitMode>(config->parallelMode) == opt::NoSplit) {
|
||||
if (config->trainModel || !config->parallelMode) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
// 1. deal with split strategy
|
||||
std::unordered_map<std::string, opt::SplitStrategy> split_strategys =
|
||||
ParserSplitStrategy(static_cast<opt::SplitMode>(config->parallelMode));
|
||||
std::unordered_map<std::string, opt::SplitStrategy> split_strategys = ParserSplitStrategy(opt::SplitH);
|
||||
if (split_strategys.empty()) {
|
||||
MS_LOG(ERROR) << "parse split_strategy error.";
|
||||
return RET_OK;
|
||||
|
@ -144,7 +144,10 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter
|
|||
parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
|
||||
parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
|
||||
// 3. multi_conv parallel pass
|
||||
parallel_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
|
||||
auto strategy = split_strategys.begin()->second;
|
||||
parallel_pm->AddPass(
|
||||
std::make_shared<opt::MultiConvSplitPass>(strategy, schema::PrimitiveType_Conv2DFusion, config->fmk, 3));
|
||||
parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
|
||||
// 4. single conv parallel pass
|
||||
parallel_pm->AddPass(std::make_shared<opt::ParallelPass>(split_strategys, config->fmk));
|
||||
optimizer->AddPassManager(parallel_pm);
|
||||
|
|
|
@ -80,7 +80,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
int quantWeightSize;
|
||||
std::string bitNumIn;
|
||||
int bitNum;
|
||||
int parallelMode = 0;
|
||||
bool parallelMode = false;
|
||||
std::string configFile;
|
||||
std::string quantWeightChannelStr;
|
||||
int quantWeightChannel;
|
||||
|
|
|
@ -15,31 +15,282 @@
|
|||
*/
|
||||
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include "tools/optimizer/fisson/fisson_util.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "mindspore/core/ops/split_with_overlap.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "ops/concat.h"
|
||||
#include "tools/optimizer/parallel/spliter.h"
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
namespace mindspore {
|
||||
using lite::converter::FmkType;
|
||||
|
||||
namespace opt {
|
||||
|
||||
AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
|
||||
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
|
||||
namespace {
|
||||
|
||||
bool CalSplitOutputShape(int32_t splited_axis_value, const SplitInfo *split_info,
|
||||
std::vector<int32_t> *split_axis_out_shape,
|
||||
std::vector<int32_t> *split_axis_reduce_out_shape) {
|
||||
// ori ratio
|
||||
int32_t split_num = split_info->size_splits.size();
|
||||
int32_t split_len = 0;
|
||||
for (int32_t i = 0; i < split_num; i++) {
|
||||
split_len += split_info->size_splits[i];
|
||||
}
|
||||
if (split_len > splited_axis_value) {
|
||||
return false;
|
||||
}
|
||||
// out-shape after splited
|
||||
int32_t tmp_value = 0;
|
||||
for (int32_t i = 0; i < split_num - 1; i++) {
|
||||
int32_t tmp = (split_info->size_splits[i] * splited_axis_value) / split_len;
|
||||
tmp_value += tmp;
|
||||
split_axis_out_shape->push_back(tmp);
|
||||
split_axis_reduce_out_shape->push_back(tmp_value);
|
||||
}
|
||||
split_axis_out_shape->push_back(splited_axis_value - tmp_value);
|
||||
split_axis_reduce_out_shape->push_back(splited_axis_value);
|
||||
return true;
|
||||
}
|
||||
|
||||
void CalSplitInShape(int32_t splited_axis_value, const SplitInfo *split_info,
|
||||
const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int32_t idx_node,
|
||||
std::vector<std::vector<int32_t>> *split_axis_inputs_shape,
|
||||
std::vector<std::vector<int32_t>> *split_axis_reduce_inputs_shape) {
|
||||
int32_t split_num = split_info->size_splits.size();
|
||||
int32_t tmp = 0;
|
||||
std::vector<int32_t> split_axis_shape;
|
||||
std::vector<int32_t> split_axis_reduce_shape;
|
||||
|
||||
// iter splited_num
|
||||
for (int32_t idx = 0; idx < split_num; idx++) {
|
||||
// shape
|
||||
if (split_info->axis == CuttingStragedy::CUT_H) { // H
|
||||
if ((splited_axis_value + ori_attr->get_pad_list()[kPadUp] + ori_attr->get_pad_list()[kPadDown] -
|
||||
(ori_attr->get_kernel_size()[kAxisH] - 1)) %
|
||||
ori_attr->get_stride()[kIndexH] ==
|
||||
0) {
|
||||
if (idx == 0) {
|
||||
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx]) +
|
||||
(ori_attr->get_kernel_size()[kAxisH] - 1) - ori_attr->get_pad_list()[kPadUp];
|
||||
} else if (idx == split_num - 1) {
|
||||
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx]) +
|
||||
(ori_attr->get_kernel_size()[kAxisH] - 1) - ori_attr->get_pad_list()[kPadDown];
|
||||
} else {
|
||||
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx]) +
|
||||
(ori_attr->get_kernel_size()[kAxisH] - 1) - 0;
|
||||
}
|
||||
} else {
|
||||
if (idx == 0) {
|
||||
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) -
|
||||
ori_attr->get_pad_list()[kPadUp] + ori_attr->get_kernel_size()[kAxisH];
|
||||
} else if (idx == split_num - 1) {
|
||||
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) -
|
||||
ori_attr->get_pad_list()[kPadDown] + ori_attr->get_kernel_size()[kAxisH];
|
||||
} else {
|
||||
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) - 0 +
|
||||
ori_attr->get_kernel_size()[kAxisH];
|
||||
}
|
||||
}
|
||||
|
||||
} else if (split_info->axis == CuttingStragedy::CUT_W) { // W
|
||||
if (idx == 0) {
|
||||
tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) -
|
||||
ori_attr->get_pad_list()[kPadLeft] + ori_attr->get_kernel_size()[kAxisW];
|
||||
} else if (idx == split_num - 1) {
|
||||
tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) -
|
||||
ori_attr->get_pad_list()[kPadRight] + ori_attr->get_kernel_size()[kAxisW];
|
||||
} else {
|
||||
tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_inputs_shape)[idx_node][idx] - 1) - 0 +
|
||||
ori_attr->get_kernel_size()[kAxisW];
|
||||
}
|
||||
}
|
||||
split_axis_shape.push_back(tmp);
|
||||
|
||||
// reduce shape
|
||||
if (split_info->axis == CuttingStragedy::CUT_H) { // H
|
||||
if ((splited_axis_value + ori_attr->get_pad_list()[kPadUp] + ori_attr->get_pad_list()[kPadDown] -
|
||||
(ori_attr->get_kernel_size()[kAxisH] - 1)) %
|
||||
ori_attr->get_stride()[kIndexH] ==
|
||||
0) {
|
||||
if (idx == split_num - 1) {
|
||||
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[idx_node][idx]) +
|
||||
ori_attr->get_kernel_size()[kAxisH] - 1 - ori_attr->get_pad_list()[kPadDown] -
|
||||
ori_attr->get_pad_list()[kPadUp];
|
||||
} else {
|
||||
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[idx_node][idx]) +
|
||||
ori_attr->get_kernel_size()[kAxisH] - 1 - ori_attr->get_pad_list()[kPadUp];
|
||||
}
|
||||
} else {
|
||||
if (idx == split_num - 1) {
|
||||
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[idx_node][idx] - 1) -
|
||||
ori_attr->get_pad_list()[kPadDown] - ori_attr->get_pad_list()[kPadUp] +
|
||||
ori_attr->get_kernel_size()[kAxisH];
|
||||
} else {
|
||||
tmp = ori_attr->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[idx_node][idx] - 1) -
|
||||
ori_attr->get_pad_list()[kPadUp] + ori_attr->get_kernel_size()[kAxisH];
|
||||
}
|
||||
}
|
||||
} else if (split_info->axis == CuttingStragedy::CUT_W) { // W
|
||||
if (idx == split_num - 1) {
|
||||
tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_reduce_inputs_shape)[idx_node][idx] - 1) -
|
||||
ori_attr->get_pad_list()[kPadRight] - ori_attr->get_pad_list()[kPadLeft] +
|
||||
ori_attr->get_kernel_size()[kAxisW];
|
||||
} else {
|
||||
tmp = ori_attr->get_stride()[kIndexW] * ((*split_axis_reduce_inputs_shape)[idx_node][idx] - 1) -
|
||||
ori_attr->get_pad_list()[kPadLeft] + ori_attr->get_kernel_size()[kAxisW];
|
||||
}
|
||||
}
|
||||
split_axis_reduce_shape.push_back(tmp);
|
||||
}
|
||||
split_axis_inputs_shape->push_back(split_axis_shape);
|
||||
split_axis_reduce_inputs_shape->push_back(split_axis_reduce_shape);
|
||||
}
|
||||
|
||||
bool CheckPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int32_t splited_axis_value) {
|
||||
return !(splited_axis_value == ori_attr->get_kernel_size()[kAxisH] && ori_attr->get_pad_list()[kPadUp] == 0 &&
|
||||
ori_attr->get_pad_list()[kPadDown] == 0);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool IsConv2D(const AnfNodePtr &node) {
|
||||
return (CheckPrimitiveType(node, prim::kPrimConv2D) || CheckPrimitiveType(node, prim::kPrimConv2DFusion));
|
||||
}
|
||||
|
||||
std::shared_ptr<ops::Conv2DFusion> CopyConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr) {
|
||||
auto prim = std::make_shared<ops::Conv2DFusion>();
|
||||
prim->set_pad(ori_attr->get_pad());
|
||||
prim->set_in_channel(ori_attr->get_in_channel());
|
||||
prim->set_out_channel(ori_attr->get_out_channel());
|
||||
prim->set_dilation(ori_attr->get_dilation());
|
||||
prim->set_format(ori_attr->get_format());
|
||||
prim->set_group(ori_attr->get_group());
|
||||
prim->set_kernel_size(ori_attr->get_kernel_size());
|
||||
prim->set_pad_mode(ori_attr->get_pad_mode());
|
||||
prim->set_pad_list(ori_attr->get_pad_list());
|
||||
prim->set_stride(ori_attr->get_stride());
|
||||
prim->set_activation_type(ori_attr->get_activation_type());
|
||||
prim->set_pad_list(prim->get_pad_list());
|
||||
return prim;
|
||||
}
|
||||
|
||||
bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_nodes, SplitInfo *split_info) {
|
||||
if (split_info->axis != CuttingStragedy::CUT_H) {
|
||||
return false;
|
||||
}
|
||||
auto splited_axis = split_info->axis;
|
||||
if (split_info->fmk_type == FmkType::FmkType_CAFFE ||
|
||||
split_info->fmk_type == FmkType::FmkType_ONNX) { // NHWC -> NCHW
|
||||
splited_axis += 1;
|
||||
}
|
||||
|
||||
const int32_t node_size = conv_nodes.size();
|
||||
int32_t idx_node = 0;
|
||||
std::vector<std::vector<ShapeVector>> node_in_out_shapes;
|
||||
while (idx_node < node_size) {
|
||||
// [conv3, conv2, conv1] conv1->conv2->conv3
|
||||
auto out_node_name = conv_nodes[idx_node]->fullname_with_scope();
|
||||
auto output_shapes = Spliter::GetInstance()->graph_node_output_shapes()[out_node_name];
|
||||
auto input_shapes = Spliter::GetInstance()->graph_node_input_shapes()[out_node_name];
|
||||
// 0-> in-shape 1->out-shape
|
||||
// only one in and one output
|
||||
node_in_out_shapes.push_back({output_shapes.front(), input_shapes.front()});
|
||||
idx_node++;
|
||||
}
|
||||
|
||||
const int32_t splited_axis_value = node_in_out_shapes[0][1][splited_axis];
|
||||
int32_t split_num = split_info->size_splits.size();
|
||||
std::vector<int32_t> split_axis_out_shape;
|
||||
std::vector<int32_t> split_axis_reduce_out_shape;
|
||||
if (!CalSplitOutputShape(splited_axis_value, split_info, &split_axis_out_shape, &split_axis_reduce_out_shape)) {
|
||||
return false;
|
||||
}
|
||||
// infer in-shape after splited
|
||||
std::vector<std::vector<int32_t>> split_axis_inputs_shape{split_axis_out_shape};
|
||||
std::vector<std::vector<int32_t>> split_axis_reduce_inputs_shape{split_axis_reduce_out_shape};
|
||||
idx_node = 0;
|
||||
// iter node
|
||||
while (idx_node < node_size) {
|
||||
auto conv_cnode = conv_nodes[idx_node]->cast<CNodePtr>();
|
||||
auto ori_attr = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(conv_cnode->input(kAnfPrimitiveIndex));
|
||||
if (!CheckPrim(ori_attr, splited_axis_value)) {
|
||||
return false;
|
||||
}
|
||||
CalSplitInShape(splited_axis_value, split_info, ori_attr, idx_node, &split_axis_inputs_shape,
|
||||
&split_axis_reduce_inputs_shape);
|
||||
idx_node++;
|
||||
}
|
||||
|
||||
// update ratio
|
||||
split_info->size_splits.clear();
|
||||
split_info->extend_top.clear();
|
||||
split_info->extend_bottom.clear();
|
||||
|
||||
int32_t top = 0;
|
||||
int32_t bottom = 0;
|
||||
split_info->size_splits.push_back(split_axis_inputs_shape[node_size][0]);
|
||||
split_info->extend_top.push_back(top);
|
||||
split_info->extend_bottom.push_back(bottom);
|
||||
|
||||
for (int32_t i = 1; i < split_num; i++) {
|
||||
auto begin = split_axis_reduce_inputs_shape[node_size][i] - split_axis_inputs_shape[node_size][i] + 1;
|
||||
top = split_axis_reduce_inputs_shape[node_size][i - 1] - begin + 1;
|
||||
auto value = split_axis_inputs_shape[node_size][i] - top;
|
||||
split_info->size_splits.push_back(value);
|
||||
split_info->extend_top.push_back(top);
|
||||
split_info->extend_bottom.push_back(bottom);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
|
||||
std::vector<AnfNodePtr> *outputs) {
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
return;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (CheckIfCNodeIsNull(cnode)) {
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
auto idx = NewValueNode(SizeToInt(i));
|
||||
if (CheckIfValueNodeIsNull(idx)) {
|
||||
return;
|
||||
}
|
||||
size_t temp = SizeToInt(i);
|
||||
auto imm = std::make_shared<Int32Imm>(temp);
|
||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
||||
idx->set_abstract(abstract_scalar);
|
||||
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||
if (CheckIfCNodeIsNull(tuple_getitem)) {
|
||||
return;
|
||||
}
|
||||
tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i + 1));
|
||||
outputs->push_back(tuple_getitem);
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
|
||||
const std::vector<AnfNodePtr> &conv_outputs, SplitInfo *split_info,
|
||||
const std::string &node_name) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(conv_cnode);
|
||||
|
||||
int32_t nodes_num = conv_outputs.size();
|
||||
if (nodes_num != split_info.out_num) {
|
||||
if (nodes_num != static_cast<int32_t>(split_info->out_num)) {
|
||||
MS_LOG(ERROR) << "Conv outputs has wrong input size";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto concat_prim = std::make_shared<ops::Concat>();
|
||||
concat_prim->set_axis(split_info->axis);
|
||||
|
||||
// the inputs of concate are from the outputs of conv
|
||||
std::vector<AnfNodePtr> concate_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
||||
std::vector<AnfNodePtr> concate_inputs = {NewValueNode(concat_prim)};
|
||||
for (int32_t i = 0; i < nodes_num; i++) {
|
||||
concate_inputs.push_back(conv_outputs[i]);
|
||||
}
|
||||
|
@ -49,78 +300,52 @@ AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr
|
|||
|
||||
concate_cnode->set_fullname_with_scope(node_name + "_Concat");
|
||||
concate_cnode->set_scope(conv_cnode->scope());
|
||||
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
GetMultipleOutputsOfAnfNode(func_graph, concate_cnode, 1, &outputs);
|
||||
return concate_cnode;
|
||||
}
|
||||
|
||||
int32_t GetCOutAxis(int32_t format) {
|
||||
switch (format) {
|
||||
case schema::Format_KHWC:
|
||||
return 0;
|
||||
case schema::Format_CHWK:
|
||||
return 3;
|
||||
case schema::Format_NCHW:
|
||||
return 0;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Do not support format: " << format << " now.";
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t GetCInAxis(int32_t format) {
|
||||
switch (format) {
|
||||
case schema::Format_KHWC:
|
||||
return 3;
|
||||
case schema::Format_CHWK:
|
||||
return 0;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Do not support format: " << format << " now.";
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t GetAxis(int32_t axis, int32_t format, const SplitInfo &split_info) {
|
||||
switch (split_info.primitive_type) {
|
||||
case mindspore::schema::PrimitiveType_Conv2DFusion:
|
||||
if (axis == CuttingStragedy::CUT_C_OUT) {
|
||||
return GetCOutAxis(format);
|
||||
} else if (axis == CuttingStragedy::CUT_C_IN) {
|
||||
return GetCInAxis(format);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Only channel_in and channel_out need to transform.";
|
||||
}
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Now, do not support the type : " << split_info.primitive_type;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateOutputsOfAddN(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
|
||||
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
|
||||
const std::string &node_name) {
|
||||
void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node,
|
||||
std::vector<AnfNodePtr> *split_outputs, SplitInfo *split_info,
|
||||
const std::string &node_name) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(conv_cnode);
|
||||
MS_EXCEPTION_IF_NULL(conv_node);
|
||||
// attr of split
|
||||
auto split_prim = std::make_shared<ops::SplitWithOverlap>();
|
||||
split_prim->set_split_dim(split_info->axis);
|
||||
split_prim->set_number_split(split_info->out_num);
|
||||
split_prim->set_ratio(split_info->size_splits);
|
||||
split_prim->set_extend_top(split_info->extend_top);
|
||||
split_prim->set_extend_bottom(split_info->extend_bottom);
|
||||
// default to format khwc or nhwc
|
||||
split_prim->set_trans_format(true);
|
||||
|
||||
int32_t nodes_num = conv_outputs.size();
|
||||
if (nodes_num != split_info.out_num) {
|
||||
MS_LOG(ERROR) << "Conv outputs has wrong input size";
|
||||
return nullptr;
|
||||
// the inputs of split is from the inputs of conv
|
||||
std::vector<AnfNodePtr> split_inputs = {NewValueNode(split_prim)};
|
||||
auto conv_cnode = conv_node->cast<CNodePtr>();
|
||||
|
||||
// this conv only has one input, which has been ensured before
|
||||
split_inputs.push_back(conv_cnode->input(1));
|
||||
|
||||
auto split_cnode = func_graph->NewCNode(split_inputs);
|
||||
MS_EXCEPTION_IF_NULL(split_cnode);
|
||||
|
||||
split_cnode->set_fullname_with_scope(node_name + "_Split");
|
||||
// create outputs op split
|
||||
GetMultipleOutputsOfAnfNode(func_graph, split_cnode, split_info->out_num, split_outputs);
|
||||
|
||||
AbstractBasePtrList ptr_list;
|
||||
for (size_t i = 0; i < split_info->out_num; i++) {
|
||||
auto node = (*split_outputs)[i];
|
||||
// set date_type same with weight
|
||||
auto type_id = static_cast<TypeId>(kNumberTypeFloat32);
|
||||
auto type_ptr = TypeIdToType(type_id);
|
||||
std::vector<int64_t> shape_vector;
|
||||
auto value_node = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
||||
ptr_list.push_back(value_node);
|
||||
}
|
||||
|
||||
// the inputs of addn are from the outputs of conv
|
||||
std::vector<AnfNodePtr> addn_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))};
|
||||
for (int32_t i = 0; i < nodes_num; i++) {
|
||||
addn_inputs.push_back(conv_outputs[i]);
|
||||
}
|
||||
|
||||
auto addn_cnode = func_graph->NewCNode(addn_inputs);
|
||||
MS_EXCEPTION_IF_NULL(addn_cnode);
|
||||
|
||||
addn_cnode->set_fullname_with_scope(node_name + "_AddN");
|
||||
addn_cnode->set_scope(conv_cnode->scope());
|
||||
|
||||
return addn_cnode;
|
||||
split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,47 +20,49 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "mindspore/ccsrc/utils/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "mindspore/lite/include/context.h"
|
||||
#include "mindspore/lite/include/lite_types.h"
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
|
||||
namespace mindspore {
|
||||
using mindspore::schema::PrimitiveType;
|
||||
namespace opt {
|
||||
|
||||
struct SplitInfo {
|
||||
int32_t axis;
|
||||
int32_t out_num;
|
||||
std::vector<int32_t> size_splits;
|
||||
std::vector<int32_t> extend_top;
|
||||
std::vector<int32_t> extend_bottom;
|
||||
int64_t axis;
|
||||
size_t out_num;
|
||||
std::vector<int64_t> size_splits;
|
||||
std::vector<int64_t> extend_top;
|
||||
std::vector<int64_t> extend_bottom;
|
||||
std::vector<mindspore::lite::DeviceType> dev_types;
|
||||
int32_t in_num_conv;
|
||||
int32_t fmk_type;
|
||||
std::vector<int32_t> weight_channel;
|
||||
int64_t in_num_conv;
|
||||
int64_t fmk_type;
|
||||
std::vector<int64_t> weight_channel;
|
||||
PrimitiveType primitive_type;
|
||||
};
|
||||
|
||||
typedef enum { CUT_N, CUT_H, CUT_W, CUT_C_IN, CUT_C_OUT, CUT_NONE } CuttingStragedy;
|
||||
|
||||
bool IsConv2D(const AnfNodePtr &node);
|
||||
|
||||
std::shared_ptr<ops::Conv2DFusion> CopyConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr);
|
||||
|
||||
bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_nodes, SplitInfo *split_info);
|
||||
|
||||
void GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
|
||||
std::vector<AnfNodePtr> *outputs);
|
||||
|
||||
AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
|
||||
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
|
||||
AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
|
||||
const std::vector<AnfNodePtr> &conv_outputs, SplitInfo *split_info,
|
||||
const std::string &node_name);
|
||||
void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
|
||||
std::vector<AnfNodePtr> *split_outputs, const SplitInfo &split_info,
|
||||
void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
|
||||
std::vector<AnfNodePtr> *split_outputs, SplitInfo *split_info,
|
||||
const std::string &node_name);
|
||||
|
||||
void GetCNodeShapeInfo(const FuncGraphPtr &func_graph, int32_t fmk_type);
|
||||
|
||||
AnfNodePtr CreateOutputsOfAddN(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
|
||||
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
|
||||
const std::string &node_name);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_FISSON_UTIL_H_
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "mindspore/ccsrc/utils/utils.h"
|
||||
#include "mindspore/lite/tools/optimizer/fisson/multi_conv_split_pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "mindspore/core/base/base.h"
|
||||
#include "mindspore/core/ops/fusion/conv2d_fusion.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::schema::PrimitiveType_Conv2dTransposeFusion;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef MultiConvSplitPass::DefinePattern() const {
|
||||
auto conv1_var = std::make_shared<CondVar>(IsConvNode);
|
||||
auto conv1_other_var = std::make_shared<SeqVar>();
|
||||
VectorRef res = VectorRef({conv1_var, conv1_other_var});
|
||||
int32_t idx = 1;
|
||||
while (idx < num_) {
|
||||
auto tmp_var = std::make_shared<CondVar>(IsConvNode);
|
||||
auto tmp_other_var = std::make_shared<SeqVar>();
|
||||
res = VectorRef({tmp_var, res, tmp_other_var});
|
||||
idx++;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
const AnfNodePtr MultiConvSplitPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_LOG(INFO) << "---Enter pass MultiConvSplit.";
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto device_type_attr = cnode->GetAttr(mindspore::ops::kDeviceType);
|
||||
auto device_type = (device_type_attr != nullptr) ? GetValue<int32_t>(device_type_attr) : kDeviceTypeNone;
|
||||
if (device_type != kDeviceTypeNone) {
|
||||
return node;
|
||||
}
|
||||
std::shared_ptr<MultiNodeSplitProxy> multi_node_split_proxy =
|
||||
std::make_shared<MultiNodeSplitProxy>(strategy_, primitive_type_, fmk_type_, num_);
|
||||
return multi_node_split_proxy->DoSplit(func_graph, node);
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_MULTI_CONV_SPLIT_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_MULTI_CONV_SPLIT_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "tools/optimizer/fisson/fisson_util.h"
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/optimizer/parallel/multi_node_split.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class MultiConvSplitPass : public PatternProcessPass {
|
||||
public:
|
||||
explicit MultiConvSplitPass(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1,
|
||||
int32_t num = 3, bool multigraph = true)
|
||||
: PatternProcessPass("multi_conv_split", multigraph),
|
||||
strategy_(strategy),
|
||||
primitive_type_(primitive_type),
|
||||
fmk_type_(fmk_type),
|
||||
num_(num) {}
|
||||
~MultiConvSplitPass() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
SplitStrategy strategy_{};
|
||||
PrimitiveType primitive_type_{schema::PrimitiveType_NONE};
|
||||
int32_t fmk_type_{-1};
|
||||
int32_t num_{0};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_MULTI_CONV_SPLIT_H_
|
|
@ -30,20 +30,6 @@
|
|||
using mindspore::schema::PrimitiveType_Conv2DFusion;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// strategy format is NHWC-KHWC
|
||||
constexpr int32_t kAxisN = 0;
|
||||
constexpr int32_t kAxisCIn = 3;
|
||||
constexpr int32_t kAxisCOut = 0;
|
||||
constexpr int32_t kAxisH = 1;
|
||||
constexpr int32_t kAxisW = 2;
|
||||
|
||||
constexpr auto kIndexH = 0;
|
||||
constexpr auto kIndexW = 1;
|
||||
|
||||
constexpr auto kPadUp = 0;
|
||||
constexpr auto kPadDown = 1;
|
||||
constexpr auto kPadLeft = 2;
|
||||
constexpr auto kPadRight = 3;
|
||||
|
||||
int Conv2DInfo::GetAttrs() { return lite::RET_OK; }
|
||||
|
||||
|
|
|
@ -0,0 +1,218 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/parallel/multi_conv_info.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "tools/optimizer/parallel/spliter.h"
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::schema::PrimitiveType_Conv2dTransposeFusion;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
int MultiConvSplit ::GenSplitInfo() {
|
||||
split_info_.out_num = this->strategy_.dev_num;
|
||||
for (const auto &dev_type : this->strategy_.dev_types) {
|
||||
if (dev_type == "CPU") {
|
||||
split_info_.dev_types.push_back(mindspore::lite::DeviceType::DT_CPU);
|
||||
} else if (dev_type == "GPU") {
|
||||
split_info_.dev_types.push_back(mindspore::lite::DeviceType::DT_GPU);
|
||||
} else if (dev_type == "NPU") {
|
||||
split_info_.dev_types.push_back(mindspore::lite::DeviceType::DT_NPU);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Do not support DeviceType:" << dev_type << "now.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
// only can get N && H && CIN &&
|
||||
std::vector<int64_t> tmp(split_info_.out_num, 0);
|
||||
for (size_t i = 0; i < this->strategy_.strategys[0].size(); i++) {
|
||||
if (this->strategy_.strategys[0][i] == tmp) {
|
||||
continue;
|
||||
}
|
||||
split_info_.axis = i; // NHWC
|
||||
split_info_.size_splits.clear();
|
||||
split_info_.size_splits = this->strategy_.strategys[0][i]; // cal base on compute_cap
|
||||
break;
|
||||
}
|
||||
split_info_.in_num_conv = num_;
|
||||
split_info_.fmk_type = fmk_type_;
|
||||
split_info_.extend_bottom = std::vector<int64_t>(split_info_.size_splits.size(), 0);
|
||||
split_info_.extend_top = std::vector<int64_t>(split_info_.size_splits.size(), 0);
|
||||
split_info_.primitive_type = primitive_type_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MultiConvSplit::GetMultiConvNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(conv_node);
|
||||
// get nodes to be splited
|
||||
// node in graph 1->2->3...
|
||||
// node in vector ...->3->2->1
|
||||
std::string conv_cnode_name = conv_node->fullname_with_scope();
|
||||
MS_LOG(INFO) << "---node name:" << conv_cnode_name;
|
||||
auto graph_node_outputs = Spliter::GetInstance()->graph_node_outputs();
|
||||
auto it = graph_node_outputs.find(conv_cnode_name);
|
||||
if (it == graph_node_outputs.end()) {
|
||||
MS_LOG(INFO) << "This node may be the last node of graph,it do not has any out-nodes.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_nodes_.push_back(conv_node);
|
||||
int32_t idx = 0;
|
||||
while (idx < split_info_.in_num_conv - 1) {
|
||||
auto curr_node = conv_nodes_[idx];
|
||||
auto curr_cnode = conv_nodes_[idx]->cast<CNodePtr>();
|
||||
auto tmp_node = curr_cnode->input(1);
|
||||
if (IsConv2D(tmp_node)) {
|
||||
break;
|
||||
}
|
||||
auto name = tmp_node->fullname_with_scope();
|
||||
// check outputs's bigger than two
|
||||
it = graph_node_outputs.find(name);
|
||||
if (it == graph_node_outputs.end()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (it->second.size() > 1) {
|
||||
break;
|
||||
}
|
||||
conv_nodes_.push_back(tmp_node);
|
||||
idx++;
|
||||
}
|
||||
|
||||
// no need split in multi_node_pass
|
||||
if (conv_nodes_.size() < 2) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
AnfNodePtr MultiConvSplit::MultiConvNHSplit(const AnfNodePtr &node) {
|
||||
std::string conv_cnode_name = node->fullname_with_scope();
|
||||
// Create Split node and get outputs of Split
|
||||
std::vector<AnfNodePtr> split_outputs;
|
||||
CreateOutputsOfSplitWithOverlap(func_graph_, conv_nodes_[conv_nodes_.size() - 1], &split_outputs, &split_info_,
|
||||
conv_cnode_name);
|
||||
// Create Conv node
|
||||
for (int32_t i = conv_nodes_.size() - 1; i >= 0; i--) {
|
||||
std::vector<AnfNodePtr> outputs_node;
|
||||
SplitSingleConv(conv_nodes_[i], split_outputs, {}, {}, &outputs_node);
|
||||
split_outputs.clear();
|
||||
std::copy(outputs_node.begin(), outputs_node.end(), std::back_inserter(split_outputs));
|
||||
outputs_node.clear();
|
||||
}
|
||||
// Create concate node
|
||||
auto concat_node = CreateOutputsOfConcat(func_graph_, node, split_outputs, &split_info_, conv_cnode_name);
|
||||
split_outputs.clear();
|
||||
return concat_node;
|
||||
}
|
||||
|
||||
void MultiConvSplit::SplitSingleConv(const AnfNodePtr &ori_node, const std::vector<AnfNodePtr> &inputs_node,
|
||||
const std::vector<AnfNodePtr> &weight_node,
|
||||
const std::vector<AnfNodePtr> &bias_nodes, std::vector<AnfNodePtr> *outputs_node) {
|
||||
auto ori_conv_cnode = ori_node->cast<CNodePtr>();
|
||||
auto ori_attr = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(ori_conv_cnode->input(kAnfPrimitiveIndex));
|
||||
for (int32_t output_conv_index = 0; output_conv_index < static_cast<int32_t>(split_info_.out_num);
|
||||
output_conv_index++) {
|
||||
// Create Conv node attr
|
||||
auto conv_prim = CopyConvPrim(ori_attr);
|
||||
// adjust primitive
|
||||
AdJustConvPrim(conv_prim, output_conv_index);
|
||||
// node inputs
|
||||
std::vector<AnfNodePtr> conv_inputs;
|
||||
conv_inputs.push_back(NewValueNode(conv_prim));
|
||||
AdJustInputs(ori_node, inputs_node, weight_node, bias_nodes, output_conv_index, &conv_inputs);
|
||||
// create new conv node
|
||||
CreateNewConvNode(ori_node, conv_inputs, output_conv_index, outputs_node);
|
||||
}
|
||||
}
|
||||
|
||||
void MultiConvSplit::AdJustInputs(const AnfNodePtr &ori_conv_node, const std::vector<AnfNodePtr> &new_inputs_node,
|
||||
const std::vector<AnfNodePtr> &weight_node, const std::vector<AnfNodePtr> &bias_nodes,
|
||||
int output_conv_index, std::vector<AnfNodePtr> *conv_inputs) {
|
||||
auto ori_conv_cnode = ori_conv_node->cast<CNodePtr>();
|
||||
// feature_map
|
||||
conv_inputs->push_back(new_inputs_node[output_conv_index]);
|
||||
// W+bias
|
||||
for (size_t j = 2; j < ori_conv_cnode->size(); j++) {
|
||||
conv_inputs->push_back(ori_conv_cnode->input(j));
|
||||
}
|
||||
}
|
||||
|
||||
void MultiConvSplit::CreateNewConvNode(const AnfNodePtr &ori_conv_node, const std::vector<AnfNodePtr> &conv_inputs,
|
||||
int output_conv_index, std::vector<AnfNodePtr> *outputs_node) {
|
||||
auto ori_conv_cnode = ori_conv_node->cast<CNodePtr>();
|
||||
std::string ori_cnode_name = ori_conv_cnode->fullname_with_scope();
|
||||
// new conv_node
|
||||
auto conv_cnode = func_graph_->NewCNode(conv_inputs);
|
||||
conv_cnode->set_fullname_with_scope(ori_cnode_name + "_" + PARALLEL_NAME_SUFFIX +
|
||||
std::to_string(output_conv_index + 1));
|
||||
conv_cnode->AddAttr(mindspore::ops::kDeviceType,
|
||||
MakeValue(static_cast<int>(split_info_.dev_types[output_conv_index])));
|
||||
std::vector<AnfNodePtr> tmp_outputs;
|
||||
// conv2d only has one output, set to output_nodes
|
||||
GetMultipleOutputsOfAnfNode(func_graph_, conv_cnode, 1, &tmp_outputs);
|
||||
outputs_node->push_back(tmp_outputs[0]->cast<CNodePtr>()->input(1));
|
||||
tmp_outputs.clear();
|
||||
}
|
||||
|
||||
AnfNodePtr MultiConvSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
int ret = GenSplitInfo();
|
||||
if (ret != RET_OK) {
|
||||
return node;
|
||||
}
|
||||
func_graph_ = func_graph;
|
||||
ret = GetMultiConvNodes(func_graph, node);
|
||||
if (ret != RET_OK) {
|
||||
return node;
|
||||
}
|
||||
return SplitMultiConv(node);
|
||||
}
|
||||
|
||||
AnfNodePtr MultiConvSplitN::SplitMultiConv(const AnfNodePtr &node) {
|
||||
if (conv_nodes_.size() == 2 && split_info_.axis == CuttingStragedy::CUT_N) {
|
||||
return node;
|
||||
}
|
||||
return MultiConvNHSplit(node);
|
||||
}
|
||||
|
||||
AnfNodePtr MultiConvSplitH::SplitMultiConv(const AnfNodePtr &node) {
|
||||
// update info, N do not need, C do not support
|
||||
if (!UpdateSplitInfo(func_graph_, conv_nodes_, &split_info_)) {
|
||||
return node;
|
||||
}
|
||||
return MultiConvNHSplit(node);
|
||||
}
|
||||
|
||||
void MultiConvSplitH::AdJustConvPrim(const std::shared_ptr<ops::Conv2DFusion> &conv_prim, int output_conv_index) {
|
||||
auto pad_list = conv_prim->get_pad_list();
|
||||
if (output_conv_index == 0) {
|
||||
pad_list[kPadDown] = 0;
|
||||
} else if (output_conv_index == static_cast<int32_t>(split_info_.out_num - 1)) {
|
||||
pad_list[kPadUp] = 0;
|
||||
} else {
|
||||
pad_list[kPadUp] = 0;
|
||||
pad_list[kPadDown] = 0;
|
||||
}
|
||||
conv_prim->set_pad_list(pad_list);
|
||||
}
|
||||
|
||||
AnfNodePtr MultiConvSplitCIN::SplitMultiConv(const AnfNodePtr &node) { return nullptr; }
|
||||
|
||||
AnfNodePtr MultiConvSplitCOUT::SplitMultiConv(const AnfNodePtr &node) { return nullptr; }
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_CONV_INFO_H
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_CONV_INFO_H
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "tools/optimizer/parallel/multi_node_split.h"
|
||||
#include "tools/optimizer/fisson/fisson_util.h"
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class MultiConvSplit : public MultiNodeSplit {
|
||||
public:
|
||||
explicit MultiConvSplit(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1,
|
||||
int32_t num = 3)
|
||||
: MultiNodeSplit(), strategy_(strategy), primitive_type_(primitive_type), fmk_type_(fmk_type), num_(num) {}
|
||||
|
||||
AnfNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||
|
||||
virtual AnfNodePtr SplitMultiConv(const AnfNodePtr &node) = 0;
|
||||
|
||||
virtual void AdJustConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int output_conv_index) = 0;
|
||||
|
||||
virtual AnfNodePtr MultiConvNHSplit(const AnfNodePtr &node);
|
||||
|
||||
virtual void AdJustInputs(const AnfNodePtr &ori_node, const std::vector<AnfNodePtr> &new_inputs_node,
|
||||
const std::vector<AnfNodePtr> &weight_node, const std::vector<AnfNodePtr> &bias_nodes,
|
||||
int output_conv_index, std::vector<AnfNodePtr> *conv_inputs);
|
||||
|
||||
virtual void CreateNewConvNode(const AnfNodePtr &ori_conv_node, const std::vector<AnfNodePtr> &conv_inputs,
|
||||
int output_conv_index, std::vector<AnfNodePtr> *outputs_node);
|
||||
|
||||
virtual void SplitSingleConv(const AnfNodePtr &ori_node, const std::vector<AnfNodePtr> &inputs_node,
|
||||
const std::vector<AnfNodePtr> &weight_node, const std::vector<AnfNodePtr> &bias_nodes,
|
||||
std::vector<AnfNodePtr> *outputs_node);
|
||||
|
||||
protected:
|
||||
FuncGraphPtr func_graph_{nullptr};
|
||||
SplitInfo split_info_{};
|
||||
SplitStrategy strategy_{};
|
||||
PrimitiveType primitive_type_{schema::PrimitiveType_NONE};
|
||||
int32_t fmk_type_{-1};
|
||||
int32_t num_{0};
|
||||
std::vector<AnfNodePtr> conv_nodes_{};
|
||||
|
||||
private:
|
||||
int GenSplitInfo();
|
||||
int GetMultiConvNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node);
|
||||
};
|
||||
|
||||
class MultiConvSplitN final : public MultiConvSplit {
|
||||
public:
|
||||
MultiConvSplitN(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, int32_t num = 3)
|
||||
: MultiConvSplit(strategy, primitive_type, fmk_type, num) {}
|
||||
|
||||
AnfNodePtr SplitMultiConv(const AnfNodePtr &node) override;
|
||||
|
||||
void AdJustConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int output_conv_index) override {}
|
||||
};
|
||||
|
||||
class MultiConvSplitCIN final : public MultiConvSplit {
|
||||
public:
|
||||
MultiConvSplitCIN(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, int32_t num = 3)
|
||||
: MultiConvSplit(strategy, primitive_type, fmk_type, num) {}
|
||||
|
||||
AnfNodePtr SplitMultiConv(const AnfNodePtr &node) override;
|
||||
|
||||
void AdJustConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int output_conv_index) override {}
|
||||
};
|
||||
|
||||
class MultiConvSplitCOUT final : public MultiConvSplit {
|
||||
public:
|
||||
MultiConvSplitCOUT(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1,
|
||||
int32_t num = 3)
|
||||
: MultiConvSplit(strategy, primitive_type, fmk_type, num) {}
|
||||
|
||||
AnfNodePtr SplitMultiConv(const AnfNodePtr &node) override;
|
||||
|
||||
void AdJustConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int output_conv_index) override {}
|
||||
};
|
||||
|
||||
class MultiConvSplitH final : public MultiConvSplit {
|
||||
public:
|
||||
MultiConvSplitH(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1, int32_t num = 3)
|
||||
: MultiConvSplit(strategy, primitive_type, fmk_type, num) {}
|
||||
|
||||
AnfNodePtr SplitMultiConv(const AnfNodePtr &node) override;
|
||||
|
||||
void AdJustConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_attr, int output_conv_index) override;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_CONV_INFO_H
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/parallel/multi_node_split.h"
|
||||
#include "tools/optimizer/parallel/multi_conv_info.h"
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
int MultiNodeSplitProxy::InitResource() {
|
||||
switch (split_mode_) {
|
||||
case SplitN:
|
||||
multi_node_split_ = std::make_shared<MultiConvSplitN>(strategy_, primitive_type_, fmk_type_, num_);
|
||||
return RET_OK;
|
||||
case SplitH:
|
||||
multi_node_split_ = std::make_shared<MultiConvSplitH>(strategy_, primitive_type_, fmk_type_, num_);
|
||||
return RET_OK;
|
||||
case SplitCIN:
|
||||
multi_node_split_ = std::make_shared<MultiConvSplitCIN>(strategy_, primitive_type_, fmk_type_, num_);
|
||||
return RET_OK;
|
||||
case SplitCOUT:
|
||||
multi_node_split_ = std::make_shared<MultiConvSplitCOUT>(strategy_, primitive_type_, fmk_type_, num_);
|
||||
return RET_OK;
|
||||
default:
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
int MultiNodeSplitProxy::FreeResource() {
|
||||
multi_node_split_ = nullptr;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
AnfNodePtr MultiNodeSplitProxy::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
int ret = InitResource();
|
||||
if (ret != RET_OK) {
|
||||
return node;
|
||||
}
|
||||
auto res_node = multi_node_split_->DoSplit(func_graph, node);
|
||||
ret = FreeResource();
|
||||
if (ret != RET_OK) {
|
||||
return node;
|
||||
}
|
||||
return res_node;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_NODE_SPLIT_H
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_NODE_SPLIT_H
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "base/base.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class MultiNodeSplit {
|
||||
public:
|
||||
MultiNodeSplit() = default;
|
||||
|
||||
virtual ~MultiNodeSplit() = default;
|
||||
|
||||
virtual AnfNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0;
|
||||
};
|
||||
|
||||
class MultiNodeSplitProxy : public MultiNodeSplit {
|
||||
public:
|
||||
explicit MultiNodeSplitProxy(const SplitStrategy &strategy, PrimitiveType primitive_type, int32_t fmk_type = -1,
|
||||
int32_t num = 3)
|
||||
: MultiNodeSplit(), strategy_(strategy), primitive_type_(primitive_type), fmk_type_(fmk_type), num_(num) {}
|
||||
|
||||
~MultiNodeSplitProxy() override = default;
|
||||
|
||||
AnfNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
int InitResource();
|
||||
int FreeResource();
|
||||
|
||||
private:
|
||||
SplitMode split_mode_{NoSplit};
|
||||
SplitStrategy strategy_{};
|
||||
PrimitiveType primitive_type_{schema::PrimitiveType_NONE};
|
||||
int32_t fmk_type_{-1};
|
||||
int32_t num_{0};
|
||||
std::shared_ptr<MultiNodeSplit> multi_node_split_{nullptr};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_MULTI_NODE_SPLIT_H
|
|
@ -38,6 +38,22 @@ const std::vector<std::string> kSplitDevTypes = {"CPU", "GPU"};
|
|||
|
||||
using Strategys = std::vector<std::vector<std::vector<int64_t>>>;
|
||||
|
||||
constexpr auto kDeviceTypeNone = -1;
|
||||
// strategy format is NHWC-KHWC
|
||||
constexpr int32_t kAxisN = 0;
|
||||
constexpr int32_t kAxisCIn = 3;
|
||||
constexpr int32_t kAxisCOut = 0;
|
||||
constexpr int32_t kAxisH = 1;
|
||||
constexpr int32_t kAxisW = 2;
|
||||
|
||||
constexpr auto kIndexH = 0;
|
||||
constexpr auto kIndexW = 1;
|
||||
|
||||
constexpr auto kPadUp = 0;
|
||||
constexpr auto kPadDown = 1;
|
||||
constexpr auto kPadLeft = 2;
|
||||
constexpr auto kPadRight = 3;
|
||||
|
||||
enum SplitMode {
|
||||
NoSplit = 0,
|
||||
SplitN = 1,
|
||||
|
|
Loading…
Reference in New Issue