From 1bdf4a08d24a1f64ba354a278676b7165105901d Mon Sep 17 00:00:00 2001 From: z00512249 Date: Thu, 29 Apr 2021 10:43:38 +0800 Subject: [PATCH] add split_with_overlap --- .../cpu/nnacl/base/split_with_over_lap_base.c | 46 ++++++ .../cpu/nnacl/base/split_with_over_lap_base.h | 36 +++++ .../cpu/nnacl/infer/infer_register.h | 3 +- .../nnacl/infer/split_with_over_lap_infer.c | 97 +++++++++++++ .../nnacl/infer/split_with_over_lap_infer.h | 32 +++++ .../cpu/nnacl/split_parameter.h | 12 ++ mindspore/core/ops/op_utils.h | 7 + mindspore/core/ops/split_with_overlap.cc | 96 +++++++++++++ mindspore/core/ops/split_with_overlap.h | 58 ++++++++ mindspore/lite/schema/ops.fbs | 12 ++ mindspore/lite/src/ops/ops_def.cc | 12 ++ mindspore/lite/src/ops/ops_func_declare.h | 2 + .../populate/split_with_overlap_populate.cc | 67 +++++++++ .../arm/base/split_with_over_lap_base.cc | 133 ++++++++++++++++++ .../arm/base/split_with_over_lap_base.h | 57 ++++++++ 15 files changed, 669 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.c create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.c create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.h create mode 100644 mindspore/core/ops/split_with_overlap.cc create mode 100644 mindspore/core/ops/split_with_overlap.h create mode 100644 mindspore/lite/src/ops/populate/split_with_overlap_populate.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.c new file mode 100644 index 00000000000..70698ac376d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.c @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/split_with_over_lap_base.h" +#include "nnacl/split_parameter.h" +#include +#include "nnacl/errorcode.h" + +int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes, + int outer_total_dim, int inner_stride, int *start_indices, int *end_indices) { + int input_stride = split_dim_size * inner_stride * element_bytes; + for (int slice_idx = 0; slice_idx < num_split; slice_idx++) { + int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes; + char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes; + for (int out_idx = 0; out_idx < outer_total_dim; out_idx++) { + (void)(memcpy(out_data[slice_idx] + out_idx * out_stride, src_ptr, out_stride)); + src_ptr += input_stride; + } + } + return NNACL_OK; +} + +int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes, + int outer_total_dim, int inner_stride, int *start_indices, int *end_indices) { + int input_stride = split_dim_size * inner_stride * element_bytes; + int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes; + char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes; + for (int i = 0; i < outer_total_dim; i++) { + (void)memcpy(out_data[slice_idx] + i * out_stride, src_ptr, out_stride); + src_ptr += input_stride; + } + return NNACL_OK; +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.h new file mode 100644 index 00000000000..d1ac30e7610 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/split_with_over_lap_base.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ +#define MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes, + int outer_total_dim, int inner_stride, int *start_indices, int *end_indices); + +int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes, + int outer_total_dim, int inner_stride, int *start_indices, int *end_indices); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h index c57a02dfa7a..3a0e00c6c36 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h @@ -217,8 +217,9 @@ enum PrimType { PrimType_Call = 190, PrimType_Custom = 191, PrimType_CumSum = 192, + PrimType_SplitWithOverlap = 193, PrimType_MIN = PrimType_NONE, - PrimType_MAX = PrimType_CumSum + 1 + PrimType_MAX = PrimType_SplitWithOverlap + 1 }; void RegInfer(int prim_type, InferShape func); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.c new file mode 100644 index 00000000000..61c009127dd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.c @@ -0,0 +1,97 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/split_with_over_lap_infer.h" +#include "nnacl/infer/infer_register.h" + +int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { +#ifdef Debug + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } +#endif + if (!parameter->infer_flag_) { + return NNACL_INFER_INVALID; + } + const TensorC *input = inputs[0]; + if (inputs_size < 1) { + return NNACL_ERR; + } + if (outputs_size == 0) { + return NNACL_ERR; + } + SplitWithOverlapParameter *param = (SplitWithOverlapParameter *)parameter; + int number_split = param->num_split_; + if (outputs_size != number_split) { + return NNACL_ERR; + } + int stride = param->stride_; + int pad_top = param->pad_top_; + int split_dim = param->split_dim_; + + int ratio[SPLIT_MAX_SLICE_NUM]; + int extend_top[SPLIT_MAX_SLICE_NUM]; + int extend_bottom[SPLIT_MAX_SLICE_NUM]; + for (int i = 0; i < number_split; ++i) { + ratio[i] = param->ratio_[i]; + extend_top[i] = param->extend_top_[i]; + extend_bottom[i] = param->extend_bottom_[i]; + } + + const int *input_shape = input->shape_; + int split_dim_size = input_shape[split_dim]; + int total_block_count = 0; + for (int i = 0; i < number_split; i++) { + total_block_count += ratio[i]; + } + + int borders[MAX_SHAPE_SIZE]; + borders[0] = 0; + int visited_block = 0; + for (int i = 0; i < number_split - 1; i++) { + visited_block += ratio[i]; + + int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); + if (stride != 0) { + // make sure border align with stride + cur_border = UP_ROUND(cur_border + pad_top, stride); + borders[i + 1] = cur_border - pad_top; + } else { + borders[i + 1] = cur_border; + } + } + borders[number_split - 1] = split_dim_size; + + for (int i = 0; i < number_split; ++i) { + int output_shape[MAX_SHAPE_SIZE]; + for (int dim = 0; dim < input->shape_size_; dim++) { + if (dim == split_dim) { + int splited_size = borders[i + 1] - borders[i]; + splited_size += extend_top[i] + extend_bottom[i]; + output_shape[dim] = splited_size; + } else { + output_shape[dim] = input_shape[dim]; + } + } + SetShapeArray(outputs[i], output_shape, input->shape_size_); + SetDataTypeFormat(outputs[i], input); + } + return NNACL_OK; +} + +REG_INFER(SplitWithOverlap, PrimType_SplitWithOverlap, SplitWithOverlapInferShape) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.h new file mode 100644 index 00000000000..9e0793d4580 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/split_with_over_lap_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H +#define MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/split_parameter.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/split_parameter.h index 5346bb87f9a..0bf6649d300 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/split_parameter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/split_parameter.h @@ -20,6 +20,7 @@ #include "nnacl/op_base.h" #define SPLIT_STRIDES_SIZE 32 +#define SPLIT_MAX_SLICE_NUM 10 typedef struct SplitQuantArg { QuantArg in_args_; @@ -44,4 +45,15 @@ typedef struct SplitParameter { int split_count_; } SplitParameter; +typedef struct SplitWithOverlapParameter { + OpParameter op_parameter_; + int num_split_; + int split_dim_; + int stride_; + int pad_top_; + int ratio_[SPLIT_MAX_SLICE_NUM]; + int extend_top_[SPLIT_MAX_SLICE_NUM]; + int extend_bottom_[SPLIT_MAX_SLICE_NUM]; +} SplitWithOverlapParameter; + #endif // MINDSPORE_NNACL_SPLIT_PARAMETER_H_ diff --git a/mindspore/core/ops/op_utils.h b/mindspore/core/ops/op_utils.h index 23910244ce9..5b539b75280 100644 --- a/mindspore/core/ops/op_utils.h +++ b/mindspore/core/ops/op_utils.h @@ -234,6 +234,13 @@ constexpr auto kSideEffectIO = "side_effect_io"; constexpr auto kDeviceType = "device_type"; constexpr auto kExclusive = "exclusive"; constexpr auto kReverse = "reverse"; +constexpr auto kSplitStride = "split_stride"; +constexpr auto kExtendTop = "extend_top"; +constexpr auto kExtendBottom = "extend_bottom"; +constexpr auto kNumberSplit = "number_split"; +constexpr auto kSplitDim = "split_dim"; +constexpr auto kPadTop = "pad_top"; +constexpr auto kTransFormat = "trans_format"; const std::set common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; diff --git a/mindspore/core/ops/split_with_overlap.cc b/mindspore/core/ops/split_with_overlap.cc new file mode 100644 index 00000000000..6cb14473f16 --- /dev/null +++ b/mindspore/core/ops/split_with_overlap.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/split_with_overlap.h" +#include "ops/op_utils.h" +namespace mindspore { +namespace ops { +void SplitWithOverlap::Init(int64_t number_split, const std::vector &ratio, + const std::vector &extend_top, const std::vector &extend_bottom, + int64_t split_dim, int64_t stride, int64_t pad_top, bool trans_format) { + this->set_number_split(number_split); + this->set_ratio(ratio); + this->set_extend_top(extend_top); + this->set_extend_bottom(extend_bottom); + this->set_split_dim(split_dim); + this->set_stride(stride); + this->set_pad_top(pad_top); + this->set_trans_format(trans_format); +} + +void SplitWithOverlap::set_ratio(const std::vector &ratio) { this->AddAttr(kRatio, MakeValue(ratio)); } + +void SplitWithOverlap::set_extend_top(const std::vector &extend_top) { + this->AddAttr(kExtendTop, MakeValue(extend_top)); +} + +void SplitWithOverlap::set_extend_bottom(const std::vector &extend_bottom) { + this->AddAttr(kExtendBottom, MakeValue(extend_bottom)); +} + +void SplitWithOverlap::set_number_split(int64_t number_split) { this->AddAttr(kNumberSplit, MakeValue(number_split)); } + +void SplitWithOverlap::set_split_dim(int64_t split_dim) { this->AddAttr(kSplitDim, MakeValue(split_dim)); } + +void SplitWithOverlap::set_stride(int64_t stride) { this->AddAttr(kSplitStride, MakeValue(stride)); } + +void SplitWithOverlap::set_pad_top(int64_t pad_top) { this->AddAttr(kPadTop, MakeValue(pad_top)); } + +void SplitWithOverlap::set_trans_format(bool trans_format) { this->AddAttr(kTransFormat, MakeValue(trans_format)); } + +std::vector SplitWithOverlap::get_ratio() const { + auto value_ptr = GetAttr(kRatio); + return GetValue>(value_ptr); +} + +std::vector SplitWithOverlap::get_extend_top() const { + auto value_ptr = GetAttr(kExtendTop); + return GetValue>(value_ptr); +} + +std::vector SplitWithOverlap::get_extend_bottom() const { + auto value_ptr = GetAttr(kExtendBottom); + return GetValue>(value_ptr); +} + +int64_t SplitWithOverlap::get_number_split() const { + auto value_ptr = GetAttr(kNumberSplit); + return GetValue(value_ptr); +} + +int64_t SplitWithOverlap::get_split_dim() const { + auto value_ptr = GetAttr(kSplitDim); + return GetValue(value_ptr); +} + +int64_t SplitWithOverlap::get_stride() const { + auto value_ptr = GetAttr(kSplitStride); + return GetValue(value_ptr); +} + +int64_t SplitWithOverlap::get_pad_top() const { + auto value_ptr = GetAttr(kPadTop); + return GetValue(value_ptr); +} + +bool SplitWithOverlap::get_trans_format() const { + auto value_ptr = GetAttr(kTransFormat); + return GetValue(value_ptr); +} + +REGISTER_PRIMITIVE_C(kNameSplitWithOverlap, SplitWithOverlap); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/split_with_overlap.h b/mindspore/core/ops/split_with_overlap.h new file mode 100644 index 00000000000..ac7cab1f0d1 --- /dev/null +++ b/mindspore/core/ops/split_with_overlap.h @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_ +#define MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_ +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +namespace mindspore { +namespace ops { +constexpr auto kNameSplitWithOverlap = "SplitWithOverlap"; +class SplitWithOverlap : public PrimitiveC { + public: + SplitWithOverlap() : PrimitiveC(kNameSplitWithOverlap) {} + ~SplitWithOverlap() = default; + MS_DECLARE_PARENT(SplitWithOverlap, PrimitiveC); + void Init(int64_t number_split, const std::vector &ratio, const std::vector &extend_top, + const std::vector &extend_bottom, int64_t split_dim, int64_t stride, int64_t pad_top, + bool trans_format); + + void set_ratio(const std::vector &ratio); + void set_extend_top(const std::vector &extend_top); + void set_extend_bottom(const std::vector &extend_bottom); + void set_number_split(int64_t number_split); + void set_split_dim(int64_t split_dim); + void set_stride(int64_t stride); + void set_pad_top(int64_t pad_top); + void set_trans_format(bool trans_format); + + std::vector get_ratio() const; + std::vector get_extend_top() const; + std::vector get_extend_bottom() const; + int64_t get_number_split() const; + int64_t get_split_dim() const; + int64_t get_stride() const; + int64_t get_pad_top() const; + bool get_trans_format() const; +}; +AbstractBasePtr SplitWithOverlapInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimSplitWithOverlap = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_ diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index da0bd88273c..2e9d57ebb9d 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -210,6 +210,7 @@ union PrimitiveType { Call, Custom, CumSum, + SplitWithOverlap, } table Abs { @@ -1115,3 +1116,14 @@ table Custom { type: string; attr: [Attribute]; } + +table SplitWithOverlap { + number_split: long; + ratio: [long]; + extend_top: [long]; + extend_bottom: [long]; + split_dim: long; + stride: long; + pad_top: long; + trans_format: bool = false; +} diff --git a/mindspore/lite/src/ops/ops_def.cc b/mindspore/lite/src/ops/ops_def.cc index 43888a54060..ab1b89595c4 100644 --- a/mindspore/lite/src/ops/ops_def.cc +++ b/mindspore/lite/src/ops/ops_def.cc @@ -209,6 +209,7 @@ OP_TYPE(LogSoftmax) OP_TYPE(Call) OP_TYPE(Custom) OP_TYPE(CumSum) +OP_TYPE(SplitWithOverlap) OP_TYPE_DEF_END(PrimitiveType) OP_SCHEMA_DEF(Abs) @@ -1114,3 +1115,14 @@ OP_SCHEMA_DEF_ONLY(Custom) OP_ATTR_ONLY(type, string) OP_ATTR_ONLY(attr, [Attribute]) OP_SCHEMA_DEF_ONLY_END(Custom) + +OP_SCHEMA_DEF(SplitWithOverlap) +OP_ATTR(number_split, long) +OP_ATTR(ratio, [long]) +OP_ATTR(extend_top, [long]) +OP_ATTR(extend_bottom, [long]) +OP_ATTR(split_dim, long) +OP_ATTR(stride, long) +OP_ATTR(pad_top, long) +OP_ATTR_WITH_VALUE(trans_format, bool, false) +OP_SCHEMA_DEF_END(SplitWithOverlap) diff --git a/mindspore/lite/src/ops/ops_func_declare.h b/mindspore/lite/src/ops/ops_func_declare.h index 5df5a173f51..d9dbee29e85 100644 --- a/mindspore/lite/src/ops/ops_func_declare.h +++ b/mindspore/lite/src/ops/ops_func_declare.h @@ -237,6 +237,7 @@ #include "ops/log_softmax.h" #include "ops/call.h" #include "ops/cumsum.h" +#include "ops/split_with_overlap.h" #define FUNC_MSOP2SCHEMAOP_DECLARE(OP) \ namespace mindspore::lite::ops { \ @@ -448,5 +449,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(Splice); FUNC_MSOP2SCHEMAOP_DECLARE(LogSoftmax); FUNC_MSOP2SCHEMAOP_DECLARE(Call); FUNC_MSOP2SCHEMAOP_DECLARE(CumSum); +FUNC_MSOP2SCHEMAOP_DECLARE(SplitWithOverlap); #endif #endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_ diff --git a/mindspore/lite/src/ops/populate/split_with_overlap_populate.cc b/mindspore/lite/src/ops/populate/split_with_overlap_populate.cc new file mode 100644 index 00000000000..fa5f390dcb8 --- /dev/null +++ b/mindspore/lite/src/ops/populate/split_with_overlap_populate.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/ops/populate/populate_register.h" +#include "nnacl/split_parameter.h" +using mindspore::schema::PrimitiveType_SplitWithOverlap; + +namespace mindspore { +namespace lite { +OpParameter *PopulateSplitWithOverlapParameter(const void *prim) { + auto *split_with_over_lap_param = + reinterpret_cast(malloc(sizeof(SplitWithOverlapParameter))); + if (split_with_over_lap_param == nullptr) { + MS_LOG(ERROR) << "malloc PopulateSplitWithOverlapParameter failed."; + return nullptr; + } + memset(split_with_over_lap_param, 0, sizeof(SplitWithOverlapParameter)); + + auto primitive = static_cast(prim); + auto value = primitive->value_as_SplitWithOverlap(); + split_with_over_lap_param->op_parameter_.type_ = primitive->value_type(); + + auto ratio = value->ratio(); + if (ratio->size() > SPLIT_MAX_SLICE_NUM) { + MS_LOG(ERROR) << "SplitWithOverlap do not support splitting tensor into more than " << SPLIT_MAX_SLICE_NUM + << " slices"; + delete split_with_over_lap_param; + return nullptr; + } + + split_with_over_lap_param->num_split_ = static_cast(ratio->size()); + split_with_over_lap_param->split_dim_ = value->split_dim(); + + auto extend_top = value->extend_top(); + auto extend_bottom = value->extend_bottom(); + if (extend_top->size() != ratio->size() || extend_bottom->size() != ratio->size()) { + MS_LOG(ERROR) << "The sizes of ratio, extend_top and extend_bottom are not identical"; + delete split_with_over_lap_param; + return nullptr; + } + + for (size_t i = 0; i < ratio->size(); ++i) { + split_with_over_lap_param->ratio_[i] = (*ratio)[i]; + split_with_over_lap_param->extend_top_[i] = (*extend_top)[i]; + split_with_over_lap_param->extend_bottom_[i] = (*extend_bottom)[i]; + } + + split_with_over_lap_param->stride_ = value->stride(); + split_with_over_lap_param->pad_top_ = value->pad_top(); + + return reinterpret_cast(split_with_over_lap_param); +} +REG_POPULATE(PrimitiveType_SplitWithOverlap, PopulateSplitWithOverlapParameter, SCHEMA_CUR) +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.cc new file mode 100644 index 00000000000..610566f2e3e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/split_with_over_lap_base.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_SplitWithOverlap; + +namespace mindspore::kernel { + +void SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const SplitWithOverlapParameter *param, + const std::vector &shape) { + int total_block_count = 0; + for (auto i = 0; i < param->num_split_; i++) { + total_block_count += param->ratio_[i]; + } + + auto split_dim_size = shape[param->split_dim_]; + + std::vector borders; + borders.emplace_back(0); + int visited_block = 0; + for (auto i = 0; i < param->num_split_ - 1; i++) { + visited_block += param->ratio_[i]; + auto cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); + if (param->stride_ != 0) { + // make sure border align with stride + cur_border = UP_ROUND(cur_border + param->pad_top_, param->stride_); + borders.emplace_back(cur_border - param->pad_top_); + } else { + borders.emplace_back(cur_border); + } + } + borders.emplace_back(split_dim_size); + + for (auto i = 0; i < param->num_split_; i++) { + start_indices_.emplace_back(borders[i]); + end_indices_.emplace_back(borders[i + 1]); + + // overlap: calibrate start_indices and end_indices by adding extends + start_indices_[i] -= param->extend_top_[i]; + end_indices_[i] += param->extend_bottom_[i]; + } +} + +int SplitWithOverlapBaseCPUKernel::Init() { return RET_OK; } + +int SplitWithOverlapBaseCPUKernel::ReSize() { return RET_OK; } + +int SplitWithOverlapBaseCPUKernel::Split(int task_id) { + DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), task_id, split_dim_size_, element_bytes_, outer_total_dim_, + inner_stride_, start_indices_.data(), end_indices_.data()); + + return RET_OK; +} + +int SplitWithOverlapRun(void *cdata, int task_id) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->Split(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SplitWithOverlapRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + + return RET_OK; +} + +int SplitWithOverlapBaseCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail! ret: " << prepare_ret; + return prepare_ret; + } + + auto in_tensor = in_tensors_.front(); + input_ptr_ = reinterpret_cast(in_tensor->data_c()); + auto input_shape = in_tensor->shape(); + + start_indices_.clear(); + end_indices_.clear(); + output_ptr_.clear(); + + for (int i = 0; i < param->num_split_; i++) { + output_ptr_.push_back(reinterpret_cast(out_tensors_.at(i)->data_c())); + } + + CalculateSplitedShapes(param, input_shape); + + outer_total_dim_ = 1; + inner_stride_ = 1; + split_dim_size_ = input_shape[param->split_dim_]; + element_bytes_ = in_tensor->Size(); + + for (auto i = 0; i < param->split_dim_; i++) { + outer_total_dim_ *= input_shape[i]; + } + + for (int i = static_cast(input_shape.size()) - 1; i > param->split_dim_; i--) { + inner_stride_ *= input_shape[i]; + } + + auto ret = ParallelLaunch(static_cast(this->context_)->thread_pool_, SplitWithOverlapRun, + this, param->num_split_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]"; + return RET_ERROR; + } + + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SplitWithOverlap, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SplitWithOverlap, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_SplitWithOverlap, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.h b/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.h new file mode 100644 index 00000000000..9de04acf1c9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.h @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_ + +#include +#include "include/errorcode.h" +#include "include/context.h" +#include "src/lite_kernel.h" +#include "nnacl/split_parameter.h" +#include "nnacl/base/split_with_over_lap_base.h" + +namespace mindspore::kernel { +class SplitWithOverlapBaseCPUKernel : public LiteKernel { + public: + SplitWithOverlapBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx) + : LiteKernel(parameter, inputs, outputs, ctx) { + param = reinterpret_cast(op_parameter_); + } + ~SplitWithOverlapBaseCPUKernel() override = default; + void CalculateSplitedShapes(const SplitWithOverlapParameter *param, const std::vector &shape); + int Init() override; + int ReSize() override; + int Run() override; + int Split(int task_id); + + protected: + // range: [start, end) + std::vector start_indices_; + std::vector end_indices_; + + int outer_total_dim_{0}; + int inner_stride_{0}; + int element_bytes_{0}; + int split_dim_size_{0}; + SplitWithOverlapParameter *param = nullptr; + char *input_ptr_{nullptr}; + std::vector output_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_