forked from mindspore-Ecosystem/mindspore
!20854 [MS][LITE][CPU] code check master
Merge pull request !20854 from liuzhongkai/code_check_100
This commit is contained in:
commit
cabc09b792
|
@ -84,6 +84,7 @@ int ElementFloorDiv(const float *in0, const float *in1, float *out, int size) {
|
|||
|
||||
int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
NNACL_ASSERT(in1[i] != 0);
|
||||
out[i] = in0[i] / in1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
|
||||
void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels,
|
||||
|
@ -30,10 +31,10 @@ void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weigh
|
|||
}
|
||||
#endif
|
||||
|
||||
void ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, int task_id) {
|
||||
if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0) {
|
||||
return;
|
||||
int ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, int task_id) {
|
||||
if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
|
||||
int h_start = h_step * task_id;
|
||||
|
@ -85,6 +86,7 @@ void ConvDw(float *output_data, const float *input_data, const float *weight_dat
|
|||
}
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
|
||||
|
|
|
@ -29,8 +29,8 @@ void DepthwiseCenter(float *dst, const float *src, const float *weight, const fl
|
|||
int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6);
|
||||
#endif
|
||||
|
||||
void ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, int task_id);
|
||||
int ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, int task_id);
|
||||
|
||||
void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block);
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/fp32/deconv_winograd_fp32.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
static void TransposeWeight(float *dst, size_t number) {
|
||||
#ifdef ENABLE_AVX
|
||||
|
@ -506,8 +507,11 @@ void DeConvWgCalCommFp32(const float *tile_in, float *tile_out, const float *wei
|
|||
return;
|
||||
}
|
||||
|
||||
void DeconvWg(const float *nhwc_input_, float *tile_in, float *tile_out, int start_index, int calculate_count,
|
||||
const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id) {
|
||||
int DeconvWg(const float *nhwc_input_, float *tile_in, float *tile_out, int start_index, int calculate_count,
|
||||
const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id) {
|
||||
if (deconv_param->in_tile_w_count_ == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
/* pack tile input */
|
||||
int tile_in_unit_stride = deconv_param->ic_up4_ * DECONV_WINOGRAD_DEFAULT_TILE;
|
||||
#ifdef ENABLE_ARM
|
||||
|
@ -555,7 +559,7 @@ void DeconvWg(const float *nhwc_input_, float *tile_in, float *tile_out, int sta
|
|||
|
||||
/* winograd a buffer */
|
||||
if (unit->winograd_.kh_ >= DECONV_WINOGRAD_BUFFER_COUNT) {
|
||||
return;
|
||||
return NNACL_ERR;
|
||||
}
|
||||
DeConvWgABuffer *wg_buf = &deconv_param->a_buffer_[unit->winograd_.kh_];
|
||||
float *wg_mid_a_buf = (float *)wg_buf->middle_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ *
|
||||
|
@ -574,11 +578,14 @@ void DeconvWg(const float *nhwc_input_, float *tile_in, float *tile_out, int sta
|
|||
unit->h_size_, unit->w_size_, conv_param, deconv_param);
|
||||
}
|
||||
}
|
||||
return;
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
void DeconvWgPost(const float *tile_out, float *nc4hw4_output, const ConvParameter *conv_param,
|
||||
const DeConvParam *deconv_param, int calculate_count, int tile_index) {
|
||||
int DeconvWgPost(const float *tile_out, float *nc4hw4_output, const ConvParameter *conv_param,
|
||||
const DeConvParam *deconv_param, int calculate_count, int tile_index) {
|
||||
if (deconv_param->in_tile_w_count_ == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
/* merge */
|
||||
int src_unit_stride = deconv_param->oc_up4_ * DECONV_WINOGRAD_DEFAULT_TILE;
|
||||
|
||||
|
@ -608,5 +615,5 @@ void DeconvWgPost(const float *tile_out, float *nc4hw4_output, const ConvParamet
|
|||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -30,10 +30,10 @@ extern "C" {
|
|||
|
||||
int PackDeConvWgDataFp32(const float *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param,
|
||||
const DeConvParam *deconv_param);
|
||||
void DeconvWg(const float *nhwc_input_, float *tile_in, float *tile_out, int start_index, int calculate_count,
|
||||
const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id);
|
||||
void DeconvWgPost(const float *tile_out, float *nc4hw4_output, const ConvParameter *conv_param,
|
||||
const DeConvParam *deconv_param, int calculate_count, int tile_index);
|
||||
int DeconvWg(const float *nhwc_input_, float *tile_in, float *tile_out, int start_index, int calculate_count,
|
||||
const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id);
|
||||
int DeconvWgPost(const float *tile_out, float *nc4hw4_output, const ConvParameter *conv_param,
|
||||
const DeConvParam *deconv_param, int calculate_count, int tile_index);
|
||||
void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t ic4, size_t cal_num, size_t oc4);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -36,13 +36,12 @@ void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param
|
|||
int kernel_plane = kernel_h * kernel_w;
|
||||
int dilation_h = conv_param->dilation_h_;
|
||||
int dilation_w = conv_param->dilation_w_;
|
||||
if (dilation_h == 0 || dilation_w == 0) {
|
||||
int out_w = conv_param->output_w_;
|
||||
if (dilation_h == 0 || dilation_w == 0 || out_w == 0) {
|
||||
return;
|
||||
}
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int in_w = conv_param->input_w_;
|
||||
int out_w = conv_param->output_w_;
|
||||
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int block_start = block_index + i;
|
||||
int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
|
|
|
@ -30,7 +30,9 @@ int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter
|
|||
int output_h = pooling_param->output_h_;
|
||||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
|
||||
if (output_w == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
#ifdef ENABLE_AVX
|
||||
int c8 = channel / C8NUM * C8NUM;
|
||||
MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(minf);
|
||||
|
@ -133,8 +135,8 @@ int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
void MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, int task_id,
|
||||
float minf, float maxf) {
|
||||
int MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, int task_id,
|
||||
float minf, float maxf) {
|
||||
int win_w = pooling_param->window_w_;
|
||||
int win_h = pooling_param->window_h_;
|
||||
int channel = pooling_param->input_channel_;
|
||||
|
@ -145,7 +147,9 @@ void MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParamete
|
|||
int output_batch = pooling_param->output_batch_;
|
||||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
|
||||
if (output_w == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
#ifdef ENABLE_AVX
|
||||
int c8 = channel / C8NUM * C8NUM;
|
||||
MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(minf);
|
||||
|
@ -227,4 +231,5 @@ void MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParamete
|
|||
} // real_cal_num loop
|
||||
} // out_plane loop
|
||||
} // out_batch loop
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -29,8 +29,8 @@ extern "C" {
|
|||
#endif
|
||||
int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, int task_id,
|
||||
float minf, float maxf);
|
||||
void MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, int task_id,
|
||||
float minf, float maxf);
|
||||
int MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, int task_id,
|
||||
float minf, float maxf);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -50,6 +50,9 @@ int ReduceMean(int outer_size, int inner_size, int axis_size, const float *src_d
|
|||
|
||||
int IntReduceMean(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid,
|
||||
int thread_num) {
|
||||
if (axis_size == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
if (src_data == NULL || dst_data == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "nnacl/fp32/space_to_batch_fp32.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
void DoSpaceToBatch(const float *input, float *output, const int *in_shape, const int *out_shape, const int *in_stride,
|
||||
const int *out_stride, const int *blocks, const int *paddings, int thread, int task_id) {
|
||||
int DoSpaceToBatch(const float *input, float *output, const int *in_shape, const int *out_shape, const int *in_stride,
|
||||
const int *out_stride, const int *blocks, const int *paddings, int thread, int task_id) {
|
||||
if (thread == 0) {
|
||||
return;
|
||||
return NNACL_ERR;
|
||||
}
|
||||
const int depth = in_shape[3];
|
||||
const int input_width = in_shape[2];
|
||||
|
@ -33,7 +34,9 @@ void DoSpaceToBatch(const float *input, float *output, const int *in_shape, cons
|
|||
const int block_shape_width = blocks[1];
|
||||
const int padding_top = paddings[0];
|
||||
const int padding_left = paddings[2];
|
||||
|
||||
if (input_batch_size == 0 || block_shape_width == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
size_t copy_size = depth * sizeof(float);
|
||||
|
||||
for (int out_b = task_id; out_b < output_batch_size; out_b += thread) {
|
||||
|
@ -57,5 +60,5 @@ void DoSpaceToBatch(const float *input, float *output, const int *in_shape, cons
|
|||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -40,8 +40,8 @@ typedef struct SpaceToBatchParameter {
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
void DoSpaceToBatch(const float *input, float *output, const int *in_shape, const int *out_shape, const int *in_stride,
|
||||
const int *out_stride, const int *blocks, const int *paddings, int thread, int task_id);
|
||||
int DoSpaceToBatch(const float *input, float *output, const int *in_shape, const int *out_shape, const int *in_stride,
|
||||
const int *out_stride, const int *blocks, const int *paddings, int thread, int task_id);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -97,8 +97,10 @@ int AdderCPUKernel::InitWeightBias() {
|
|||
|
||||
int AdderCPUKernel::RunImpl(int task_id) {
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->data_c());
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c());
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->MutableData());
|
||||
MS_ASSERT(ori_input_data != nullptr);
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
|
||||
AdderFp32(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), col_major_input_,
|
||||
output_addr, task_id, conv_param_);
|
||||
return RET_OK;
|
||||
|
|
|
@ -58,6 +58,9 @@ int AddNCPUKernel::Run() {
|
|||
auto input0_data = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
|
||||
auto input1_data = reinterpret_cast<float *>(in_tensors_[1]->MutableData());
|
||||
auto output_data = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
|
||||
MS_ASSERT(input0_data != nullptr);
|
||||
MS_ASSERT(input1_data != nullptr);
|
||||
MS_ASSERT(output_data != nullptr);
|
||||
if (static_cast<int>(elements_num_) < op_parameter_->thread_num_) {
|
||||
if (in_tensors_[0]->shape() == in_tensors_[1]->shape()) {
|
||||
ElementAdd(input0_data, input1_data, output_data, elements_num_);
|
||||
|
|
|
@ -301,6 +301,10 @@ int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
|
|||
if (break_pos_ < 1) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (break_pos_ > 10 || param_->out_strides_[break_pos_ - 1] == 0) {
|
||||
MS_LOG(ERROR) << "param_->out_strides_[break_pos_ - 1] is 0 or break_pos_ is > 10";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int batch = param_->out_elements_num_ / param_->out_strides_[break_pos_ - 1];
|
||||
int batch_per_thread = UP_DIV(batch, op_parameter_->thread_num_);
|
||||
|
||||
|
@ -328,6 +332,10 @@ int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
|
|||
}
|
||||
|
||||
int ArithmeticCPUKernel::BiasCalc(int task_id) {
|
||||
if (param_->ndim_ > 10 || param_->out_shape_[param_->ndim_ - 1] == 0) {
|
||||
MS_LOG(ERROR) << "BiasCalc param is error!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int last_shape = param_->out_shape_[param_->ndim_ - 1];
|
||||
int batch = param_->out_elements_num_ / last_shape;
|
||||
int batch_per_thread = UP_DIV(batch, op_parameter_->thread_num_);
|
||||
|
|
|
@ -58,10 +58,10 @@ int BiasCPUKernel::Run() {
|
|||
ms_context_->allocator->Free(tile_bias);
|
||||
return RET_ERROR;
|
||||
}
|
||||
BroadcastAdd(in, bias, tile_in, tile_bias, out, data_size, bias_param_);
|
||||
auto ret = BroadcastAdd(in, bias, tile_in, tile_bias, out, data_size, bias_param_);
|
||||
ms_context_->allocator->Free(tile_in);
|
||||
ms_context_->allocator->Free(tile_bias);
|
||||
return RET_OK;
|
||||
return ret;
|
||||
}
|
||||
|
||||
int BiasCPUKernel::Init() {
|
||||
|
|
|
@ -254,15 +254,21 @@ int Convolution1x1CPUKernel::Run() {
|
|||
} else {
|
||||
input_ptr_ = tmp_in;
|
||||
}
|
||||
|
||||
int ret = 0;
|
||||
if (multi_thread_by_hw_) {
|
||||
ParallelLaunch(this->ms_context_, Convolution1x1RunHw, this, thread_count_);
|
||||
ret = ParallelLaunch(this->ms_context_, Convolution1x1RunHw, this, thread_count_);
|
||||
} else {
|
||||
PackMatmulInput(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
ParallelLaunch(this->ms_context_, Convolution1x1Run, this, thread_count_);
|
||||
ret = ParallelLaunch(this->ms_context_, Convolution1x1Run, this, thread_count_);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
if (pack_input_ != nullptr) {
|
||||
ctx_->allocator->Free(pack_input_);
|
||||
pack_input_ = nullptr;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (pack_input_ != nullptr) {
|
||||
ctx_->allocator->Free(pack_input_);
|
||||
pack_input_ = nullptr;
|
||||
|
|
|
@ -91,6 +91,10 @@ int ConvolutionDepthwise3x3CPUKernel::Execute(int task_id) {
|
|||
int units = UP_DIV(conv_param_->output_w_, C2NUM); // F(2, 3) contains 2 conv units
|
||||
int c4 = UP_ROUND(conv_param_->input_channel_, C4NUM);
|
||||
auto buffer = buffer_ + C12NUM * c4 * units * task_id;
|
||||
if (conv_param_->thread_num_ == 0) {
|
||||
MS_LOG(ERROR) << "conv_param_->thread_num_ must be not equal to 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int step_oh = UP_DIV(conv_param_->output_h_, conv_param_->thread_num_);
|
||||
int start_oh = step_oh * task_id;
|
||||
int end_oh = MSMIN(start_oh + step_oh, conv_param_->output_h_);
|
||||
|
|
|
@ -90,8 +90,9 @@ int ConvolutionDepthwiseCPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int ConvolutionDepthwiseCPUKernel::Execute(int task_id) {
|
||||
ConvDw(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_, task_id);
|
||||
return RET_OK;
|
||||
auto ret =
|
||||
ConvDw(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_, task_id);
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ConvDwRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
|
|
|
@ -215,6 +215,10 @@ int DeConvolutionWinogradCPUKernel::InitComputeParam() {
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
int cur_count = 0;
|
||||
if (conv_param_->stride_h_ == 0 || conv_param_->stride_w_ == 0) {
|
||||
MS_LOG(ERROR) << "conv_param_->stride_w_ or conv_param_->stride_h_ is 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (int si_h = 0; si_h < conv_param_->stride_h_; si_h++) {
|
||||
if (si_h >= conv_param_->kernel_h_) {
|
||||
continue;
|
||||
|
@ -344,10 +348,18 @@ int DeConvolutionWinogradCPUKernel::DoDeconv(int task_id) {
|
|||
int calculate_count = MSMIN(DECONV_WINOGRAD_DEFAULT_TILE,
|
||||
deconv_param_->in_tile_w_count_ * deconv_param_->in_tile_h_count_ - start_index);
|
||||
|
||||
DeconvWg(nhwc_input_, tile_in, tile_out, start_index, calculate_count, conv_param_, deconv_param_, task_id);
|
||||
|
||||
auto ret =
|
||||
DeconvWg(nhwc_input_, tile_in, tile_out, start_index, calculate_count, conv_param_, deconv_param_, task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DeconvWg is error";
|
||||
return ret;
|
||||
}
|
||||
std::unique_lock<std::mutex> merge_lock(lock_);
|
||||
DeconvWgPost(tile_out, nc4hw4_output_, conv_param_, deconv_param_, calculate_count, tile_index);
|
||||
ret = DeconvWgPost(tile_out, nc4hw4_output_, conv_param_, deconv_param_, calculate_count, tile_index);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DeconvWgPost is error";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -49,8 +49,9 @@ int ExpCPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int ExpCPUKernel::DoExcute(int task_id) {
|
||||
ExpFusionFp32(reinterpret_cast<float *>(input_addr_), reinterpret_cast<float *>(output_addr_), param_, task_id);
|
||||
return RET_OK;
|
||||
auto ret =
|
||||
ExpFusionFp32(reinterpret_cast<float *>(input_addr_), reinterpret_cast<float *>(output_addr_), param_, task_id);
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ExpRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
|
|
|
@ -66,6 +66,10 @@ int InstanceNormCPUKernel::Run() {
|
|||
gamma_data_ = reinterpret_cast<float *>(in_tensors_.at(1)->data_c());
|
||||
beta_data_ = reinterpret_cast<float *>(in_tensors_.at(2)->data_c());
|
||||
dst_data_ = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
|
||||
MS_ASSERT(src_data_ != nullptr);
|
||||
MS_ASSERT(gamma_data_ != nullptr);
|
||||
MS_ASSERT(beta_data_ != nullptr);
|
||||
MS_ASSERT(dst_data_ != nullptr);
|
||||
auto ret = ParallelLaunch(this->ms_context_, InstanceNormRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InstanceNormRun error error_code[" << ret << "]";
|
||||
|
|
|
@ -54,6 +54,8 @@ int InvertPermutationCPUKernel::Run() {
|
|||
}
|
||||
auto input_ptr = reinterpret_cast<int32_t *>(in_tensor->data_c());
|
||||
auto output_ptr = reinterpret_cast<int32_t *>(out_tensor->data_c());
|
||||
MS_ASSERT(input_ptr != nullptr);
|
||||
MS_ASSERT(output_ptr != nullptr);
|
||||
if (input_ptr == nullptr || output_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "null pointer dereferencing.";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -102,6 +102,10 @@ int L2NormCPUKernel::DivSqrtSum(int task_id) {
|
|||
|
||||
int L2NormCPUKernel::CalcL2NormTrailingAxis(int task_id) {
|
||||
auto input = in_tensors_.at(0);
|
||||
if (input->shape().back() == 0) {
|
||||
MS_LOG(ERROR) << "input->shape().back() is 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int outer_size = input->ElementsNum() / input->shape().back();
|
||||
int unit = UP_DIV(outer_size, op_parameter_->thread_num_);
|
||||
int begin = task_id * unit;
|
||||
|
|
|
@ -316,7 +316,7 @@ int MatmulFp32BaseCPUKernel::Init() {
|
|||
if (params_->b_const_) {
|
||||
// only copy weight data
|
||||
// resize or run to pack
|
||||
auto b_tensor = in_tensors_[1];
|
||||
auto b_tensor = in_tensors_.at(1);
|
||||
src_b_ = reinterpret_cast<float *>(malloc(params_->batch * params_->deep_ * params_->col_ * sizeof(float)));
|
||||
if (src_b_ == nullptr) {
|
||||
MS_LOG(ERROR) << "matmul fp16 src_b_ is failed!";
|
||||
|
|
|
@ -143,7 +143,7 @@ void PadCPUKernel::InitMirrorPadBlock() {
|
|||
int dst_offset = dst_basic_offset;
|
||||
|
||||
int value = index;
|
||||
for (size_t i = 0; i < pad_region.size(); ++i) {
|
||||
for (size_t i = 0; i < pad_region.size() && pad_region_stride[i] != 0; ++i) {
|
||||
pad_cord[i] = value / pad_region_stride[i];
|
||||
value = value % pad_region_stride[i];
|
||||
}
|
||||
|
|
|
@ -61,14 +61,15 @@ int PoolingCPUKernel::RunImpl(int task_id) {
|
|||
minf = 0.f;
|
||||
maxf = 6.f;
|
||||
}
|
||||
int ret = 0;
|
||||
if (pooling_param_->pool_mode_ == PoolMode_MaxPool) {
|
||||
MaxPooling(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
|
||||
ret = MaxPooling(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
|
||||
} else {
|
||||
auto ret = AvgPooling(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "AcgPooling run failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = AvgPooling(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "AcgPooling run failed.";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -50,6 +50,10 @@ int PReluCPUKernel::Init() {
|
|||
|
||||
int PReluCPUKernel::DoExcute(int task_id) {
|
||||
int thread_num = prelu_param_->op_parameter_.thread_num_;
|
||||
if (thread_num == 0) {
|
||||
MS_LOG(ERROR) << "thread_num is 0!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (prelu_param_->channelShared) {
|
||||
int step = UP_DIV(prelu_param_->input_num_, thread_num);
|
||||
int start = task_id * step;
|
||||
|
|
|
@ -43,8 +43,12 @@ int ROIPoolingCPUKernel::ReSize() {
|
|||
auto in_shape = in_tensors_.front()->shape();
|
||||
auto out_shape = out_tensors_.front()->shape();
|
||||
int ndims = in_shape.size();
|
||||
if (ndims > 4) {
|
||||
MS_LOG(ERROR) << "ROIPooling ReSzie error ,shape dim greater than 4!";
|
||||
if (ndims < C4NUM) {
|
||||
MS_LOG(ERROR) << "ROIPooling in_shape.size() error ,shape dim greater than or equal to 4!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (out_shape.size() < C4NUM) {
|
||||
MS_LOG(ERROR) << "ROIPooling out_shape.size() error ,shape dim greater than or equal to 4!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
param_->ndim_ = ndims;
|
||||
|
|
|
@ -182,11 +182,13 @@ int ScaleCPUKernel::Run() {
|
|||
if (!scale_param_->const_scale_) {
|
||||
auto scale_tensor = in_tensors_.at(1);
|
||||
scale_ = reinterpret_cast<float *>(scale_tensor->data_c());
|
||||
MS_ASSERT(scale_ != nullptr);
|
||||
}
|
||||
if (!scale_param_->const_offset_) {
|
||||
MS_ASSERT(in_tensors_.size() == 3);
|
||||
auto offset_tensor = in_tensors_.at(2);
|
||||
offset_ = reinterpret_cast<float *>(offset_tensor->data_c());
|
||||
MS_ASSERT(offset_ != nullptr);
|
||||
}
|
||||
auto out_tensor = out_tensors_.front();
|
||||
output_ptr_ = reinterpret_cast<float *>(out_tensor->MutableData());
|
||||
|
|
|
@ -41,6 +41,7 @@ void SpaceToBatchCPUKernel::ProcessInput() {
|
|||
ComputeStrides(param_->output_shape_, param_->out_stride_, DIMENSION_4D);
|
||||
auto block_shape_data = in_tensors_[1]->data_c();
|
||||
auto block_shape = static_cast<int *>(block_shape_data);
|
||||
MS_ASSERT(block_shape != nullptr);
|
||||
for (int i = 0; i < in_tensors_[1]->ElementsNum(); i++) {
|
||||
param_->block_sizes_[i] = block_shape[i];
|
||||
}
|
||||
|
@ -60,8 +61,8 @@ int SpaceToBatchCPUKernel::Init() {
|
|||
|
||||
int SpaceToBatchFp32Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto op = reinterpret_cast<SpaceToBatchCPUKernel *>(cdata);
|
||||
op->DoRun(task_id);
|
||||
return RET_OK;
|
||||
auto ret = op->DoRun(task_id);
|
||||
return ret;
|
||||
}
|
||||
|
||||
int SpaceToBatchCPUKernel::ReSize() {
|
||||
|
@ -86,16 +87,19 @@ int SpaceToBatchCPUKernel::ReSize() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void SpaceToBatchCPUKernel::DoRun(int task_id) {
|
||||
DoSpaceToBatch(input_ptr_, output_ptr_, param_->input_shape_, param_->output_shape_, param_->in_stride_,
|
||||
param_->out_stride_, param_->block_sizes_, param_->paddings_, op_parameter_->thread_num_, task_id);
|
||||
return;
|
||||
int SpaceToBatchCPUKernel::DoRun(int task_id) {
|
||||
auto ret =
|
||||
DoSpaceToBatch(input_ptr_, output_ptr_, param_->input_shape_, param_->output_shape_, param_->in_stride_,
|
||||
param_->out_stride_, param_->block_sizes_, param_->paddings_, op_parameter_->thread_num_, task_id);
|
||||
return ret;
|
||||
}
|
||||
|
||||
int SpaceToBatchCPUKernel::Run() {
|
||||
MS_ASSERT(in_tensors_[0] != nullptr);
|
||||
input_ptr_ = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
|
||||
MS_ASSERT(input_ptr_ != nullptr);
|
||||
output_ptr_ = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
|
||||
MS_ASSERT(output_ptr_ != nullptr);
|
||||
if (in_tensors_.size() == 3) {
|
||||
if (!in_tensors_[1]->IsConst() || !in_tensors_[2]->IsConst()) {
|
||||
ProcessInput();
|
||||
|
|
|
@ -37,7 +37,7 @@ class SpaceToBatchCPUKernel : public InnerKernel {
|
|||
void ProcessInput();
|
||||
|
||||
public:
|
||||
void DoRun(int task_id);
|
||||
int DoRun(int task_id);
|
||||
|
||||
protected:
|
||||
SpaceToBatchParameter *param_;
|
||||
|
|
|
@ -67,6 +67,10 @@ int SparseToDenseCPUKernel::DoExcute(int task_id) {
|
|||
}
|
||||
int index_start = task_id * count_unit_;
|
||||
int index_end = index_start + real_dst_count;
|
||||
if (index_num == 0) {
|
||||
MS_LOG(ERROR) << "invalid index_num div";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int out_width = output_num / index_num;
|
||||
MS_ASSERT(sparse_indices_vect);
|
||||
MS_ASSERT(output_shape);
|
||||
|
@ -173,6 +177,7 @@ int SparseToDenseCPUKernel::Run() {
|
|||
}
|
||||
}
|
||||
output_data = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
|
||||
MS_ASSERT(output_data != nullptr);
|
||||
count_unit_ = thread_count_ > 1 ? UP_DIV(index_num, thread_count_) : index_num;
|
||||
ret = ParallelLaunch(this->ms_context_, SparseToDenseRun, this, s2d_param->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -31,8 +31,6 @@ class UniqueCPUKernel : public InnerKernel {
|
|||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -74,6 +74,10 @@ int WhereCPUKernel::RunWithSingleInput() {
|
|||
ComputeStrides(in_tensors_.at(0)->shape().data(), strides, where_param_->rank_);
|
||||
|
||||
auto data = ms_context_->allocator->Malloc(where_param_->condition_num_ * where_param_->rank_ * sizeof(int32_t));
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "macllov data is error!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int *result = reinterpret_cast<int *>(data);
|
||||
|
||||
int result_index = 0;
|
||||
|
@ -83,6 +87,10 @@ int WhereCPUKernel::RunWithSingleInput() {
|
|||
true_num++;
|
||||
int dim = index;
|
||||
for (int j = 0; j < where_param_->rank_; j++) {
|
||||
if (strides[j] == 0) {
|
||||
MS_LOG(ERROR) << "strides[j] is 0!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
result[result_index++] = dim / strides[j];
|
||||
dim %= strides[j];
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue