!15760 add split_with_overlap

From: @zoloft
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-04-30 09:28:11 +08:00 committed by Gitee
commit ef0431e5c9
15 changed files with 669 additions and 1 deletions

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/base/split_with_over_lap_base.h"
#include "nnacl/split_parameter.h"
#include <string.h>
#include "nnacl/errorcode.h"
int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes,
int outer_total_dim, int inner_stride, int *start_indices, int *end_indices) {
int input_stride = split_dim_size * inner_stride * element_bytes;
for (int slice_idx = 0; slice_idx < num_split; slice_idx++) {
int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes;
char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes;
for (int out_idx = 0; out_idx < outer_total_dim; out_idx++) {
(void)(memcpy(out_data[slice_idx] + out_idx * out_stride, src_ptr, out_stride));
src_ptr += input_stride;
}
}
return NNACL_OK;
}
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes,
int outer_total_dim, int inner_stride, int *start_indices, int *end_indices) {
int input_stride = split_dim_size * inner_stride * element_bytes;
int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes;
char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes;
for (int i = 0; i < outer_total_dim; i++) {
(void)memcpy(out_data[slice_idx] + i * out_stride, src_ptr, out_stride);
src_ptr += input_stride;
}
return NNACL_OK;
}

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_
#define MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_
#include "nnacl/op_base.h"
#include "nnacl/split_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes,
int outer_total_dim, int inner_stride, int *start_indices, int *end_indices);
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes,
int outer_total_dim, int inner_stride, int *start_indices, int *end_indices);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_

View File

@ -217,8 +217,9 @@ enum PrimType {
PrimType_Call = 190,
PrimType_Custom = 191,
PrimType_CumSum = 192,
PrimType_SplitWithOverlap = 193,
PrimType_MIN = PrimType_NONE,
PrimType_MAX = PrimType_CumSum + 1
PrimType_MAX = PrimType_SplitWithOverlap + 1
};
void RegInfer(int prim_type, InferShape func);

View File

@ -0,0 +1,97 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/infer/split_with_over_lap_infer.h"
#include "nnacl/infer/infer_register.h"
int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
#ifdef Debug
int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter);
if (check_ret != NNACL_OK) {
return check_ret;
}
#endif
if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}
const TensorC *input = inputs[0];
if (inputs_size < 1) {
return NNACL_ERR;
}
if (outputs_size == 0) {
return NNACL_ERR;
}
SplitWithOverlapParameter *param = (SplitWithOverlapParameter *)parameter;
int number_split = param->num_split_;
if (outputs_size != number_split) {
return NNACL_ERR;
}
int stride = param->stride_;
int pad_top = param->pad_top_;
int split_dim = param->split_dim_;
int ratio[SPLIT_MAX_SLICE_NUM];
int extend_top[SPLIT_MAX_SLICE_NUM];
int extend_bottom[SPLIT_MAX_SLICE_NUM];
for (int i = 0; i < number_split; ++i) {
ratio[i] = param->ratio_[i];
extend_top[i] = param->extend_top_[i];
extend_bottom[i] = param->extend_bottom_[i];
}
const int *input_shape = input->shape_;
int split_dim_size = input_shape[split_dim];
int total_block_count = 0;
for (int i = 0; i < number_split; i++) {
total_block_count += ratio[i];
}
int borders[MAX_SHAPE_SIZE];
borders[0] = 0;
int visited_block = 0;
for (int i = 0; i < number_split - 1; i++) {
visited_block += ratio[i];
int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
if (stride != 0) {
// make sure border align with stride
cur_border = UP_ROUND(cur_border + pad_top, stride);
borders[i + 1] = cur_border - pad_top;
} else {
borders[i + 1] = cur_border;
}
}
borders[number_split - 1] = split_dim_size;
for (int i = 0; i < number_split; ++i) {
int output_shape[MAX_SHAPE_SIZE];
for (int dim = 0; dim < input->shape_size_; dim++) {
if (dim == split_dim) {
int splited_size = borders[i + 1] - borders[i];
splited_size += extend_top[i] + extend_bottom[i];
output_shape[dim] = splited_size;
} else {
output_shape[dim] = input_shape[dim];
}
}
SetShapeArray(outputs[i], output_shape, input->shape_size_);
SetDataTypeFormat(outputs[i], input);
}
return NNACL_OK;
}
REG_INFER(SplitWithOverlap, PrimType_SplitWithOverlap, SplitWithOverlapInferShape)

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H
#define MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H
#include "nnacl/infer/common_infer.h"
#include "nnacl/split_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H

View File

@ -20,6 +20,7 @@
#include "nnacl/op_base.h"
#define SPLIT_STRIDES_SIZE 32
#define SPLIT_MAX_SLICE_NUM 10
typedef struct SplitQuantArg {
QuantArg in_args_;
@ -44,4 +45,15 @@ typedef struct SplitParameter {
int split_count_;
} SplitParameter;
typedef struct SplitWithOverlapParameter {
OpParameter op_parameter_;
int num_split_;
int split_dim_;
int stride_;
int pad_top_;
int ratio_[SPLIT_MAX_SLICE_NUM];
int extend_top_[SPLIT_MAX_SLICE_NUM];
int extend_bottom_[SPLIT_MAX_SLICE_NUM];
} SplitWithOverlapParameter;
#endif // MINDSPORE_NNACL_SPLIT_PARAMETER_H_

View File

@ -234,6 +234,13 @@ constexpr auto kSideEffectIO = "side_effect_io";
constexpr auto kDeviceType = "device_type";
constexpr auto kExclusive = "exclusive";
constexpr auto kReverse = "reverse";
constexpr auto kSplitStride = "split_stride";
constexpr auto kExtendTop = "extend_top";
constexpr auto kExtendBottom = "extend_bottom";
constexpr auto kNumberSplit = "number_split";
constexpr auto kSplitDim = "split_dim";
constexpr auto kPadTop = "pad_top";
constexpr auto kTransFormat = "trans_format";
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};

View File

@ -0,0 +1,96 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/split_with_overlap.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
void SplitWithOverlap::Init(int64_t number_split, const std::vector<int64_t> &ratio,
const std::vector<int64_t> &extend_top, const std::vector<int64_t> &extend_bottom,
int64_t split_dim, int64_t stride, int64_t pad_top, bool trans_format) {
this->set_number_split(number_split);
this->set_ratio(ratio);
this->set_extend_top(extend_top);
this->set_extend_bottom(extend_bottom);
this->set_split_dim(split_dim);
this->set_stride(stride);
this->set_pad_top(pad_top);
this->set_trans_format(trans_format);
}
void SplitWithOverlap::set_ratio(const std::vector<int64_t> &ratio) { this->AddAttr(kRatio, MakeValue(ratio)); }
void SplitWithOverlap::set_extend_top(const std::vector<int64_t> &extend_top) {
this->AddAttr(kExtendTop, MakeValue(extend_top));
}
void SplitWithOverlap::set_extend_bottom(const std::vector<int64_t> &extend_bottom) {
this->AddAttr(kExtendBottom, MakeValue(extend_bottom));
}
void SplitWithOverlap::set_number_split(int64_t number_split) { this->AddAttr(kNumberSplit, MakeValue(number_split)); }
void SplitWithOverlap::set_split_dim(int64_t split_dim) { this->AddAttr(kSplitDim, MakeValue(split_dim)); }
void SplitWithOverlap::set_stride(int64_t stride) { this->AddAttr(kSplitStride, MakeValue(stride)); }
void SplitWithOverlap::set_pad_top(int64_t pad_top) { this->AddAttr(kPadTop, MakeValue(pad_top)); }
void SplitWithOverlap::set_trans_format(bool trans_format) { this->AddAttr(kTransFormat, MakeValue(trans_format)); }
std::vector<int64_t> SplitWithOverlap::get_ratio() const {
auto value_ptr = GetAttr(kRatio);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> SplitWithOverlap::get_extend_top() const {
auto value_ptr = GetAttr(kExtendTop);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> SplitWithOverlap::get_extend_bottom() const {
auto value_ptr = GetAttr(kExtendBottom);
return GetValue<std::vector<int64_t>>(value_ptr);
}
int64_t SplitWithOverlap::get_number_split() const {
auto value_ptr = GetAttr(kNumberSplit);
return GetValue<int64_t>(value_ptr);
}
int64_t SplitWithOverlap::get_split_dim() const {
auto value_ptr = GetAttr(kSplitDim);
return GetValue<int64_t>(value_ptr);
}
int64_t SplitWithOverlap::get_stride() const {
auto value_ptr = GetAttr(kSplitStride);
return GetValue<int64_t>(value_ptr);
}
int64_t SplitWithOverlap::get_pad_top() const {
auto value_ptr = GetAttr(kPadTop);
return GetValue<int64_t>(value_ptr);
}
bool SplitWithOverlap::get_trans_format() const {
auto value_ptr = GetAttr(kTransFormat);
return GetValue<bool>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameSplitWithOverlap, SplitWithOverlap);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_
#define MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSplitWithOverlap = "SplitWithOverlap";
class SplitWithOverlap : public PrimitiveC {
public:
SplitWithOverlap() : PrimitiveC(kNameSplitWithOverlap) {}
~SplitWithOverlap() = default;
MS_DECLARE_PARENT(SplitWithOverlap, PrimitiveC);
void Init(int64_t number_split, const std::vector<int64_t> &ratio, const std::vector<int64_t> &extend_top,
const std::vector<int64_t> &extend_bottom, int64_t split_dim, int64_t stride, int64_t pad_top,
bool trans_format);
void set_ratio(const std::vector<int64_t> &ratio);
void set_extend_top(const std::vector<int64_t> &extend_top);
void set_extend_bottom(const std::vector<int64_t> &extend_bottom);
void set_number_split(int64_t number_split);
void set_split_dim(int64_t split_dim);
void set_stride(int64_t stride);
void set_pad_top(int64_t pad_top);
void set_trans_format(bool trans_format);
std::vector<int64_t> get_ratio() const;
std::vector<int64_t> get_extend_top() const;
std::vector<int64_t> get_extend_bottom() const;
int64_t get_number_split() const;
int64_t get_split_dim() const;
int64_t get_stride() const;
int64_t get_pad_top() const;
bool get_trans_format() const;
};
AbstractBasePtr SplitWithOverlapInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimSplitWithOverlap = std::shared_ptr<SplitWithOverlap>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_

View File

@ -210,6 +210,7 @@ union PrimitiveType {
Call,
Custom,
CumSum,
SplitWithOverlap,
}
table Abs {
@ -1115,3 +1116,14 @@ table Custom {
type: string;
attr: [Attribute];
}
table SplitWithOverlap {
number_split: long;
ratio: [long];
extend_top: [long];
extend_bottom: [long];
split_dim: long;
stride: long;
pad_top: long;
trans_format: bool = false;
}

View File

@ -209,6 +209,7 @@ OP_TYPE(LogSoftmax)
OP_TYPE(Call)
OP_TYPE(Custom)
OP_TYPE(CumSum)
OP_TYPE(SplitWithOverlap)
OP_TYPE_DEF_END(PrimitiveType)
OP_SCHEMA_DEF(Abs)
@ -1114,3 +1115,14 @@ OP_SCHEMA_DEF_ONLY(Custom)
OP_ATTR_ONLY(type, string)
OP_ATTR_ONLY(attr, [Attribute])
OP_SCHEMA_DEF_ONLY_END(Custom)
OP_SCHEMA_DEF(SplitWithOverlap)
OP_ATTR(number_split, long)
OP_ATTR(ratio, [long])
OP_ATTR(extend_top, [long])
OP_ATTR(extend_bottom, [long])
OP_ATTR(split_dim, long)
OP_ATTR(stride, long)
OP_ATTR(pad_top, long)
OP_ATTR_WITH_VALUE(trans_format, bool, false)
OP_SCHEMA_DEF_END(SplitWithOverlap)

View File

@ -237,6 +237,7 @@
#include "ops/log_softmax.h"
#include "ops/call.h"
#include "ops/cumsum.h"
#include "ops/split_with_overlap.h"
#define FUNC_MSOP2SCHEMAOP_DECLARE(OP) \
namespace mindspore::lite::ops { \
@ -448,5 +449,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(Splice);
FUNC_MSOP2SCHEMAOP_DECLARE(LogSoftmax);
FUNC_MSOP2SCHEMAOP_DECLARE(Call);
FUNC_MSOP2SCHEMAOP_DECLARE(CumSum);
FUNC_MSOP2SCHEMAOP_DECLARE(SplitWithOverlap);
#endif
#endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_

View File

@ -0,0 +1,67 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
#include "nnacl/split_parameter.h"
using mindspore::schema::PrimitiveType_SplitWithOverlap;
namespace mindspore {
namespace lite {
OpParameter *PopulateSplitWithOverlapParameter(const void *prim) {
auto *split_with_over_lap_param =
reinterpret_cast<SplitWithOverlapParameter *>(malloc(sizeof(SplitWithOverlapParameter)));
if (split_with_over_lap_param == nullptr) {
MS_LOG(ERROR) << "malloc PopulateSplitWithOverlapParameter failed.";
return nullptr;
}
memset(split_with_over_lap_param, 0, sizeof(SplitWithOverlapParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_SplitWithOverlap();
split_with_over_lap_param->op_parameter_.type_ = primitive->value_type();
auto ratio = value->ratio();
if (ratio->size() > SPLIT_MAX_SLICE_NUM) {
MS_LOG(ERROR) << "SplitWithOverlap do not support splitting tensor into more than " << SPLIT_MAX_SLICE_NUM
<< " slices";
delete split_with_over_lap_param;
return nullptr;
}
split_with_over_lap_param->num_split_ = static_cast<int>(ratio->size());
split_with_over_lap_param->split_dim_ = value->split_dim();
auto extend_top = value->extend_top();
auto extend_bottom = value->extend_bottom();
if (extend_top->size() != ratio->size() || extend_bottom->size() != ratio->size()) {
MS_LOG(ERROR) << "The sizes of ratio, extend_top and extend_bottom are not identical";
delete split_with_over_lap_param;
return nullptr;
}
for (size_t i = 0; i < ratio->size(); ++i) {
split_with_over_lap_param->ratio_[i] = (*ratio)[i];
split_with_over_lap_param->extend_top_[i] = (*extend_top)[i];
split_with_over_lap_param->extend_bottom_[i] = (*extend_bottom)[i];
}
split_with_over_lap_param->stride_ = value->stride();
split_with_over_lap_param->pad_top_ = value->pad_top();
return reinterpret_cast<OpParameter *>(split_with_over_lap_param);
}
REG_POPULATE(PrimitiveType_SplitWithOverlap, PopulateSplitWithOverlapParameter, SCHEMA_CUR)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,133 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/split_with_over_lap_base.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SplitWithOverlap;
namespace mindspore::kernel {
void SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const SplitWithOverlapParameter *param,
const std::vector<int> &shape) {
int total_block_count = 0;
for (auto i = 0; i < param->num_split_; i++) {
total_block_count += param->ratio_[i];
}
auto split_dim_size = shape[param->split_dim_];
std::vector<int> borders;
borders.emplace_back(0);
int visited_block = 0;
for (auto i = 0; i < param->num_split_ - 1; i++) {
visited_block += param->ratio_[i];
auto cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
if (param->stride_ != 0) {
// make sure border align with stride
cur_border = UP_ROUND(cur_border + param->pad_top_, param->stride_);
borders.emplace_back(cur_border - param->pad_top_);
} else {
borders.emplace_back(cur_border);
}
}
borders.emplace_back(split_dim_size);
for (auto i = 0; i < param->num_split_; i++) {
start_indices_.emplace_back(borders[i]);
end_indices_.emplace_back(borders[i + 1]);
// overlap: calibrate start_indices and end_indices by adding extends
start_indices_[i] -= param->extend_top_[i];
end_indices_[i] += param->extend_bottom_[i];
}
}
int SplitWithOverlapBaseCPUKernel::Init() { return RET_OK; }
int SplitWithOverlapBaseCPUKernel::ReSize() { return RET_OK; }
int SplitWithOverlapBaseCPUKernel::Split(int task_id) {
DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), task_id, split_dim_size_, element_bytes_, outer_total_dim_,
inner_stride_, start_indices_.data(), end_indices_.data());
return RET_OK;
}
int SplitWithOverlapRun(void *cdata, int task_id) {
auto g_kernel = reinterpret_cast<SplitWithOverlapBaseCPUKernel *>(cdata);
auto ret = g_kernel->Split(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SplitWithOverlapRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int SplitWithOverlapBaseCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail! ret: " << prepare_ret;
return prepare_ret;
}
auto in_tensor = in_tensors_.front();
input_ptr_ = reinterpret_cast<char *>(in_tensor->data_c());
auto input_shape = in_tensor->shape();
start_indices_.clear();
end_indices_.clear();
output_ptr_.clear();
for (int i = 0; i < param->num_split_; i++) {
output_ptr_.push_back(reinterpret_cast<char *>(out_tensors_.at(i)->data_c()));
}
CalculateSplitedShapes(param, input_shape);
outer_total_dim_ = 1;
inner_stride_ = 1;
split_dim_size_ = input_shape[param->split_dim_];
element_bytes_ = in_tensor->Size();
for (auto i = 0; i < param->split_dim_; i++) {
outer_total_dim_ *= input_shape[i];
}
for (int i = static_cast<int>(input_shape.size()) - 1; i > param->split_dim_; i--) {
inner_stride_ *= input_shape[i];
}
auto ret = ParallelLaunch(static_cast<const lite::InnerContext *>(this->context_)->thread_pool_, SplitWithOverlapRun,
this, param->num_split_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>)
} // namespace mindspore::kernel

View File

@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_
#include <vector>
#include "include/errorcode.h"
#include "include/context.h"
#include "src/lite_kernel.h"
#include "nnacl/split_parameter.h"
#include "nnacl/base/split_with_over_lap_base.h"
namespace mindspore::kernel {
class SplitWithOverlapBaseCPUKernel : public LiteKernel {
public:
SplitWithOverlapBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: LiteKernel(parameter, inputs, outputs, ctx) {
param = reinterpret_cast<SplitWithOverlapParameter *>(op_parameter_);
}
~SplitWithOverlapBaseCPUKernel() override = default;
void CalculateSplitedShapes(const SplitWithOverlapParameter *param, const std::vector<int> &shape);
int Init() override;
int ReSize() override;
int Run() override;
int Split(int task_id);
protected:
// range: [start, end)
std::vector<int> start_indices_;
std::vector<int> end_indices_;
int outer_total_dim_{0};
int inner_stride_{0};
int element_bytes_{0};
int split_dim_size_{0};
SplitWithOverlapParameter *param = nullptr;
char *input_ptr_{nullptr};
std::vector<char *> output_ptr_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_