diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.c index 354bc8b15c1..92499aac6b0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.c @@ -15,32 +15,22 @@ */ #include "nnacl/base/split_with_over_lap_base.h" -#include "nnacl/split_parameter.h" #include #include "nnacl/errorcode.h" -int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes, - int outer_total_dim, int inner_stride, const int *start_indices, const int *end_indices) { - int input_stride = split_dim_size * inner_stride * element_bytes; - for (int slice_idx = 0; slice_idx < num_split; slice_idx++) { - int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes; - char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes; - for (int out_idx = 0; out_idx < outer_total_dim; out_idx++) { - (void)(memcpy(out_data[slice_idx] + out_idx * out_stride, src_ptr, out_stride)); - src_ptr += input_stride; - } - } - return NNACL_OK; -} +int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, SplitWithOverlapParameter *param, + const int *start_indices, const int *end_indices) { + int start_index = start_indices[slice_idx]; + int end_index = end_indices[slice_idx]; -int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes, - int outer_total_dim, int inner_stride, const int *start_indices, - const int *end_indices) { - int input_stride = split_dim_size * inner_stride * element_bytes; - int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes; - char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes; - for (int i = 0; i < outer_total_dim; i++) { - (void)memcpy(out_data[slice_idx] + i * out_stride, src_ptr, out_stride); + int input_stride = param->split_dim_size_ * param->inner_stride_ * param->element_bytes_; + int out_stride = (end_index - start_index) * param->inner_stride_ * param->element_bytes_; + + char *src_ptr = in_data + start_index * param->inner_stride_ * param->element_bytes_; + char *dst_ptr = out_data[slice_idx]; + + for (int i = 0; i < param->outer_total_dim_; i++) { + (void)memcpy(dst_ptr + i * out_stride, src_ptr, out_stride); src_ptr += input_stride; } return NNACL_OK; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.h index 140a22a727e..2bd32cc9c8d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.h @@ -23,12 +23,8 @@ #ifdef __cplusplus extern "C" { #endif -int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes, - int outer_total_dim, int inner_stride, const int *start_indices, const int *end_indices); - -int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes, - int outer_total_dim, int inner_stride, const int *start_indices, const int *end_indices); - +int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, SplitWithOverlapParameter *param, + const int *start_indices, const int *end_indices); #ifdef __cplusplus } #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.c index 44fe16de018..ea8ddc88861 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.c @@ -29,13 +29,12 @@ int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, } const TensorC *input = inputs[0]; SplitWithOverlapParameter *param = (SplitWithOverlapParameter *)parameter; + + int split_dim = param->split_dim_; int number_split = param->num_split_; if (outputs_size != number_split) { return NNACL_ERR; } - int stride = param->split_stride_; - int pad_top = param->pad_top_; - int split_dim = param->split_dim_; int ratio[SPLIT_MAX_SLICE_NUM]; int extend_top[SPLIT_MAX_SLICE_NUM]; @@ -58,15 +57,8 @@ int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, int visited_block = 0; for (int i = 0; i < number_split - 1; i++) { visited_block += ratio[i]; - int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); - if (stride != 0) { - // make sure border align with stride - cur_border = UP_ROUND(cur_border + pad_top, stride); - borders[i + 1] = cur_border - pad_top; - } else { - borders[i + 1] = cur_border; - } + borders[i + 1] = cur_border; } borders[number_split] = split_dim_size; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/split_parameter.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/split_parameter.h index 4882abb8297..84f38fb2d74 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/split_parameter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/split_parameter.h @@ -49,11 +49,15 @@ typedef struct SplitWithOverlapParameter { OpParameter op_parameter_; int num_split_; int split_dim_; - int split_stride_; - int pad_top_; int ratio_[SPLIT_MAX_SLICE_NUM]; int extend_top_[SPLIT_MAX_SLICE_NUM]; int extend_bottom_[SPLIT_MAX_SLICE_NUM]; + + // other parameter + int element_bytes_; + int split_dim_size_; + int outer_total_dim_; + int inner_stride_; } SplitWithOverlapParameter; #endif // MINDSPORE_NNACL_SPLIT_PARAMETER_H_ diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 09f421326f9..662a6a37592 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1125,14 +1125,11 @@ table Custom { } table SplitWithOverlap { + split_dim: long; number_split: long; ratio: [long]; extend_top: [long]; extend_bottom: [long]; - split_dim: long; - split_stride: long; - pad_top: long; - trans_format: bool = false; } table GenOP { diff --git a/mindspore/lite/src/ops/ops_def.cc b/mindspore/lite/src/ops/ops_def.cc index 1a5e8f76162..633aae02364 100644 --- a/mindspore/lite/src/ops/ops_def.cc +++ b/mindspore/lite/src/ops/ops_def.cc @@ -1125,14 +1125,11 @@ OP_ATTR_ONLY(attr, [Attribute]) OP_SCHEMA_DEF_ONLY_END(Custom) OP_SCHEMA_DEF(SplitWithOverlap) +OP_ATTR(split_dim, long) OP_ATTR(number_split, long) OP_ATTR(ratio, [long]) OP_ATTR(extend_top, [long]) OP_ATTR(extend_bottom, [long]) -OP_ATTR(split_dim, long) -OP_ATTR(split_stride, long) -OP_ATTR(pad_top, long) -OP_ATTR_WITH_VALUE(trans_format, bool, false) OP_SCHEMA_DEF_END(SplitWithOverlap) OP_SCHEMA_DEF_ONLY(GenOP) diff --git a/mindspore/lite/src/ops/populate/split_with_overlap_populate.cc b/mindspore/lite/src/ops/populate/split_with_overlap_populate.cc index fa0f0fdad08..c993718633a 100644 --- a/mindspore/lite/src/ops/populate/split_with_overlap_populate.cc +++ b/mindspore/lite/src/ops/populate/split_with_overlap_populate.cc @@ -36,25 +36,27 @@ OpParameter *PopulateSplitWithOverlapParameter(const void *prim) { memset(param, 0, sizeof(SplitWithOverlapParameter)); param->op_parameter_.type_ = primitive->value_type(); - auto ratio = value->ratio(); - if (ratio == nullptr) { - MS_LOG(ERROR) << "ratio is nullptr"; - free(param); - return nullptr; - } - if (ratio->size() > SPLIT_MAX_SLICE_NUM) { - MS_LOG(ERROR) << "SplitWithOverlap do not support splitting tensor into more than " << SPLIT_MAX_SLICE_NUM - << " slices"; - free(param); - return nullptr; - } - param->num_split_ = static_cast(ratio->size()); + param->num_split_ = value->number_split(); param->split_dim_ = value->split_dim(); + if (param->num_split_ > SPLIT_MAX_SLICE_NUM) { + MS_LOG(ERROR) << "SplitWithOverlap num_split_ error."; + free(param); + return nullptr; + } + + auto ratio = value->ratio(); auto extend_top = value->extend_top(); auto extend_bottom = value->extend_bottom(); - if (extend_top->size() != ratio->size() || (extend_bottom != nullptr && extend_bottom->size() != ratio->size())) { - MS_LOG(ERROR) << "The sizes of ratio, extend_top and extend_bottom are not identical"; + if (ratio == nullptr || extend_top == nullptr || extend_bottom == nullptr) { + MS_LOG(ERROR) << "SplitWithOverlap parameter is nullptr."; + free(param); + return nullptr; + } + if (static_cast(ratio->size()) != param->num_split_ || + static_cast(extend_top->size()) != param->num_split_ || + static_cast(extend_bottom->size()) != param->num_split_) { + MS_LOG(ERROR) << "SplitWithOverlap parameter size error."; free(param); return nullptr; } @@ -65,9 +67,6 @@ OpParameter *PopulateSplitWithOverlapParameter(const void *prim) { param->extend_bottom_[i] = (*extend_bottom)[i]; } - param->split_stride_ = value->split_stride(); - param->pad_top_ = value->pad_top(); - return reinterpret_cast(param); } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.cc index f7f3791a476..36ab44b9773 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.cc @@ -40,13 +40,7 @@ void SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const std::vectornum_split_ - 1; i++) { visited_block += param_->ratio_[i]; auto cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); - if (param_->split_stride_ != 0) { - // make sure border align with stride - cur_border = UP_ROUND(cur_border + param_->pad_top_, param_->split_stride_); - borders.emplace_back(cur_border - param_->pad_top_); - } else { - borders.emplace_back(cur_border); - } + borders.emplace_back(cur_border); } borders.emplace_back(split_dim_size); @@ -74,24 +68,26 @@ int SplitWithOverlapBaseCPUKernel::ReSize() { CalculateSplitedShapes(input_shape); - element_bytes_ = static_cast(lite::DataTypeSize(in_tensor->data_type())); + param_->element_bytes_ = static_cast(lite::DataTypeSize(in_tensor->data_type())); - outer_total_dim_ = 1; - inner_stride_ = 1; + param_->outer_total_dim_ = 1; + param_->inner_stride_ = 1; for (int i = 0; i < static_cast(input_shape.size()); i++) { - if (i < param_->split_dim_) outer_total_dim_ *= input_shape[i]; - if (i == param_->split_dim_) split_dim_size_ = input_shape[param_->split_dim_]; - if (i > param_->split_dim_) inner_stride_ *= input_shape[i]; + if (i < param_->split_dim_) param_->outer_total_dim_ *= input_shape[i]; + if (i == param_->split_dim_) param_->split_dim_size_ = input_shape[param_->split_dim_]; + if (i > param_->split_dim_) param_->inner_stride_ *= input_shape[i]; } + thread_count_ = MSMIN(param_->num_split_, op_parameter_->thread_num_); return RET_OK; } int SplitWithOverlapBaseCPUKernel::Split(int task_id) { - DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), task_id, split_dim_size_, element_bytes_, outer_total_dim_, - inner_stride_, start_indices_.data(), end_indices_.data()); - + for (int current_slice_task = task_id; current_slice_task < param_->num_split_; current_slice_task += thread_count_) { + DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), current_slice_task, param_, start_indices_.data(), + end_indices_.data()); + } return RET_OK; } @@ -113,7 +109,7 @@ int SplitWithOverlapBaseCPUKernel::Run() { output_ptr_.push_back(reinterpret_cast(out_tensors_.at(i)->data_c())); } - auto ret = ParallelLaunch(this->context_, SplitWithOverlapRun, this, param_->num_split_); + auto ret = ParallelLaunch(this->context_, SplitWithOverlapRun, this, thread_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.h b/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.h index 4887877d479..530037b9a6c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.h @@ -47,11 +47,8 @@ class SplitWithOverlapBaseCPUKernel : public InnerKernel { std::vector end_indices_; SplitWithOverlapParameter *param_ = nullptr; + int thread_count_; - int outer_total_dim_{0}; - int inner_stride_{0}; - int element_bytes_{0}; - int split_dim_size_{0}; char *input_ptr_{nullptr}; std::vector output_ptr_; }; diff --git a/mindspore/lite/test/config/models_mindrt_parallel.cfg b/mindspore/lite/test/config/models_mindrt_parallel.cfg index e65db28397e..d6a33280487 100644 --- a/mindspore/lite/test/config/models_mindrt_parallel.cfg +++ b/mindspore/lite/test/config/models_mindrt_parallel.cfg @@ -3,4 +3,5 @@ # model_file ### accuracy_limit ### enable_fp16(true or false) mtk_model_normalize_object_scene_ps_20200519_f32.tflite;0.5;false hiai_cv_poseEstimation.tflite;0.5;false +hiai_lm_inference_graph.pb;0.5;false # end \ No newline at end of file diff --git a/mindspore/lite/tools/optimizer/fisson/fisson_util.cc b/mindspore/lite/tools/optimizer/fisson/fisson_util.cc index e1a1205a4cc..dd8be3b420e 100644 --- a/mindspore/lite/tools/optimizer/fisson/fisson_util.cc +++ b/mindspore/lite/tools/optimizer/fisson/fisson_util.cc @@ -306,10 +306,7 @@ void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNo split_prim->set_ratio(split_info->size_splits); split_prim->set_extend_top(split_info->extend_top); split_prim->set_extend_bottom(split_info->extend_bottom); - // default to format khwc or nhwc - split_prim->set_trans_format(false); auto conv_cnode = conv_node->cast(); - split_prim->set_split_stride(0); // the inputs of split is from the inputs of conv std::vector split_inputs = {NewValueNode(split_prim)}; @@ -337,5 +334,28 @@ void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNo split_cnode->set_abstract(std::make_shared(ptr_list)); } +void UpdateRatioWithPadStride(int64_t *ratio, size_t split_size, int split_dim_size, int pad, int stride) { + if (stride == 0) { + return; + } + + int total_block_count = 0; + for (size_t i = 0; i < split_size; i++) { + total_block_count += ratio[i]; + } + + std::vector new_ratio(split_size); + int visited_block = 0; + for (size_t i = 0; i < split_size - 1; i++) { + visited_block += ratio[i]; + int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); + new_ratio[i + 1] = cur_border; + } + + for (size_t i = 0; i < split_size; i++) { + ratio[i] = new_ratio[i]; + } + return; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fisson/fisson_util.h b/mindspore/lite/tools/optimizer/fisson/fisson_util.h index 59346bea43f..89d830ba829 100644 --- a/mindspore/lite/tools/optimizer/fisson/fisson_util.h +++ b/mindspore/lite/tools/optimizer/fisson/fisson_util.h @@ -66,6 +66,7 @@ AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePt void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode, std::vector *split_outputs, SplitInfo *split_info, const std::string &node_name); +void UpdateRatioWithPadStride(int64_t *ratio, size_t split_size, int split_dim_size, int pad, int stride); } // namespace opt } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_FISSON_UTIL_H_ diff --git a/mindspore/lite/tools/optimizer/parallel/conv2d_info.cc b/mindspore/lite/tools/optimizer/parallel/conv2d_info.cc index b8e802e00d2..99935acfe02 100644 --- a/mindspore/lite/tools/optimizer/parallel/conv2d_info.cc +++ b/mindspore/lite/tools/optimizer/parallel/conv2d_info.cc @@ -26,6 +26,8 @@ #include "tools/converter/converter_flags.h" #include "include/errorcode.h" #include "tools/optimizer/parallel/operator_info_register.h" +#include "tools/optimizer/parallel/spliter.h" +#include "tools/optimizer/fisson/fisson_util.h" using mindspore::schema::PrimitiveType_Conv2DFusion; namespace mindspore { @@ -128,28 +130,37 @@ int Conv2DInfo::CheckIfSplit() { AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector *split_outputs, size_t split_dim, size_t split_num, - const std::vector &splits, bool trans_format) { + const std::vector &splits) { + auto graph_node_input_shapes = Spliter::GetInstance()->graph_node_input_shapes(); + auto ori_node_name = orig_node->fullname_with_scope(); + auto input_shape_iter = graph_node_input_shapes.find(ori_node_name); + if (input_shape_iter == graph_node_input_shapes.end()) { + return nullptr; + } + auto input_shapes = input_shape_iter->second; + auto input_shape = input_shapes.front(); + auto conv_prim = GetValueNode>(cnode_->input(kAnfPrimitiveIndex)); + // prim of split auto split_prim = std::make_shared(); - split_prim->set_split_dim(split_dim); - split_prim->set_number_split(split_num); - split_prim->set_ratio(splits); - split_prim->set_trans_format(trans_format); + std::vector new_splits = splits; if (split_mode_ == SplitH) { split_prim->set_extend_top(std::vector(split_num, 0)); auto extend_bottom = conv_prim->get_kernel_size().at(kIndexH) - conv_prim->get_stride().at(kIndexH); auto bottom_vector = std::vector(split_num, extend_bottom); bottom_vector[split_num - 1] = 0; split_prim->set_extend_bottom(bottom_vector); - split_prim->set_split_stride(conv_prim->get_stride().at(kIndexH)); - split_prim->set_pad_top(conv_prim->get_pad_list().at(kPadUp)); + UpdateRatioWithPadStride(new_splits.data(), split_num, input_shape[split_dim], conv_prim->get_pad_list().at(kPadUp), + conv_prim->get_stride().at(kIndexH)); } else { split_prim->set_extend_top(std::vector(split_num, 0)); split_prim->set_extend_bottom(std::vector(split_num, 0)); - split_prim->set_split_stride(0); - split_prim->set_pad_top(0); } + split_prim->set_split_dim(split_dim); + split_prim->set_number_split(split_num); + split_prim->set_ratio(new_splits); + std::vector split_inputs = {NewValueNode(split_prim)}; // ori_conv_node must only have one input split_inputs.push_back(orig_node->input(input_index + 1)); @@ -197,8 +208,7 @@ int Conv2DInfo::InferParallelCNodes() { case SplitN: case SplitH: { name_ = orig_name + "_input"; - auto feature_split_cnode = - CreateOutputsOfSplit(cnode_, 0, &feature_split_outputs, kAxisH, dev_num, splits_, true); + auto feature_split_cnode = CreateOutputsOfSplit(cnode_, 0, &feature_split_outputs, kAxisH, dev_num, splits_); if (CheckSplitResult(feature_split_cnode, feature_split_outputs, dev_num) != RET_OK) { return RET_ERROR; } diff --git a/mindspore/lite/tools/optimizer/parallel/conv2d_info.h b/mindspore/lite/tools/optimizer/parallel/conv2d_info.h index 603a187f23e..80ec5f2925c 100644 --- a/mindspore/lite/tools/optimizer/parallel/conv2d_info.h +++ b/mindspore/lite/tools/optimizer/parallel/conv2d_info.h @@ -42,8 +42,7 @@ class Conv2DInfo : public OperatorInfo { const std::vector &kernel_split_outputs, const std::vector &bias_split_outputs); AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector *split_outputs, - size_t split_dim, size_t split_num, const std::vector &splits, - bool trans_format) override; + size_t split_dim, size_t split_num, const std::vector &splits) override; protected: SplitMode split_mode_ = NoSplit; diff --git a/mindspore/lite/tools/optimizer/parallel/depthwise_conv2d_info.cc b/mindspore/lite/tools/optimizer/parallel/depthwise_conv2d_info.cc index 078eabfeb07..b3e066dd9cf 100644 --- a/mindspore/lite/tools/optimizer/parallel/depthwise_conv2d_info.cc +++ b/mindspore/lite/tools/optimizer/parallel/depthwise_conv2d_info.cc @@ -311,12 +311,10 @@ AnfNodePtr DepthwiseConv2DInfo::CreateOutputsOfSplit(const CNodePtr &ori_node, s auto pad_list = GetSplitPadList(depth_wise_conv_prim, input_h, input_w); depth_wise_conv_prim->set_pad_list(pad_list); depth_wise_conv_prim->set_pad_mode(PAD); + // prim of split auto split_prim = std::make_shared(); - split_prim->set_split_dim(split_dim_); - split_prim->set_number_split(split_num); - split_prim->set_ratio(splits); - split_prim->set_trans_format(false); + std::vector new_splits = splits; if (split_mode_ == SplitH) { split_prim->set_extend_top(std::vector(split_num, 0)); auto extend_bottom = @@ -324,14 +322,17 @@ AnfNodePtr DepthwiseConv2DInfo::CreateOutputsOfSplit(const CNodePtr &ori_node, s auto bottom_vector = std::vector(split_num, extend_bottom); bottom_vector[split_num - 1] = 0; split_prim->set_extend_bottom(bottom_vector); - split_prim->set_split_stride(depth_wise_conv_prim->get_stride().at(kIndexH)); - split_prim->set_pad_top(depth_wise_conv_prim->get_pad_list().at(kPadUp)); + UpdateRatioWithPadStride(new_splits.data(), split_num, input_shape[split_dim_], + depth_wise_conv_prim->get_pad_list().at(kPadUp), + depth_wise_conv_prim->get_stride().at(kIndexH)); } else { split_prim->set_extend_top(std::vector(split_num, 0)); split_prim->set_extend_bottom(std::vector(split_num, 0)); - split_prim->set_split_stride(0); - split_prim->set_pad_top(0); } + split_prim->set_split_dim(split_dim_); + split_prim->set_number_split(split_num); + split_prim->set_ratio(new_splits); + std::vector split_inputs = {NewValueNode(split_prim)}; // ori_conv_node must only have one feature input split_inputs.push_back(ori_node->input(input_index + 1)); diff --git a/mindspore/lite/tools/optimizer/parallel/operator_info.h b/mindspore/lite/tools/optimizer/parallel/operator_info.h index 4c14cec305a..dacc4bfa870 100644 --- a/mindspore/lite/tools/optimizer/parallel/operator_info.h +++ b/mindspore/lite/tools/optimizer/parallel/operator_info.h @@ -73,7 +73,7 @@ class OperatorInfo { virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &input_node, size_t input_index, std::vector *split_outputs, size_t split_dim, size_t split_num, - const std::vector &splits, bool trans_format) = 0; + const std::vector &splits) = 0; virtual int InferReplaceOp() = 0; virtual int InferParallelCNodes() = 0; virtual int CheckStrategy(const SplitStrategy &strategy) = 0;