!48500 [MSLITE]回合静态告警清理代码

Merge pull request !48500 from zhuguodong/for_warn_backport
This commit is contained in:
i-robot 2023-02-08 01:47:26 +00:00 committed by Gitee
commit 759d6fea7c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 126 additions and 109 deletions

View File

@ -34,14 +34,14 @@ int SpaceToDepthForNHWC(const void *input, void *output, const int *in_shape, co
ComputeStrides(in_shape, in_strides, shape_size);
ComputeStrides(out_shape, out_strides, shape_size);
for (int i = 0; i < out_shape[0]; ++i) {
size_t in_offset_n = i * in_strides[0];
size_t out_offset_n = i * out_strides[0];
int64_t in_offset_n = i * in_strides[0];
int64_t out_offset_n = i * out_strides[0];
for (int j = h_start; j < h_end; ++j) {
size_t in_offset_h = in_offset_n + j * block_size * in_strides[1];
size_t out_offset_h = out_offset_n + j * out_strides[1];
int64_t in_offset_h = in_offset_n + j * block_size * in_strides[1];
int64_t out_offset_h = out_offset_n + j * out_strides[1];
for (int k = 0; k < out_shape[2]; ++k) {
size_t in_offset_w = in_offset_h + k * block_size * in_strides[2];
size_t out_offset_w = out_offset_h + k * out_strides[2];
int64_t in_offset_w = in_offset_h + k * block_size * in_strides[2];
int64_t out_offset_w = out_offset_h + k * out_strides[2];
for (int l = 0; l < block_size; ++l) {
memcpy((int8_t *)output + (out_offset_w + l * block_size * in_strides[DIMENSION_2D]) * param->date_type_len,
(const int8_t *)input + (in_offset_w + l * in_strides[DIMENSION_1D]) * param->date_type_len,

View File

@ -17,6 +17,22 @@
#ifndef MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_
#define MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_
#define SSE_ROW_NUM_1 1
#define SSE_ROW_NUM_2 2
#define SSE_ROW_NUM_3 3
#define SSE_INDEX_1 1
#define SSE_INDEX_2 2
#define SSE_INDEX_3 3
#define SSE_INDEX_4 4
#define SSE_INDEX_5 5
#define SSE_INDEX_6 6
#define SSE_SHUFFLE_0321 (_MM_SHUFFLE(0, 3, 2, 1))
#define SSE_ACT_RELU 1
#define SSE_ACT_RELU6 3
static inline void ActBlock1(__m128 *v1, size_t relu, size_t relu6) {
__m128 zero_ma = _mm_setzero_ps();
if (relu || relu6) {
@ -98,7 +114,7 @@ static inline void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m
__m128 relu6 = _mm_set_ps1(6.0);
__m128 zero = _mm_setzero_ps();
switch (relu_type) {
case 3:
case SSE_ACT_RELU6:
*v1 = _mm_min_ps(*v1, relu6);
*v2 = _mm_min_ps(*v2, relu6);
*v3 = _mm_min_ps(*v3, relu6);
@ -107,7 +123,7 @@ static inline void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m
*v6 = _mm_min_ps(*v6, relu6);
*v7 = _mm_min_ps(*v7, relu6);
*v8 = _mm_min_ps(*v8, relu6);
case 1:
case SSE_ACT_RELU:
*v1 = _mm_max_ps(*v1, zero);
*v2 = _mm_max_ps(*v2, zero);
*v3 = _mm_max_ps(*v3, zero);
@ -124,15 +140,15 @@ static inline void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m
static inline void WriteCol1(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
__m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
_mm_store_ss(*dst, *dst1);
if (r > 1) {
if (r > SSE_ROW_NUM_1) {
*dst += stride;
_mm_store_ss(*dst, *dst3);
}
if (r > 2) {
if (r > SSE_ROW_NUM_2) {
*dst += stride;
_mm_store_ss(*dst, *dst5);
}
if (r > 3) {
if (r > SSE_ROW_NUM_3) {
*dst += stride;
_mm_store_ss(*dst, *dst7);
*dst += stride;
@ -143,24 +159,24 @@ static inline void WriteCol1(float **dst, __m128 *dst1, __m128 *dst2, __m128 *ds
static inline void WriteCol2(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
__m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int r) {
_mm_store_ss(*dst, *dst1);
*dst1 = _mm_shuffle_ps(*dst1, *dst1, _MM_SHUFFLE(0, 3, 2, 1));
*dst1 = _mm_shuffle_ps(*dst1, *dst1, SSE_SHUFFLE_0321);
_mm_store_ss(*dst, *dst1);
if (r > 1) {
if (r > SSE_ROW_NUM_1) {
*dst += stride;
_mm_store_ss(*dst, *dst3);
*dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1));
*dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321);
_mm_store_ss(*dst, *dst3);
}
if (r > 2) {
if (r > SSE_ROW_NUM_2) {
*dst += stride;
_mm_store_ss(*dst, *dst5);
*dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1));
*dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321);
_mm_store_ss(*dst, *dst5);
}
if (r > 3) {
if (r > SSE_ROW_NUM_3) {
*dst += stride;
_mm_store_ss(*dst, *dst7);
*dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1));
*dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321);
_mm_store_ss(*dst, *dst7);
}
}
@ -168,55 +184,55 @@ static inline void WriteCol2(float **dst, __m128 *dst1, __m128 *dst2, __m128 *ds
static inline void WriteCol2Opt(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
__m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int r) {
_mm_store_ss(*dst, *dst1);
*dst1 = _mm_shuffle_ps(*dst1, *dst1, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 1, *dst1);
if (r > 1) {
*dst1 = _mm_shuffle_ps(*dst1, *dst1, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_1, *dst1);
if (r > SSE_ROW_NUM_1) {
*dst += stride;
_mm_store_ss(*dst, *dst3);
*dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 1, *dst3);
*dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_1, *dst3);
}
if (r > 2) {
if (r > SSE_ROW_NUM_2) {
*dst += stride;
_mm_store_ss(*dst, *dst5);
*dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 1, *dst5);
*dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_1, *dst5);
}
if (r > 3) {
if (r > SSE_ROW_NUM_3) {
*dst += stride;
_mm_store_ss(*dst, *dst7);
*dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 1, *dst7);
*dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_1, *dst7);
*dst += stride;
*dst += 2;
*dst += SSE_INDEX_2;
}
}
static inline void WriteCol3(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
__m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
if (r > 1) {
if (r > SSE_ROW_NUM_1) {
*dst += stride;
_mm_store_ss(*dst, *dst3);
*dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 1, *dst3);
*dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 2, *dst3);
*dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_1, *dst3);
*dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_2, *dst3);
}
if (r > 2) {
if (r > SSE_ROW_NUM_2) {
*dst += stride;
_mm_store_ss(*dst, *dst5);
*dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 1, *dst5);
*dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 2, *dst5);
*dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_1, *dst5);
*dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_2, *dst5);
}
if (r > 3) {
if (r > SSE_ROW_NUM_3) {
*dst += stride;
_mm_store_ss(*dst, *dst7);
*dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 1, *dst7);
*dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 2, *dst7);
*dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_1, *dst7);
*dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_2, *dst7);
*dst += stride;
*dst += extra_stride;
}
@ -225,15 +241,15 @@ static inline void WriteCol3(float **dst, __m128 *dst1, __m128 *dst2, __m128 *ds
static inline void WriteCol4(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
__m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
_mm_storeu_ps(*dst, *dst1);
if (r > 1) {
if (r > SSE_ROW_NUM_1) {
*dst += stride;
_mm_storeu_ps(*dst, *dst3);
}
if (r > 2) {
if (r > SSE_ROW_NUM_2) {
*dst += stride;
_mm_storeu_ps(*dst, *dst5);
}
if (r > 3) {
if (r > SSE_ROW_NUM_3) {
*dst += stride;
_mm_storeu_ps(*dst, *dst7);
*dst += stride;
@ -244,21 +260,21 @@ static inline void WriteCol4(float **dst, __m128 *dst1, __m128 *dst2, __m128 *ds
static inline void WriteCol5(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
__m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
_mm_storeu_ps(*dst, *dst1);
_mm_store_ss(*dst + 4, *dst2);
if (r > 1) {
_mm_store_ss(*dst + SSE_INDEX_4, *dst2);
if (r > SSE_ROW_NUM_1) {
*dst += stride;
_mm_storeu_ps(*dst, *dst3);
_mm_store_ss(*dst + 4, *dst4);
_mm_store_ss(*dst + SSE_INDEX_4, *dst4);
}
if (r > 2) {
if (r > SSE_ROW_NUM_2) {
*dst += stride;
_mm_storeu_ps(*dst, *dst5);
_mm_store_ss(*dst + 4, *dst6);
_mm_store_ss(*dst + SSE_INDEX_4, *dst6);
}
if (r > 3) {
if (r > SSE_ROW_NUM_3) {
*dst += stride;
_mm_storeu_ps(*dst, *dst7);
_mm_store_ss(*dst + 4, *dst8);
_mm_store_ss(*dst + SSE_INDEX_4, *dst8);
*dst += stride;
*dst += extra_stride;
}
@ -267,29 +283,29 @@ static inline void WriteCol5(float **dst, __m128 *dst1, __m128 *dst2, __m128 *ds
static inline void WriteCol6(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
__m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
_mm_storeu_ps(*dst, *dst1);
_mm_store_ss(*dst + 4, *dst2);
*dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 5, *dst2);
if (r > 1) {
_mm_store_ss(*dst + SSE_INDEX_4, *dst2);
*dst2 = _mm_shuffle_ps(*dst2, *dst2, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_5, *dst2);
if (r > SSE_ROW_NUM_1) {
*dst += stride;
_mm_storeu_ps(*dst, *dst3);
_mm_store_ss(*dst + 4, *dst4);
*dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 5, *dst4);
_mm_store_ss(*dst + SSE_INDEX_4, *dst4);
*dst4 = _mm_shuffle_ps(*dst4, *dst4, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_5, *dst4);
}
if (r > 2) {
if (r > SSE_ROW_NUM_2) {
*dst += stride;
_mm_storeu_ps(*dst, *dst5);
_mm_store_ss(*dst + 4, *dst6);
*dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 5, *dst6);
_mm_store_ss(*dst + SSE_INDEX_4, *dst6);
*dst6 = _mm_shuffle_ps(*dst6, *dst6, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_5, *dst6);
}
if (r > 3) {
if (r > SSE_ROW_NUM_3) {
*dst += stride;
_mm_storeu_ps(*dst, *dst7);
_mm_store_ss(*dst + 4, *dst8);
*dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 5, *dst8);
_mm_store_ss(*dst + SSE_INDEX_4, *dst8);
*dst8 = _mm_shuffle_ps(*dst8, *dst8, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_5, *dst8);
*dst += stride;
*dst += extra_stride;
}
@ -298,37 +314,37 @@ static inline void WriteCol6(float **dst, __m128 *dst1, __m128 *dst2, __m128 *ds
static inline void WriteCol7(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
__m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
_mm_storeu_ps(*dst, *dst1);
_mm_store_ss(*dst + 4, *dst2);
*dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 5, *dst2);
*dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 6, *dst2);
if (r > 1) {
_mm_store_ss(*dst + SSE_INDEX_4, *dst2);
*dst2 = _mm_shuffle_ps(*dst2, *dst2, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_5, *dst2);
*dst2 = _mm_shuffle_ps(*dst2, *dst2, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_6, *dst2);
if (r > SSE_ROW_NUM_1) {
*dst += stride;
_mm_storeu_ps(*dst, *dst3);
_mm_store_ss(*dst + 4, *dst4);
*dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 5, *dst4);
*dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 6, *dst4);
_mm_store_ss(*dst + SSE_INDEX_4, *dst4);
*dst4 = _mm_shuffle_ps(*dst4, *dst4, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_5, *dst4);
*dst4 = _mm_shuffle_ps(*dst4, *dst4, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_6, *dst4);
}
if (r > 2) {
if (r > SSE_ROW_NUM_2) {
*dst += stride;
_mm_storeu_ps(*dst, *dst5);
_mm_store_ss(*dst + 4, *dst6);
*dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 5, *dst6);
*dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 6, *dst6);
_mm_store_ss(*dst + SSE_INDEX_4, *dst6);
*dst6 = _mm_shuffle_ps(*dst6, *dst6, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_5, *dst6);
*dst6 = _mm_shuffle_ps(*dst6, *dst6, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_6, *dst6);
}
if (r > 3) {
if (r > SSE_ROW_NUM_3) {
*dst += stride;
_mm_storeu_ps(*dst, *dst7);
_mm_store_ss(*dst + 4, *dst8);
*dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 5, *dst8);
*dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(*dst + 6, *dst8);
_mm_store_ss(*dst + SSE_INDEX_4, *dst8);
*dst8 = _mm_shuffle_ps(*dst8, *dst8, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_5, *dst8);
*dst8 = _mm_shuffle_ps(*dst8, *dst8, SSE_SHUFFLE_0321);
_mm_store_ss(*dst + SSE_INDEX_6, *dst8);
*dst += stride;
*dst += extra_stride;
}
@ -337,21 +353,21 @@ static inline void WriteCol7(float **dst, __m128 *dst1, __m128 *dst2, __m128 *ds
static inline void WriteCol8(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
__m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
_mm_storeu_ps(*dst, *dst1);
_mm_storeu_ps(*dst + 4, *dst2);
if (r > 1) {
_mm_storeu_ps(*dst + SSE_INDEX_4, *dst2);
if (r > SSE_ROW_NUM_1) {
*dst += stride;
_mm_storeu_ps(*dst, *dst3);
_mm_storeu_ps(*dst + 4, *dst4);
_mm_storeu_ps(*dst + SSE_INDEX_4, *dst4);
}
if (r > 2) {
if (r > SSE_ROW_NUM_2) {
*dst += stride;
_mm_storeu_ps(*dst, *dst5);
_mm_storeu_ps(*dst + 4, *dst6);
_mm_storeu_ps(*dst + SSE_INDEX_4, *dst6);
}
if (r > 3) {
if (r > SSE_ROW_NUM_3) {
*dst += stride;
_mm_storeu_ps(*dst, *dst7);
_mm_storeu_ps(*dst + 4, *dst8);
_mm_storeu_ps(*dst + SSE_INDEX_4, *dst8);
*dst += stride;
*dst += extra_stride;
}

View File

@ -43,9 +43,9 @@ int ArgMinMaxCPUKernel::ReSize() {
auto in_shape = in_tensors_.at(0)->shape();
auto dims_size = in_shape.size();
MS_CHECK_TRUE_MSG(dims_size >= 0, RET_ERROR, "The input shape is invalid.");
int axis = arg_param_->axis_ < 0 ? arg_param_->axis_ + dims_size : arg_param_->axis_;
arg_param_->axis_ = axis;
arg_param_->dims_size_ = static_cast<int>(dims_size);
int axis = arg_param_->axis_ < 0 ? (arg_param_->axis_ + arg_param_->dims_size_) : arg_param_->axis_;
arg_param_->axis_ = axis;
MS_CHECK_TRUE_MSG(axis >= 0 && axis < static_cast<int>(in_shape.size()), RET_ERROR, "The axis is invalid.");
if (arg_param_->topk_ <= 0 || arg_param_->topk_ > in_shape.at(axis)) {
MS_LOG(ERROR) << "Invalid topk " << arg_param_->topk_;

View File

@ -103,7 +103,7 @@ int StackBaseCPUKernel::StackExecute(int task_id) {
auto end = MSMIN(start + step, outer_size_);
auto input_num = in_tensors_.size();
MS_CHECK_FALSE(INT_MUL_OVERFLOW(input_num * static_cast<size_t>(start), copy_size_), RET_ERROR);
auto output = reinterpret_cast<char *>(output_data) + input_num * start * copy_size_;
auto output = reinterpret_cast<char *>(output_data) + input_num * static_cast<size_t>(start) * copy_size_;
Stack(all_inputs_, reinterpret_cast<void *>(output), input_num, copy_size_, start, end);
return RET_OK;
}

View File

@ -141,10 +141,11 @@ int StridedSliceCPUKernel::FastRunImpl(int task_id) {
auto out_shape = out_tensors_.front()->shape();
int begin_index = param_->begins_[split_axis_];
int caled_num = task_id * cal_num_per_thread_;
int64_t inner_size = static_cast<int64_t>(inner_size_);
if (parallel_on_outer_) {
uint8_t *cur_in_ptr = input_ptr_ + (caled_num * in_shape[split_axis_] + begin_index) * inner_size_;
uint8_t *cur_out_ptr = output_ptr_ + caled_num * out_shape[split_axis_] * inner_size_;
int cur_outer = outer_ - caled_num;
uint8_t *cur_in_ptr = input_ptr_ + (caled_num * in_shape[split_axis_] + begin_index) * inner_size;
uint8_t *cur_out_ptr = output_ptr_ + caled_num * out_shape[split_axis_] * inner_size;
int cur_outer = static_cast<int>(outer_) - caled_num;
if (cur_outer <= 0) {
return RET_OK;
}
@ -152,12 +153,12 @@ int StridedSliceCPUKernel::FastRunImpl(int task_id) {
cur_outer = cal_num_per_thread_;
}
FastStride(cur_in_ptr, cur_out_ptr, out_shape[split_axis_], param_->strides_[split_axis_], cur_outer, inner_size_,
in_shape[split_axis_] * inner_size_);
static_cast<size_t>(in_shape[split_axis_]) * inner_size_);
} else {
MS_CHECK_TRUE_MSG(parallel_on_split_axis_ == true, RET_ERROR,
"Stride slice op should be parallel on axis or outer size.");
uint8_t *cur_in_ptr = input_ptr_ + (caled_num * param_->strides_[split_axis_] + begin_index) * inner_size_;
uint8_t *cur_out_ptr = output_ptr_ + caled_num * inner_size_;
uint8_t *cur_in_ptr = input_ptr_ + (caled_num * param_->strides_[split_axis_] + begin_index) * inner_size;
uint8_t *cur_out_ptr = output_ptr_ + caled_num * inner_size;
int cal_axis_num = out_shape[split_axis_] - caled_num;
if (cal_axis_num <= 0) {
return RET_OK;