split-with-over-lap op bug

This commit is contained in:
ling 2021-06-25 11:09:36 +08:00
parent 0cd459fb47
commit 74b6edfdba
16 changed files with 113 additions and 113 deletions

View File

@ -15,32 +15,22 @@
*/
#include "nnacl/base/split_with_over_lap_base.h"
#include "nnacl/split_parameter.h"
#include <string.h>
#include "nnacl/errorcode.h"
int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes,
int outer_total_dim, int inner_stride, const int *start_indices, const int *end_indices) {
int input_stride = split_dim_size * inner_stride * element_bytes;
for (int slice_idx = 0; slice_idx < num_split; slice_idx++) {
int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes;
char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes;
for (int out_idx = 0; out_idx < outer_total_dim; out_idx++) {
(void)(memcpy(out_data[slice_idx] + out_idx * out_stride, src_ptr, out_stride));
src_ptr += input_stride;
}
}
return NNACL_OK;
}
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, SplitWithOverlapParameter *param,
const int *start_indices, const int *end_indices) {
int start_index = start_indices[slice_idx];
int end_index = end_indices[slice_idx];
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes,
int outer_total_dim, int inner_stride, const int *start_indices,
const int *end_indices) {
int input_stride = split_dim_size * inner_stride * element_bytes;
int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes;
char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes;
for (int i = 0; i < outer_total_dim; i++) {
(void)memcpy(out_data[slice_idx] + i * out_stride, src_ptr, out_stride);
int input_stride = param->split_dim_size_ * param->inner_stride_ * param->element_bytes_;
int out_stride = (end_index - start_index) * param->inner_stride_ * param->element_bytes_;
char *src_ptr = in_data + start_index * param->inner_stride_ * param->element_bytes_;
char *dst_ptr = out_data[slice_idx];
for (int i = 0; i < param->outer_total_dim_; i++) {
(void)memcpy(dst_ptr + i * out_stride, src_ptr, out_stride);
src_ptr += input_stride;
}
return NNACL_OK;

View File

@ -23,12 +23,8 @@
#ifdef __cplusplus
extern "C" {
#endif
int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes,
int outer_total_dim, int inner_stride, const int *start_indices, const int *end_indices);
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes,
int outer_total_dim, int inner_stride, const int *start_indices, const int *end_indices);
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, SplitWithOverlapParameter *param,
const int *start_indices, const int *end_indices);
#ifdef __cplusplus
}
#endif

View File

@ -29,13 +29,12 @@ int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size,
}
const TensorC *input = inputs[0];
SplitWithOverlapParameter *param = (SplitWithOverlapParameter *)parameter;
int split_dim = param->split_dim_;
int number_split = param->num_split_;
if (outputs_size != number_split) {
return NNACL_ERR;
}
int stride = param->split_stride_;
int pad_top = param->pad_top_;
int split_dim = param->split_dim_;
int ratio[SPLIT_MAX_SLICE_NUM];
int extend_top[SPLIT_MAX_SLICE_NUM];
@ -58,15 +57,8 @@ int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size,
int visited_block = 0;
for (int i = 0; i < number_split - 1; i++) {
visited_block += ratio[i];
int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
if (stride != 0) {
// make sure border align with stride
cur_border = UP_ROUND(cur_border + pad_top, stride);
borders[i + 1] = cur_border - pad_top;
} else {
borders[i + 1] = cur_border;
}
borders[i + 1] = cur_border;
}
borders[number_split] = split_dim_size;

View File

@ -49,11 +49,15 @@ typedef struct SplitWithOverlapParameter {
OpParameter op_parameter_;
int num_split_;
int split_dim_;
int split_stride_;
int pad_top_;
int ratio_[SPLIT_MAX_SLICE_NUM];
int extend_top_[SPLIT_MAX_SLICE_NUM];
int extend_bottom_[SPLIT_MAX_SLICE_NUM];
// other parameter
int element_bytes_;
int split_dim_size_;
int outer_total_dim_;
int inner_stride_;
} SplitWithOverlapParameter;
#endif // MINDSPORE_NNACL_SPLIT_PARAMETER_H_

View File

@ -1125,14 +1125,11 @@ table Custom {
}
table SplitWithOverlap {
split_dim: long;
number_split: long;
ratio: [long];
extend_top: [long];
extend_bottom: [long];
split_dim: long;
split_stride: long;
pad_top: long;
trans_format: bool = false;
}
table GenOP {

View File

@ -1125,14 +1125,11 @@ OP_ATTR_ONLY(attr, [Attribute])
OP_SCHEMA_DEF_ONLY_END(Custom)
OP_SCHEMA_DEF(SplitWithOverlap)
OP_ATTR(split_dim, long)
OP_ATTR(number_split, long)
OP_ATTR(ratio, [long])
OP_ATTR(extend_top, [long])
OP_ATTR(extend_bottom, [long])
OP_ATTR(split_dim, long)
OP_ATTR(split_stride, long)
OP_ATTR(pad_top, long)
OP_ATTR_WITH_VALUE(trans_format, bool, false)
OP_SCHEMA_DEF_END(SplitWithOverlap)
OP_SCHEMA_DEF_ONLY(GenOP)

View File

@ -36,25 +36,27 @@ OpParameter *PopulateSplitWithOverlapParameter(const void *prim) {
memset(param, 0, sizeof(SplitWithOverlapParameter));
param->op_parameter_.type_ = primitive->value_type();
auto ratio = value->ratio();
if (ratio == nullptr) {
MS_LOG(ERROR) << "ratio is nullptr";
free(param);
return nullptr;
}
if (ratio->size() > SPLIT_MAX_SLICE_NUM) {
MS_LOG(ERROR) << "SplitWithOverlap do not support splitting tensor into more than " << SPLIT_MAX_SLICE_NUM
<< " slices";
free(param);
return nullptr;
}
param->num_split_ = static_cast<int>(ratio->size());
param->num_split_ = value->number_split();
param->split_dim_ = value->split_dim();
if (param->num_split_ > SPLIT_MAX_SLICE_NUM) {
MS_LOG(ERROR) << "SplitWithOverlap num_split_ error.";
free(param);
return nullptr;
}
auto ratio = value->ratio();
auto extend_top = value->extend_top();
auto extend_bottom = value->extend_bottom();
if (extend_top->size() != ratio->size() || (extend_bottom != nullptr && extend_bottom->size() != ratio->size())) {
MS_LOG(ERROR) << "The sizes of ratio, extend_top and extend_bottom are not identical";
if (ratio == nullptr || extend_top == nullptr || extend_bottom == nullptr) {
MS_LOG(ERROR) << "SplitWithOverlap parameter is nullptr.";
free(param);
return nullptr;
}
if (static_cast<int>(ratio->size()) != param->num_split_ ||
static_cast<int>(extend_top->size()) != param->num_split_ ||
static_cast<int>(extend_bottom->size()) != param->num_split_) {
MS_LOG(ERROR) << "SplitWithOverlap parameter size error.";
free(param);
return nullptr;
}
@ -65,9 +67,6 @@ OpParameter *PopulateSplitWithOverlapParameter(const void *prim) {
param->extend_bottom_[i] = (*extend_bottom)[i];
}
param->split_stride_ = value->split_stride();
param->pad_top_ = value->pad_top();
return reinterpret_cast<OpParameter *>(param);
}

View File

@ -40,13 +40,7 @@ void SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const std::vector<int
for (auto i = 0; i < param_->num_split_ - 1; i++) {
visited_block += param_->ratio_[i];
auto cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
if (param_->split_stride_ != 0) {
// make sure border align with stride
cur_border = UP_ROUND(cur_border + param_->pad_top_, param_->split_stride_);
borders.emplace_back(cur_border - param_->pad_top_);
} else {
borders.emplace_back(cur_border);
}
borders.emplace_back(cur_border);
}
borders.emplace_back(split_dim_size);
@ -74,24 +68,26 @@ int SplitWithOverlapBaseCPUKernel::ReSize() {
CalculateSplitedShapes(input_shape);
element_bytes_ = static_cast<int>(lite::DataTypeSize(in_tensor->data_type()));
param_->element_bytes_ = static_cast<int>(lite::DataTypeSize(in_tensor->data_type()));
outer_total_dim_ = 1;
inner_stride_ = 1;
param_->outer_total_dim_ = 1;
param_->inner_stride_ = 1;
for (int i = 0; i < static_cast<int>(input_shape.size()); i++) {
if (i < param_->split_dim_) outer_total_dim_ *= input_shape[i];
if (i == param_->split_dim_) split_dim_size_ = input_shape[param_->split_dim_];
if (i > param_->split_dim_) inner_stride_ *= input_shape[i];
if (i < param_->split_dim_) param_->outer_total_dim_ *= input_shape[i];
if (i == param_->split_dim_) param_->split_dim_size_ = input_shape[param_->split_dim_];
if (i > param_->split_dim_) param_->inner_stride_ *= input_shape[i];
}
thread_count_ = MSMIN(param_->num_split_, op_parameter_->thread_num_);
return RET_OK;
}
int SplitWithOverlapBaseCPUKernel::Split(int task_id) {
DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), task_id, split_dim_size_, element_bytes_, outer_total_dim_,
inner_stride_, start_indices_.data(), end_indices_.data());
for (int current_slice_task = task_id; current_slice_task < param_->num_split_; current_slice_task += thread_count_) {
DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), current_slice_task, param_, start_indices_.data(),
end_indices_.data());
}
return RET_OK;
}
@ -113,7 +109,7 @@ int SplitWithOverlapBaseCPUKernel::Run() {
output_ptr_.push_back(reinterpret_cast<char *>(out_tensors_.at(i)->data_c()));
}
auto ret = ParallelLaunch(this->context_, SplitWithOverlapRun, this, param_->num_split_);
auto ret = ParallelLaunch(this->context_, SplitWithOverlapRun, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]";
return RET_ERROR;

View File

@ -47,11 +47,8 @@ class SplitWithOverlapBaseCPUKernel : public InnerKernel {
std::vector<int> end_indices_;
SplitWithOverlapParameter *param_ = nullptr;
int thread_count_;
int outer_total_dim_{0};
int inner_stride_{0};
int element_bytes_{0};
int split_dim_size_{0};
char *input_ptr_{nullptr};
std::vector<char *> output_ptr_;
};

View File

@ -3,4 +3,5 @@
# model_file ### accuracy_limit ### enable_fp16(true or false)
mtk_model_normalize_object_scene_ps_20200519_f32.tflite;0.5;false
hiai_cv_poseEstimation.tflite;0.5;false
hiai_lm_inference_graph.pb;0.5;false
# end

View File

@ -306,10 +306,7 @@ void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNo
split_prim->set_ratio(split_info->size_splits);
split_prim->set_extend_top(split_info->extend_top);
split_prim->set_extend_bottom(split_info->extend_bottom);
// default to format khwc or nhwc
split_prim->set_trans_format(false);
auto conv_cnode = conv_node->cast<CNodePtr>();
split_prim->set_split_stride(0);
// the inputs of split is from the inputs of conv
std::vector<AnfNodePtr> split_inputs = {NewValueNode(split_prim)};
@ -337,5 +334,28 @@ void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNo
split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
}
void UpdateRatioWithPadStride(int64_t *ratio, size_t split_size, int split_dim_size, int pad, int stride) {
if (stride == 0) {
return;
}
int total_block_count = 0;
for (size_t i = 0; i < split_size; i++) {
total_block_count += ratio[i];
}
std::vector<int64_t> new_ratio(split_size);
int visited_block = 0;
for (size_t i = 0; i < split_size - 1; i++) {
visited_block += ratio[i];
int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
new_ratio[i + 1] = cur_border;
}
for (size_t i = 0; i < split_size; i++) {
ratio[i] = new_ratio[i];
}
return;
}
} // namespace opt
} // namespace mindspore

View File

@ -66,6 +66,7 @@ AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePt
void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
std::vector<AnfNodePtr> *split_outputs, SplitInfo *split_info,
const std::string &node_name);
void UpdateRatioWithPadStride(int64_t *ratio, size_t split_size, int split_dim_size, int pad, int stride);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_FISSON_UTIL_H_

View File

@ -26,6 +26,8 @@
#include "tools/converter/converter_flags.h"
#include "include/errorcode.h"
#include "tools/optimizer/parallel/operator_info_register.h"
#include "tools/optimizer/parallel/spliter.h"
#include "tools/optimizer/fisson/fisson_util.h"
using mindspore::schema::PrimitiveType_Conv2DFusion;
namespace mindspore {
@ -128,28 +130,37 @@ int Conv2DInfo::CheckIfSplit() {
AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index,
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
const std::vector<int64_t> &splits, bool trans_format) {
const std::vector<int64_t> &splits) {
auto graph_node_input_shapes = Spliter::GetInstance()->graph_node_input_shapes();
auto ori_node_name = orig_node->fullname_with_scope();
auto input_shape_iter = graph_node_input_shapes.find(ori_node_name);
if (input_shape_iter == graph_node_input_shapes.end()) {
return nullptr;
}
auto input_shapes = input_shape_iter->second;
auto input_shape = input_shapes.front();
auto conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode_->input(kAnfPrimitiveIndex));
// prim of split
auto split_prim = std::make_shared<ops::SplitWithOverlap>();
split_prim->set_split_dim(split_dim);
split_prim->set_number_split(split_num);
split_prim->set_ratio(splits);
split_prim->set_trans_format(trans_format);
std::vector<int64_t> new_splits = splits;
if (split_mode_ == SplitH) {
split_prim->set_extend_top(std::vector<int64_t>(split_num, 0));
auto extend_bottom = conv_prim->get_kernel_size().at(kIndexH) - conv_prim->get_stride().at(kIndexH);
auto bottom_vector = std::vector<int64_t>(split_num, extend_bottom);
bottom_vector[split_num - 1] = 0;
split_prim->set_extend_bottom(bottom_vector);
split_prim->set_split_stride(conv_prim->get_stride().at(kIndexH));
split_prim->set_pad_top(conv_prim->get_pad_list().at(kPadUp));
UpdateRatioWithPadStride(new_splits.data(), split_num, input_shape[split_dim], conv_prim->get_pad_list().at(kPadUp),
conv_prim->get_stride().at(kIndexH));
} else {
split_prim->set_extend_top(std::vector<int64_t>(split_num, 0));
split_prim->set_extend_bottom(std::vector<int64_t>(split_num, 0));
split_prim->set_split_stride(0);
split_prim->set_pad_top(0);
}
split_prim->set_split_dim(split_dim);
split_prim->set_number_split(split_num);
split_prim->set_ratio(new_splits);
std::vector<AnfNodePtr> split_inputs = {NewValueNode(split_prim)};
// ori_conv_node must only have one input
split_inputs.push_back(orig_node->input(input_index + 1));
@ -197,8 +208,7 @@ int Conv2DInfo::InferParallelCNodes() {
case SplitN:
case SplitH: {
name_ = orig_name + "_input";
auto feature_split_cnode =
CreateOutputsOfSplit(cnode_, 0, &feature_split_outputs, kAxisH, dev_num, splits_, true);
auto feature_split_cnode = CreateOutputsOfSplit(cnode_, 0, &feature_split_outputs, kAxisH, dev_num, splits_);
if (CheckSplitResult(feature_split_cnode, feature_split_outputs, dev_num) != RET_OK) {
return RET_ERROR;
}

View File

@ -42,8 +42,7 @@ class Conv2DInfo : public OperatorInfo {
const std::vector<AnfNodePtr> &kernel_split_outputs,
const std::vector<AnfNodePtr> &bias_split_outputs);
AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs,
size_t split_dim, size_t split_num, const std::vector<int64_t> &splits,
bool trans_format) override;
size_t split_dim, size_t split_num, const std::vector<int64_t> &splits) override;
protected:
SplitMode split_mode_ = NoSplit;

View File

@ -311,12 +311,10 @@ AnfNodePtr DepthwiseConv2DInfo::CreateOutputsOfSplit(const CNodePtr &ori_node, s
auto pad_list = GetSplitPadList(depth_wise_conv_prim, input_h, input_w);
depth_wise_conv_prim->set_pad_list(pad_list);
depth_wise_conv_prim->set_pad_mode(PAD);
// prim of split
auto split_prim = std::make_shared<ops::SplitWithOverlap>();
split_prim->set_split_dim(split_dim_);
split_prim->set_number_split(split_num);
split_prim->set_ratio(splits);
split_prim->set_trans_format(false);
std::vector<int64_t> new_splits = splits;
if (split_mode_ == SplitH) {
split_prim->set_extend_top(std::vector<int64_t>(split_num, 0));
auto extend_bottom =
@ -324,14 +322,17 @@ AnfNodePtr DepthwiseConv2DInfo::CreateOutputsOfSplit(const CNodePtr &ori_node, s
auto bottom_vector = std::vector<int64_t>(split_num, extend_bottom);
bottom_vector[split_num - 1] = 0;
split_prim->set_extend_bottom(bottom_vector);
split_prim->set_split_stride(depth_wise_conv_prim->get_stride().at(kIndexH));
split_prim->set_pad_top(depth_wise_conv_prim->get_pad_list().at(kPadUp));
UpdateRatioWithPadStride(new_splits.data(), split_num, input_shape[split_dim_],
depth_wise_conv_prim->get_pad_list().at(kPadUp),
depth_wise_conv_prim->get_stride().at(kIndexH));
} else {
split_prim->set_extend_top(std::vector<int64_t>(split_num, 0));
split_prim->set_extend_bottom(std::vector<int64_t>(split_num, 0));
split_prim->set_split_stride(0);
split_prim->set_pad_top(0);
}
split_prim->set_split_dim(split_dim_);
split_prim->set_number_split(split_num);
split_prim->set_ratio(new_splits);
std::vector<AnfNodePtr> split_inputs = {NewValueNode(split_prim)};
// ori_conv_node must only have one feature input
split_inputs.push_back(ori_node->input(input_index + 1));

View File

@ -73,7 +73,7 @@ class OperatorInfo {
virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &input_node, size_t input_index,
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
const std::vector<int64_t> &splits, bool trans_format) = 0;
const std::vector<int64_t> &splits) = 0;
virtual int InferReplaceOp() = 0;
virtual int InferParallelCNodes() = 0;
virtual int CheckStrategy(const SplitStrategy &strategy) = 0;