split-with-over-lap op bug
This commit is contained in:
parent
0cd459fb47
commit
74b6edfdba
|
@ -15,32 +15,22 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/base/split_with_over_lap_base.h"
|
||||
#include "nnacl/split_parameter.h"
|
||||
#include <string.h>
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes,
|
||||
int outer_total_dim, int inner_stride, const int *start_indices, const int *end_indices) {
|
||||
int input_stride = split_dim_size * inner_stride * element_bytes;
|
||||
for (int slice_idx = 0; slice_idx < num_split; slice_idx++) {
|
||||
int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes;
|
||||
char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes;
|
||||
for (int out_idx = 0; out_idx < outer_total_dim; out_idx++) {
|
||||
(void)(memcpy(out_data[slice_idx] + out_idx * out_stride, src_ptr, out_stride));
|
||||
src_ptr += input_stride;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, SplitWithOverlapParameter *param,
|
||||
const int *start_indices, const int *end_indices) {
|
||||
int start_index = start_indices[slice_idx];
|
||||
int end_index = end_indices[slice_idx];
|
||||
|
||||
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes,
|
||||
int outer_total_dim, int inner_stride, const int *start_indices,
|
||||
const int *end_indices) {
|
||||
int input_stride = split_dim_size * inner_stride * element_bytes;
|
||||
int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes;
|
||||
char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes;
|
||||
for (int i = 0; i < outer_total_dim; i++) {
|
||||
(void)memcpy(out_data[slice_idx] + i * out_stride, src_ptr, out_stride);
|
||||
int input_stride = param->split_dim_size_ * param->inner_stride_ * param->element_bytes_;
|
||||
int out_stride = (end_index - start_index) * param->inner_stride_ * param->element_bytes_;
|
||||
|
||||
char *src_ptr = in_data + start_index * param->inner_stride_ * param->element_bytes_;
|
||||
char *dst_ptr = out_data[slice_idx];
|
||||
|
||||
for (int i = 0; i < param->outer_total_dim_; i++) {
|
||||
(void)memcpy(dst_ptr + i * out_stride, src_ptr, out_stride);
|
||||
src_ptr += input_stride;
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
|
|
@ -23,12 +23,8 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes,
|
||||
int outer_total_dim, int inner_stride, const int *start_indices, const int *end_indices);
|
||||
|
||||
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes,
|
||||
int outer_total_dim, int inner_stride, const int *start_indices, const int *end_indices);
|
||||
|
||||
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, SplitWithOverlapParameter *param,
|
||||
const int *start_indices, const int *end_indices);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -29,13 +29,12 @@ int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size,
|
|||
}
|
||||
const TensorC *input = inputs[0];
|
||||
SplitWithOverlapParameter *param = (SplitWithOverlapParameter *)parameter;
|
||||
|
||||
int split_dim = param->split_dim_;
|
||||
int number_split = param->num_split_;
|
||||
if (outputs_size != number_split) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
int stride = param->split_stride_;
|
||||
int pad_top = param->pad_top_;
|
||||
int split_dim = param->split_dim_;
|
||||
|
||||
int ratio[SPLIT_MAX_SLICE_NUM];
|
||||
int extend_top[SPLIT_MAX_SLICE_NUM];
|
||||
|
@ -58,15 +57,8 @@ int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size,
|
|||
int visited_block = 0;
|
||||
for (int i = 0; i < number_split - 1; i++) {
|
||||
visited_block += ratio[i];
|
||||
|
||||
int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
|
||||
if (stride != 0) {
|
||||
// make sure border align with stride
|
||||
cur_border = UP_ROUND(cur_border + pad_top, stride);
|
||||
borders[i + 1] = cur_border - pad_top;
|
||||
} else {
|
||||
borders[i + 1] = cur_border;
|
||||
}
|
||||
borders[i + 1] = cur_border;
|
||||
}
|
||||
borders[number_split] = split_dim_size;
|
||||
|
||||
|
|
|
@ -49,11 +49,15 @@ typedef struct SplitWithOverlapParameter {
|
|||
OpParameter op_parameter_;
|
||||
int num_split_;
|
||||
int split_dim_;
|
||||
int split_stride_;
|
||||
int pad_top_;
|
||||
int ratio_[SPLIT_MAX_SLICE_NUM];
|
||||
int extend_top_[SPLIT_MAX_SLICE_NUM];
|
||||
int extend_bottom_[SPLIT_MAX_SLICE_NUM];
|
||||
|
||||
// other parameter
|
||||
int element_bytes_;
|
||||
int split_dim_size_;
|
||||
int outer_total_dim_;
|
||||
int inner_stride_;
|
||||
} SplitWithOverlapParameter;
|
||||
|
||||
#endif // MINDSPORE_NNACL_SPLIT_PARAMETER_H_
|
||||
|
|
|
@ -1125,14 +1125,11 @@ table Custom {
|
|||
}
|
||||
|
||||
table SplitWithOverlap {
|
||||
split_dim: long;
|
||||
number_split: long;
|
||||
ratio: [long];
|
||||
extend_top: [long];
|
||||
extend_bottom: [long];
|
||||
split_dim: long;
|
||||
split_stride: long;
|
||||
pad_top: long;
|
||||
trans_format: bool = false;
|
||||
}
|
||||
|
||||
table GenOP {
|
||||
|
|
|
@ -1125,14 +1125,11 @@ OP_ATTR_ONLY(attr, [Attribute])
|
|||
OP_SCHEMA_DEF_ONLY_END(Custom)
|
||||
|
||||
OP_SCHEMA_DEF(SplitWithOverlap)
|
||||
OP_ATTR(split_dim, long)
|
||||
OP_ATTR(number_split, long)
|
||||
OP_ATTR(ratio, [long])
|
||||
OP_ATTR(extend_top, [long])
|
||||
OP_ATTR(extend_bottom, [long])
|
||||
OP_ATTR(split_dim, long)
|
||||
OP_ATTR(split_stride, long)
|
||||
OP_ATTR(pad_top, long)
|
||||
OP_ATTR_WITH_VALUE(trans_format, bool, false)
|
||||
OP_SCHEMA_DEF_END(SplitWithOverlap)
|
||||
|
||||
OP_SCHEMA_DEF_ONLY(GenOP)
|
||||
|
|
|
@ -36,25 +36,27 @@ OpParameter *PopulateSplitWithOverlapParameter(const void *prim) {
|
|||
memset(param, 0, sizeof(SplitWithOverlapParameter));
|
||||
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
auto ratio = value->ratio();
|
||||
if (ratio == nullptr) {
|
||||
MS_LOG(ERROR) << "ratio is nullptr";
|
||||
free(param);
|
||||
return nullptr;
|
||||
}
|
||||
if (ratio->size() > SPLIT_MAX_SLICE_NUM) {
|
||||
MS_LOG(ERROR) << "SplitWithOverlap do not support splitting tensor into more than " << SPLIT_MAX_SLICE_NUM
|
||||
<< " slices";
|
||||
free(param);
|
||||
return nullptr;
|
||||
}
|
||||
param->num_split_ = static_cast<int>(ratio->size());
|
||||
param->num_split_ = value->number_split();
|
||||
param->split_dim_ = value->split_dim();
|
||||
|
||||
if (param->num_split_ > SPLIT_MAX_SLICE_NUM) {
|
||||
MS_LOG(ERROR) << "SplitWithOverlap num_split_ error.";
|
||||
free(param);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto ratio = value->ratio();
|
||||
auto extend_top = value->extend_top();
|
||||
auto extend_bottom = value->extend_bottom();
|
||||
if (extend_top->size() != ratio->size() || (extend_bottom != nullptr && extend_bottom->size() != ratio->size())) {
|
||||
MS_LOG(ERROR) << "The sizes of ratio, extend_top and extend_bottom are not identical";
|
||||
if (ratio == nullptr || extend_top == nullptr || extend_bottom == nullptr) {
|
||||
MS_LOG(ERROR) << "SplitWithOverlap parameter is nullptr.";
|
||||
free(param);
|
||||
return nullptr;
|
||||
}
|
||||
if (static_cast<int>(ratio->size()) != param->num_split_ ||
|
||||
static_cast<int>(extend_top->size()) != param->num_split_ ||
|
||||
static_cast<int>(extend_bottom->size()) != param->num_split_) {
|
||||
MS_LOG(ERROR) << "SplitWithOverlap parameter size error.";
|
||||
free(param);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -65,9 +67,6 @@ OpParameter *PopulateSplitWithOverlapParameter(const void *prim) {
|
|||
param->extend_bottom_[i] = (*extend_bottom)[i];
|
||||
}
|
||||
|
||||
param->split_stride_ = value->split_stride();
|
||||
param->pad_top_ = value->pad_top();
|
||||
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
||||
|
|
|
@ -40,13 +40,7 @@ void SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const std::vector<int
|
|||
for (auto i = 0; i < param_->num_split_ - 1; i++) {
|
||||
visited_block += param_->ratio_[i];
|
||||
auto cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
|
||||
if (param_->split_stride_ != 0) {
|
||||
// make sure border align with stride
|
||||
cur_border = UP_ROUND(cur_border + param_->pad_top_, param_->split_stride_);
|
||||
borders.emplace_back(cur_border - param_->pad_top_);
|
||||
} else {
|
||||
borders.emplace_back(cur_border);
|
||||
}
|
||||
borders.emplace_back(cur_border);
|
||||
}
|
||||
borders.emplace_back(split_dim_size);
|
||||
|
||||
|
@ -74,24 +68,26 @@ int SplitWithOverlapBaseCPUKernel::ReSize() {
|
|||
|
||||
CalculateSplitedShapes(input_shape);
|
||||
|
||||
element_bytes_ = static_cast<int>(lite::DataTypeSize(in_tensor->data_type()));
|
||||
param_->element_bytes_ = static_cast<int>(lite::DataTypeSize(in_tensor->data_type()));
|
||||
|
||||
outer_total_dim_ = 1;
|
||||
inner_stride_ = 1;
|
||||
param_->outer_total_dim_ = 1;
|
||||
param_->inner_stride_ = 1;
|
||||
|
||||
for (int i = 0; i < static_cast<int>(input_shape.size()); i++) {
|
||||
if (i < param_->split_dim_) outer_total_dim_ *= input_shape[i];
|
||||
if (i == param_->split_dim_) split_dim_size_ = input_shape[param_->split_dim_];
|
||||
if (i > param_->split_dim_) inner_stride_ *= input_shape[i];
|
||||
if (i < param_->split_dim_) param_->outer_total_dim_ *= input_shape[i];
|
||||
if (i == param_->split_dim_) param_->split_dim_size_ = input_shape[param_->split_dim_];
|
||||
if (i > param_->split_dim_) param_->inner_stride_ *= input_shape[i];
|
||||
}
|
||||
|
||||
thread_count_ = MSMIN(param_->num_split_, op_parameter_->thread_num_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SplitWithOverlapBaseCPUKernel::Split(int task_id) {
|
||||
DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), task_id, split_dim_size_, element_bytes_, outer_total_dim_,
|
||||
inner_stride_, start_indices_.data(), end_indices_.data());
|
||||
|
||||
for (int current_slice_task = task_id; current_slice_task < param_->num_split_; current_slice_task += thread_count_) {
|
||||
DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), current_slice_task, param_, start_indices_.data(),
|
||||
end_indices_.data());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -113,7 +109,7 @@ int SplitWithOverlapBaseCPUKernel::Run() {
|
|||
output_ptr_.push_back(reinterpret_cast<char *>(out_tensors_.at(i)->data_c()));
|
||||
}
|
||||
|
||||
auto ret = ParallelLaunch(this->context_, SplitWithOverlapRun, this, param_->num_split_);
|
||||
auto ret = ParallelLaunch(this->context_, SplitWithOverlapRun, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -47,11 +47,8 @@ class SplitWithOverlapBaseCPUKernel : public InnerKernel {
|
|||
std::vector<int> end_indices_;
|
||||
|
||||
SplitWithOverlapParameter *param_ = nullptr;
|
||||
int thread_count_;
|
||||
|
||||
int outer_total_dim_{0};
|
||||
int inner_stride_{0};
|
||||
int element_bytes_{0};
|
||||
int split_dim_size_{0};
|
||||
char *input_ptr_{nullptr};
|
||||
std::vector<char *> output_ptr_;
|
||||
};
|
||||
|
|
|
@ -3,4 +3,5 @@
|
|||
# model_file ### accuracy_limit ### enable_fp16(true or false)
|
||||
mtk_model_normalize_object_scene_ps_20200519_f32.tflite;0.5;false
|
||||
hiai_cv_poseEstimation.tflite;0.5;false
|
||||
hiai_lm_inference_graph.pb;0.5;false
|
||||
# end
|
|
@ -306,10 +306,7 @@ void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNo
|
|||
split_prim->set_ratio(split_info->size_splits);
|
||||
split_prim->set_extend_top(split_info->extend_top);
|
||||
split_prim->set_extend_bottom(split_info->extend_bottom);
|
||||
// default to format khwc or nhwc
|
||||
split_prim->set_trans_format(false);
|
||||
auto conv_cnode = conv_node->cast<CNodePtr>();
|
||||
split_prim->set_split_stride(0);
|
||||
|
||||
// the inputs of split is from the inputs of conv
|
||||
std::vector<AnfNodePtr> split_inputs = {NewValueNode(split_prim)};
|
||||
|
@ -337,5 +334,28 @@ void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNo
|
|||
split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
|
||||
}
|
||||
|
||||
void UpdateRatioWithPadStride(int64_t *ratio, size_t split_size, int split_dim_size, int pad, int stride) {
|
||||
if (stride == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int total_block_count = 0;
|
||||
for (size_t i = 0; i < split_size; i++) {
|
||||
total_block_count += ratio[i];
|
||||
}
|
||||
|
||||
std::vector<int64_t> new_ratio(split_size);
|
||||
int visited_block = 0;
|
||||
for (size_t i = 0; i < split_size - 1; i++) {
|
||||
visited_block += ratio[i];
|
||||
int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
|
||||
new_ratio[i + 1] = cur_border;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < split_size; i++) {
|
||||
ratio[i] = new_ratio[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -66,6 +66,7 @@ AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
|
||||
std::vector<AnfNodePtr> *split_outputs, SplitInfo *split_info,
|
||||
const std::string &node_name);
|
||||
void UpdateRatioWithPadStride(int64_t *ratio, size_t split_size, int split_dim_size, int pad, int stride);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_FISSON_UTIL_H_
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "tools/optimizer/parallel/operator_info_register.h"
|
||||
#include "tools/optimizer/parallel/spliter.h"
|
||||
#include "tools/optimizer/fisson/fisson_util.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_Conv2DFusion;
|
||||
namespace mindspore {
|
||||
|
@ -128,28 +130,37 @@ int Conv2DInfo::CheckIfSplit() {
|
|||
|
||||
AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
|
||||
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
|
||||
const std::vector<int64_t> &splits, bool trans_format) {
|
||||
const std::vector<int64_t> &splits) {
|
||||
auto graph_node_input_shapes = Spliter::GetInstance()->graph_node_input_shapes();
|
||||
auto ori_node_name = orig_node->fullname_with_scope();
|
||||
auto input_shape_iter = graph_node_input_shapes.find(ori_node_name);
|
||||
if (input_shape_iter == graph_node_input_shapes.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto input_shapes = input_shape_iter->second;
|
||||
auto input_shape = input_shapes.front();
|
||||
|
||||
auto conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode_->input(kAnfPrimitiveIndex));
|
||||
|
||||
// prim of split
|
||||
auto split_prim = std::make_shared<ops::SplitWithOverlap>();
|
||||
split_prim->set_split_dim(split_dim);
|
||||
split_prim->set_number_split(split_num);
|
||||
split_prim->set_ratio(splits);
|
||||
split_prim->set_trans_format(trans_format);
|
||||
std::vector<int64_t> new_splits = splits;
|
||||
if (split_mode_ == SplitH) {
|
||||
split_prim->set_extend_top(std::vector<int64_t>(split_num, 0));
|
||||
auto extend_bottom = conv_prim->get_kernel_size().at(kIndexH) - conv_prim->get_stride().at(kIndexH);
|
||||
auto bottom_vector = std::vector<int64_t>(split_num, extend_bottom);
|
||||
bottom_vector[split_num - 1] = 0;
|
||||
split_prim->set_extend_bottom(bottom_vector);
|
||||
split_prim->set_split_stride(conv_prim->get_stride().at(kIndexH));
|
||||
split_prim->set_pad_top(conv_prim->get_pad_list().at(kPadUp));
|
||||
UpdateRatioWithPadStride(new_splits.data(), split_num, input_shape[split_dim], conv_prim->get_pad_list().at(kPadUp),
|
||||
conv_prim->get_stride().at(kIndexH));
|
||||
} else {
|
||||
split_prim->set_extend_top(std::vector<int64_t>(split_num, 0));
|
||||
split_prim->set_extend_bottom(std::vector<int64_t>(split_num, 0));
|
||||
split_prim->set_split_stride(0);
|
||||
split_prim->set_pad_top(0);
|
||||
}
|
||||
split_prim->set_split_dim(split_dim);
|
||||
split_prim->set_number_split(split_num);
|
||||
split_prim->set_ratio(new_splits);
|
||||
|
||||
std::vector<AnfNodePtr> split_inputs = {NewValueNode(split_prim)};
|
||||
// ori_conv_node must only have one input
|
||||
split_inputs.push_back(orig_node->input(input_index + 1));
|
||||
|
@ -197,8 +208,7 @@ int Conv2DInfo::InferParallelCNodes() {
|
|||
case SplitN:
|
||||
case SplitH: {
|
||||
name_ = orig_name + "_input";
|
||||
auto feature_split_cnode =
|
||||
CreateOutputsOfSplit(cnode_, 0, &feature_split_outputs, kAxisH, dev_num, splits_, true);
|
||||
auto feature_split_cnode = CreateOutputsOfSplit(cnode_, 0, &feature_split_outputs, kAxisH, dev_num, splits_);
|
||||
if (CheckSplitResult(feature_split_cnode, feature_split_outputs, dev_num) != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -42,8 +42,7 @@ class Conv2DInfo : public OperatorInfo {
|
|||
const std::vector<AnfNodePtr> &kernel_split_outputs,
|
||||
const std::vector<AnfNodePtr> &bias_split_outputs);
|
||||
AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs,
|
||||
size_t split_dim, size_t split_num, const std::vector<int64_t> &splits,
|
||||
bool trans_format) override;
|
||||
size_t split_dim, size_t split_num, const std::vector<int64_t> &splits) override;
|
||||
|
||||
protected:
|
||||
SplitMode split_mode_ = NoSplit;
|
||||
|
|
|
@ -311,12 +311,10 @@ AnfNodePtr DepthwiseConv2DInfo::CreateOutputsOfSplit(const CNodePtr &ori_node, s
|
|||
auto pad_list = GetSplitPadList(depth_wise_conv_prim, input_h, input_w);
|
||||
depth_wise_conv_prim->set_pad_list(pad_list);
|
||||
depth_wise_conv_prim->set_pad_mode(PAD);
|
||||
|
||||
// prim of split
|
||||
auto split_prim = std::make_shared<ops::SplitWithOverlap>();
|
||||
split_prim->set_split_dim(split_dim_);
|
||||
split_prim->set_number_split(split_num);
|
||||
split_prim->set_ratio(splits);
|
||||
split_prim->set_trans_format(false);
|
||||
std::vector<int64_t> new_splits = splits;
|
||||
if (split_mode_ == SplitH) {
|
||||
split_prim->set_extend_top(std::vector<int64_t>(split_num, 0));
|
||||
auto extend_bottom =
|
||||
|
@ -324,14 +322,17 @@ AnfNodePtr DepthwiseConv2DInfo::CreateOutputsOfSplit(const CNodePtr &ori_node, s
|
|||
auto bottom_vector = std::vector<int64_t>(split_num, extend_bottom);
|
||||
bottom_vector[split_num - 1] = 0;
|
||||
split_prim->set_extend_bottom(bottom_vector);
|
||||
split_prim->set_split_stride(depth_wise_conv_prim->get_stride().at(kIndexH));
|
||||
split_prim->set_pad_top(depth_wise_conv_prim->get_pad_list().at(kPadUp));
|
||||
UpdateRatioWithPadStride(new_splits.data(), split_num, input_shape[split_dim_],
|
||||
depth_wise_conv_prim->get_pad_list().at(kPadUp),
|
||||
depth_wise_conv_prim->get_stride().at(kIndexH));
|
||||
} else {
|
||||
split_prim->set_extend_top(std::vector<int64_t>(split_num, 0));
|
||||
split_prim->set_extend_bottom(std::vector<int64_t>(split_num, 0));
|
||||
split_prim->set_split_stride(0);
|
||||
split_prim->set_pad_top(0);
|
||||
}
|
||||
split_prim->set_split_dim(split_dim_);
|
||||
split_prim->set_number_split(split_num);
|
||||
split_prim->set_ratio(new_splits);
|
||||
|
||||
std::vector<AnfNodePtr> split_inputs = {NewValueNode(split_prim)};
|
||||
// ori_conv_node must only have one feature input
|
||||
split_inputs.push_back(ori_node->input(input_index + 1));
|
||||
|
|
|
@ -73,7 +73,7 @@ class OperatorInfo {
|
|||
|
||||
virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &input_node, size_t input_index,
|
||||
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
|
||||
const std::vector<int64_t> &splits, bool trans_format) = 0;
|
||||
const std::vector<int64_t> &splits) = 0;
|
||||
virtual int InferReplaceOp() = 0;
|
||||
virtual int InferParallelCNodes() = 0;
|
||||
virtual int CheckStrategy(const SplitStrategy &strategy) = 0;
|
||||
|
|
Loading…
Reference in New Issue