forked from mindspore-Ecosystem/mindspore
!15760 add split_with_overlap
From: @zoloft Reviewed-by: Signed-off-by:
This commit is contained in:
commit
ef0431e5c9
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/base/split_with_over_lap_base.h"
|
||||
#include "nnacl/split_parameter.h"
|
||||
#include <string.h>
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes,
|
||||
int outer_total_dim, int inner_stride, int *start_indices, int *end_indices) {
|
||||
int input_stride = split_dim_size * inner_stride * element_bytes;
|
||||
for (int slice_idx = 0; slice_idx < num_split; slice_idx++) {
|
||||
int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes;
|
||||
char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes;
|
||||
for (int out_idx = 0; out_idx < outer_total_dim; out_idx++) {
|
||||
(void)(memcpy(out_data[slice_idx] + out_idx * out_stride, src_ptr, out_stride));
|
||||
src_ptr += input_stride;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes,
|
||||
int outer_total_dim, int inner_stride, int *start_indices, int *end_indices) {
|
||||
int input_stride = split_dim_size * inner_stride * element_bytes;
|
||||
int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes;
|
||||
char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes;
|
||||
for (int i = 0; i < outer_total_dim; i++) {
|
||||
(void)memcpy(out_data[slice_idx] + i * out_stride, src_ptr, out_stride);
|
||||
src_ptr += input_stride;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_
|
||||
#define MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/split_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes,
|
||||
int outer_total_dim, int inner_stride, int *start_indices, int *end_indices);
|
||||
|
||||
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes,
|
||||
int outer_total_dim, int inner_stride, int *start_indices, int *end_indices);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_
|
|
@ -217,8 +217,9 @@ enum PrimType {
|
|||
PrimType_Call = 190,
|
||||
PrimType_Custom = 191,
|
||||
PrimType_CumSum = 192,
|
||||
PrimType_SplitWithOverlap = 193,
|
||||
PrimType_MIN = PrimType_NONE,
|
||||
PrimType_MAX = PrimType_CumSum + 1
|
||||
PrimType_MAX = PrimType_SplitWithOverlap + 1
|
||||
};
|
||||
|
||||
void RegInfer(int prim_type, InferShape func);
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/infer/split_with_over_lap_infer.h"
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
|
||||
int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
#ifdef Debug
|
||||
int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter);
|
||||
if (check_ret != NNACL_OK) {
|
||||
return check_ret;
|
||||
}
|
||||
#endif
|
||||
if (!parameter->infer_flag_) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
const TensorC *input = inputs[0];
|
||||
if (inputs_size < 1) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
if (outputs_size == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
SplitWithOverlapParameter *param = (SplitWithOverlapParameter *)parameter;
|
||||
int number_split = param->num_split_;
|
||||
if (outputs_size != number_split) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
int stride = param->stride_;
|
||||
int pad_top = param->pad_top_;
|
||||
int split_dim = param->split_dim_;
|
||||
|
||||
int ratio[SPLIT_MAX_SLICE_NUM];
|
||||
int extend_top[SPLIT_MAX_SLICE_NUM];
|
||||
int extend_bottom[SPLIT_MAX_SLICE_NUM];
|
||||
for (int i = 0; i < number_split; ++i) {
|
||||
ratio[i] = param->ratio_[i];
|
||||
extend_top[i] = param->extend_top_[i];
|
||||
extend_bottom[i] = param->extend_bottom_[i];
|
||||
}
|
||||
|
||||
const int *input_shape = input->shape_;
|
||||
int split_dim_size = input_shape[split_dim];
|
||||
int total_block_count = 0;
|
||||
for (int i = 0; i < number_split; i++) {
|
||||
total_block_count += ratio[i];
|
||||
}
|
||||
|
||||
int borders[MAX_SHAPE_SIZE];
|
||||
borders[0] = 0;
|
||||
int visited_block = 0;
|
||||
for (int i = 0; i < number_split - 1; i++) {
|
||||
visited_block += ratio[i];
|
||||
|
||||
int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
|
||||
if (stride != 0) {
|
||||
// make sure border align with stride
|
||||
cur_border = UP_ROUND(cur_border + pad_top, stride);
|
||||
borders[i + 1] = cur_border - pad_top;
|
||||
} else {
|
||||
borders[i + 1] = cur_border;
|
||||
}
|
||||
}
|
||||
borders[number_split - 1] = split_dim_size;
|
||||
|
||||
for (int i = 0; i < number_split; ++i) {
|
||||
int output_shape[MAX_SHAPE_SIZE];
|
||||
for (int dim = 0; dim < input->shape_size_; dim++) {
|
||||
if (dim == split_dim) {
|
||||
int splited_size = borders[i + 1] - borders[i];
|
||||
splited_size += extend_top[i] + extend_bottom[i];
|
||||
output_shape[dim] = splited_size;
|
||||
} else {
|
||||
output_shape[dim] = input_shape[dim];
|
||||
}
|
||||
}
|
||||
SetShapeArray(outputs[i], output_shape, input->shape_size_);
|
||||
SetDataTypeFormat(outputs[i], input);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
REG_INFER(SplitWithOverlap, PrimType_SplitWithOverlap, SplitWithOverlapInferShape)
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H
|
||||
#define MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H
|
||||
|
||||
#include "nnacl/infer/common_infer.h"
|
||||
#include "nnacl/split_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H
|
|
@ -20,6 +20,7 @@
|
|||
#include "nnacl/op_base.h"
|
||||
|
||||
#define SPLIT_STRIDES_SIZE 32
|
||||
#define SPLIT_MAX_SLICE_NUM 10
|
||||
|
||||
typedef struct SplitQuantArg {
|
||||
QuantArg in_args_;
|
||||
|
@ -44,4 +45,15 @@ typedef struct SplitParameter {
|
|||
int split_count_;
|
||||
} SplitParameter;
|
||||
|
||||
typedef struct SplitWithOverlapParameter {
|
||||
OpParameter op_parameter_;
|
||||
int num_split_;
|
||||
int split_dim_;
|
||||
int stride_;
|
||||
int pad_top_;
|
||||
int ratio_[SPLIT_MAX_SLICE_NUM];
|
||||
int extend_top_[SPLIT_MAX_SLICE_NUM];
|
||||
int extend_bottom_[SPLIT_MAX_SLICE_NUM];
|
||||
} SplitWithOverlapParameter;
|
||||
|
||||
#endif // MINDSPORE_NNACL_SPLIT_PARAMETER_H_
|
||||
|
|
|
@ -234,6 +234,13 @@ constexpr auto kSideEffectIO = "side_effect_io";
|
|||
constexpr auto kDeviceType = "device_type";
|
||||
constexpr auto kExclusive = "exclusive";
|
||||
constexpr auto kReverse = "reverse";
|
||||
constexpr auto kSplitStride = "split_stride";
|
||||
constexpr auto kExtendTop = "extend_top";
|
||||
constexpr auto kExtendBottom = "extend_bottom";
|
||||
constexpr auto kNumberSplit = "number_split";
|
||||
constexpr auto kSplitDim = "split_dim";
|
||||
constexpr auto kPadTop = "pad_top";
|
||||
constexpr auto kTransFormat = "trans_format";
|
||||
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
|
||||
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
|
||||
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/split_with_overlap.h"
|
||||
#include "ops/op_utils.h"
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void SplitWithOverlap::Init(int64_t number_split, const std::vector<int64_t> &ratio,
|
||||
const std::vector<int64_t> &extend_top, const std::vector<int64_t> &extend_bottom,
|
||||
int64_t split_dim, int64_t stride, int64_t pad_top, bool trans_format) {
|
||||
this->set_number_split(number_split);
|
||||
this->set_ratio(ratio);
|
||||
this->set_extend_top(extend_top);
|
||||
this->set_extend_bottom(extend_bottom);
|
||||
this->set_split_dim(split_dim);
|
||||
this->set_stride(stride);
|
||||
this->set_pad_top(pad_top);
|
||||
this->set_trans_format(trans_format);
|
||||
}
|
||||
|
||||
void SplitWithOverlap::set_ratio(const std::vector<int64_t> &ratio) { this->AddAttr(kRatio, MakeValue(ratio)); }
|
||||
|
||||
void SplitWithOverlap::set_extend_top(const std::vector<int64_t> &extend_top) {
|
||||
this->AddAttr(kExtendTop, MakeValue(extend_top));
|
||||
}
|
||||
|
||||
void SplitWithOverlap::set_extend_bottom(const std::vector<int64_t> &extend_bottom) {
|
||||
this->AddAttr(kExtendBottom, MakeValue(extend_bottom));
|
||||
}
|
||||
|
||||
void SplitWithOverlap::set_number_split(int64_t number_split) { this->AddAttr(kNumberSplit, MakeValue(number_split)); }
|
||||
|
||||
void SplitWithOverlap::set_split_dim(int64_t split_dim) { this->AddAttr(kSplitDim, MakeValue(split_dim)); }
|
||||
|
||||
void SplitWithOverlap::set_stride(int64_t stride) { this->AddAttr(kSplitStride, MakeValue(stride)); }
|
||||
|
||||
void SplitWithOverlap::set_pad_top(int64_t pad_top) { this->AddAttr(kPadTop, MakeValue(pad_top)); }
|
||||
|
||||
void SplitWithOverlap::set_trans_format(bool trans_format) { this->AddAttr(kTransFormat, MakeValue(trans_format)); }
|
||||
|
||||
std::vector<int64_t> SplitWithOverlap::get_ratio() const {
|
||||
auto value_ptr = GetAttr(kRatio);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> SplitWithOverlap::get_extend_top() const {
|
||||
auto value_ptr = GetAttr(kExtendTop);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> SplitWithOverlap::get_extend_bottom() const {
|
||||
auto value_ptr = GetAttr(kExtendBottom);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t SplitWithOverlap::get_number_split() const {
|
||||
auto value_ptr = GetAttr(kNumberSplit);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t SplitWithOverlap::get_split_dim() const {
|
||||
auto value_ptr = GetAttr(kSplitDim);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t SplitWithOverlap::get_stride() const {
|
||||
auto value_ptr = GetAttr(kSplitStride);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t SplitWithOverlap::get_pad_top() const {
|
||||
auto value_ptr = GetAttr(kPadTop);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
bool SplitWithOverlap::get_trans_format() const {
|
||||
auto value_ptr = GetAttr(kTransFormat);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_C(kNameSplitWithOverlap, SplitWithOverlap);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_
|
||||
#define MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSplitWithOverlap = "SplitWithOverlap";
|
||||
class SplitWithOverlap : public PrimitiveC {
|
||||
public:
|
||||
SplitWithOverlap() : PrimitiveC(kNameSplitWithOverlap) {}
|
||||
~SplitWithOverlap() = default;
|
||||
MS_DECLARE_PARENT(SplitWithOverlap, PrimitiveC);
|
||||
void Init(int64_t number_split, const std::vector<int64_t> &ratio, const std::vector<int64_t> &extend_top,
|
||||
const std::vector<int64_t> &extend_bottom, int64_t split_dim, int64_t stride, int64_t pad_top,
|
||||
bool trans_format);
|
||||
|
||||
void set_ratio(const std::vector<int64_t> &ratio);
|
||||
void set_extend_top(const std::vector<int64_t> &extend_top);
|
||||
void set_extend_bottom(const std::vector<int64_t> &extend_bottom);
|
||||
void set_number_split(int64_t number_split);
|
||||
void set_split_dim(int64_t split_dim);
|
||||
void set_stride(int64_t stride);
|
||||
void set_pad_top(int64_t pad_top);
|
||||
void set_trans_format(bool trans_format);
|
||||
|
||||
std::vector<int64_t> get_ratio() const;
|
||||
std::vector<int64_t> get_extend_top() const;
|
||||
std::vector<int64_t> get_extend_bottom() const;
|
||||
int64_t get_number_split() const;
|
||||
int64_t get_split_dim() const;
|
||||
int64_t get_stride() const;
|
||||
int64_t get_pad_top() const;
|
||||
bool get_trans_format() const;
|
||||
};
|
||||
AbstractBasePtr SplitWithOverlapInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimSplitWithOverlap = std::shared_ptr<SplitWithOverlap>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_
|
|
@ -210,6 +210,7 @@ union PrimitiveType {
|
|||
Call,
|
||||
Custom,
|
||||
CumSum,
|
||||
SplitWithOverlap,
|
||||
}
|
||||
|
||||
table Abs {
|
||||
|
@ -1115,3 +1116,14 @@ table Custom {
|
|||
type: string;
|
||||
attr: [Attribute];
|
||||
}
|
||||
|
||||
table SplitWithOverlap {
|
||||
number_split: long;
|
||||
ratio: [long];
|
||||
extend_top: [long];
|
||||
extend_bottom: [long];
|
||||
split_dim: long;
|
||||
stride: long;
|
||||
pad_top: long;
|
||||
trans_format: bool = false;
|
||||
}
|
||||
|
|
|
@ -209,6 +209,7 @@ OP_TYPE(LogSoftmax)
|
|||
OP_TYPE(Call)
|
||||
OP_TYPE(Custom)
|
||||
OP_TYPE(CumSum)
|
||||
OP_TYPE(SplitWithOverlap)
|
||||
OP_TYPE_DEF_END(PrimitiveType)
|
||||
|
||||
OP_SCHEMA_DEF(Abs)
|
||||
|
@ -1114,3 +1115,14 @@ OP_SCHEMA_DEF_ONLY(Custom)
|
|||
OP_ATTR_ONLY(type, string)
|
||||
OP_ATTR_ONLY(attr, [Attribute])
|
||||
OP_SCHEMA_DEF_ONLY_END(Custom)
|
||||
|
||||
OP_SCHEMA_DEF(SplitWithOverlap)
|
||||
OP_ATTR(number_split, long)
|
||||
OP_ATTR(ratio, [long])
|
||||
OP_ATTR(extend_top, [long])
|
||||
OP_ATTR(extend_bottom, [long])
|
||||
OP_ATTR(split_dim, long)
|
||||
OP_ATTR(stride, long)
|
||||
OP_ATTR(pad_top, long)
|
||||
OP_ATTR_WITH_VALUE(trans_format, bool, false)
|
||||
OP_SCHEMA_DEF_END(SplitWithOverlap)
|
||||
|
|
|
@ -237,6 +237,7 @@
|
|||
#include "ops/log_softmax.h"
|
||||
#include "ops/call.h"
|
||||
#include "ops/cumsum.h"
|
||||
#include "ops/split_with_overlap.h"
|
||||
|
||||
#define FUNC_MSOP2SCHEMAOP_DECLARE(OP) \
|
||||
namespace mindspore::lite::ops { \
|
||||
|
@ -448,5 +449,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(Splice);
|
|||
FUNC_MSOP2SCHEMAOP_DECLARE(LogSoftmax);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(Call);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(CumSum);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(SplitWithOverlap);
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/split_parameter.h"
|
||||
using mindspore::schema::PrimitiveType_SplitWithOverlap;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateSplitWithOverlapParameter(const void *prim) {
|
||||
auto *split_with_over_lap_param =
|
||||
reinterpret_cast<SplitWithOverlapParameter *>(malloc(sizeof(SplitWithOverlapParameter)));
|
||||
if (split_with_over_lap_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc PopulateSplitWithOverlapParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(split_with_over_lap_param, 0, sizeof(SplitWithOverlapParameter));
|
||||
|
||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
auto value = primitive->value_as_SplitWithOverlap();
|
||||
split_with_over_lap_param->op_parameter_.type_ = primitive->value_type();
|
||||
|
||||
auto ratio = value->ratio();
|
||||
if (ratio->size() > SPLIT_MAX_SLICE_NUM) {
|
||||
MS_LOG(ERROR) << "SplitWithOverlap do not support splitting tensor into more than " << SPLIT_MAX_SLICE_NUM
|
||||
<< " slices";
|
||||
delete split_with_over_lap_param;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
split_with_over_lap_param->num_split_ = static_cast<int>(ratio->size());
|
||||
split_with_over_lap_param->split_dim_ = value->split_dim();
|
||||
|
||||
auto extend_top = value->extend_top();
|
||||
auto extend_bottom = value->extend_bottom();
|
||||
if (extend_top->size() != ratio->size() || extend_bottom->size() != ratio->size()) {
|
||||
MS_LOG(ERROR) << "The sizes of ratio, extend_top and extend_bottom are not identical";
|
||||
delete split_with_over_lap_param;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < ratio->size(); ++i) {
|
||||
split_with_over_lap_param->ratio_[i] = (*ratio)[i];
|
||||
split_with_over_lap_param->extend_top_[i] = (*extend_top)[i];
|
||||
split_with_over_lap_param->extend_bottom_[i] = (*extend_bottom)[i];
|
||||
}
|
||||
|
||||
split_with_over_lap_param->stride_ = value->stride();
|
||||
split_with_over_lap_param->pad_top_ = value->pad_top();
|
||||
|
||||
return reinterpret_cast<OpParameter *>(split_with_over_lap_param);
|
||||
}
|
||||
REG_POPULATE(PrimitiveType_SplitWithOverlap, PopulateSplitWithOverlapParameter, SCHEMA_CUR)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,133 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/base/split_with_over_lap_base.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_SplitWithOverlap;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
void SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const SplitWithOverlapParameter *param,
|
||||
const std::vector<int> &shape) {
|
||||
int total_block_count = 0;
|
||||
for (auto i = 0; i < param->num_split_; i++) {
|
||||
total_block_count += param->ratio_[i];
|
||||
}
|
||||
|
||||
auto split_dim_size = shape[param->split_dim_];
|
||||
|
||||
std::vector<int> borders;
|
||||
borders.emplace_back(0);
|
||||
int visited_block = 0;
|
||||
for (auto i = 0; i < param->num_split_ - 1; i++) {
|
||||
visited_block += param->ratio_[i];
|
||||
auto cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
|
||||
if (param->stride_ != 0) {
|
||||
// make sure border align with stride
|
||||
cur_border = UP_ROUND(cur_border + param->pad_top_, param->stride_);
|
||||
borders.emplace_back(cur_border - param->pad_top_);
|
||||
} else {
|
||||
borders.emplace_back(cur_border);
|
||||
}
|
||||
}
|
||||
borders.emplace_back(split_dim_size);
|
||||
|
||||
for (auto i = 0; i < param->num_split_; i++) {
|
||||
start_indices_.emplace_back(borders[i]);
|
||||
end_indices_.emplace_back(borders[i + 1]);
|
||||
|
||||
// overlap: calibrate start_indices and end_indices by adding extends
|
||||
start_indices_[i] -= param->extend_top_[i];
|
||||
end_indices_[i] += param->extend_bottom_[i];
|
||||
}
|
||||
}
|
||||
|
||||
int SplitWithOverlapBaseCPUKernel::Init() { return RET_OK; }
|
||||
|
||||
int SplitWithOverlapBaseCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int SplitWithOverlapBaseCPUKernel::Split(int task_id) {
|
||||
DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), task_id, split_dim_size_, element_bytes_, outer_total_dim_,
|
||||
inner_stride_, start_indices_.data(), end_indices_.data());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SplitWithOverlapRun(void *cdata, int task_id) {
|
||||
auto g_kernel = reinterpret_cast<SplitWithOverlapBaseCPUKernel *>(cdata);
|
||||
auto ret = g_kernel->Split(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SplitWithOverlapRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SplitWithOverlapBaseCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail! ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
|
||||
auto in_tensor = in_tensors_.front();
|
||||
input_ptr_ = reinterpret_cast<char *>(in_tensor->data_c());
|
||||
auto input_shape = in_tensor->shape();
|
||||
|
||||
start_indices_.clear();
|
||||
end_indices_.clear();
|
||||
output_ptr_.clear();
|
||||
|
||||
for (int i = 0; i < param->num_split_; i++) {
|
||||
output_ptr_.push_back(reinterpret_cast<char *>(out_tensors_.at(i)->data_c()));
|
||||
}
|
||||
|
||||
CalculateSplitedShapes(param, input_shape);
|
||||
|
||||
outer_total_dim_ = 1;
|
||||
inner_stride_ = 1;
|
||||
split_dim_size_ = input_shape[param->split_dim_];
|
||||
element_bytes_ = in_tensor->Size();
|
||||
|
||||
for (auto i = 0; i < param->split_dim_; i++) {
|
||||
outer_total_dim_ *= input_shape[i];
|
||||
}
|
||||
|
||||
for (int i = static_cast<int>(input_shape.size()) - 1; i > param->split_dim_; i--) {
|
||||
inner_stride_ *= input_shape[i];
|
||||
}
|
||||
|
||||
auto ret = ParallelLaunch(static_cast<const lite::InnerContext *>(this->context_)->thread_pool_, SplitWithOverlapRun,
|
||||
this, param->num_split_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/context.h"
|
||||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/split_parameter.h"
|
||||
#include "nnacl/base/split_with_over_lap_base.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class SplitWithOverlapBaseCPUKernel : public LiteKernel {
|
||||
public:
|
||||
SplitWithOverlapBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx) {
|
||||
param = reinterpret_cast<SplitWithOverlapParameter *>(op_parameter_);
|
||||
}
|
||||
~SplitWithOverlapBaseCPUKernel() override = default;
|
||||
void CalculateSplitedShapes(const SplitWithOverlapParameter *param, const std::vector<int> &shape);
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int Split(int task_id);
|
||||
|
||||
protected:
|
||||
// range: [start, end)
|
||||
std::vector<int> start_indices_;
|
||||
std::vector<int> end_indices_;
|
||||
|
||||
int outer_total_dim_{0};
|
||||
int inner_stride_{0};
|
||||
int element_bytes_{0};
|
||||
int split_dim_size_{0};
|
||||
SplitWithOverlapParameter *param = nullptr;
|
||||
char *input_ptr_{nullptr};
|
||||
std::vector<char *> output_ptr_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_
|
Loading…
Reference in New Issue