!21141 make some symbols const

Merge pull request !21141 from zhaozhenlong/lite/issue/pclint
This commit is contained in:
i-robot 2021-08-02 02:05:45 +00:00 committed by Gitee
commit 9a86dbecd5
50 changed files with 98 additions and 84 deletions

View File

@ -19,7 +19,7 @@
#include <arm_neon.h>
#endif
void TanhInt8(const int8_t *input_ptr, int8_t *output_ptr, int size, TanhQuantParameter *quant) {
void TanhInt8(const int8_t *input_ptr, int8_t *output_ptr, int size, const TanhQuantParameter *quant) {
for (int i = 0; i < size; ++i) {
float fp32_src = (input_ptr[i] - quant->in_zp_) * quant->in_scale_;
float fp32_dst = TanhOpt(fp32_src);

View File

@ -34,7 +34,7 @@ typedef struct TanhQuantParameter {
extern "C" {
#endif
void TanhInt8(const int8_t *input_ptr, int8_t *output_ptr, int size, TanhQuantParameter *quant);
void TanhInt8(const int8_t *input_ptr, int8_t *output_ptr, int size, const TanhQuantParameter *quant);
#ifdef __cplusplus
}

View File

@ -88,8 +88,8 @@ std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, const size_t
// only support op_type from current schema
bool IsPackedOp(int op_type) {
static std::vector<int> packed_ops = {schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_Conv2dTransposeFusion,
schema::PrimitiveType_MatMul};
static const std::vector<int> packed_ops = {
schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_Conv2dTransposeFusion, schema::PrimitiveType_MatMul};
return IsContain(packed_ops, op_type);
}
} // namespace lite

View File

@ -50,8 +50,8 @@ int StrToInt(const char *env) {
}
bool IsPrint(int level) {
static const char *env = std::getenv("GLOG_v");
static int ms_level = StrToInt(env);
static const char *const env = std::getenv("GLOG_v");
static const int ms_level = StrToInt(env);
if (level < 0) {
level = 2;
}

View File

@ -160,9 +160,9 @@ std::vector<std::string> MSTensorToStrings(const tensor::MSTensor *tensor) {
// Some primes between 2^63 and 2^64
namespace {
static uint64_t k0 = 0xc3a5c85c97cb3127ULL;
static uint64_t k1 = 0xb492b66fbe98f273ULL;
static uint64_t k2 = 0x9ae16a3b2f90404fULL;
static const uint64_t k0 = 0xc3a5c85c97cb3127ULL;
static const uint64_t k1 = 0xb492b66fbe98f273ULL;
static const uint64_t k2 = 0x9ae16a3b2f90404fULL;
uint64_t Fetch64Bit(const char *p) {
uint64_t result = 0;

View File

@ -123,7 +123,9 @@ int TensorList2TensorListC(TensorList *src, TensorListC *dst) {
dst->format_ = src->format();
dst->shape_value_ = src->shape().empty() ? 0 : src->shape().front();
dst->element_num_ = src->shape().empty() ? 0 : src->tensors().size();
if (dst->element_num_ * sizeof(TensorC) < 0 || dst->element_num_ * sizeof(TensorC) > MAX_MALLOC_SIZE) {
if ((dst->element_num_ != 0 && SIZE_MAX / dst->element_num_ < sizeof(TensorC)) ||
dst->element_num_ * sizeof(TensorC) > MAX_MALLOC_SIZE) {
MS_LOG(ERROR) << "data size error.";
return RET_ERROR;
}

View File

@ -100,7 +100,6 @@ void DefaultAllocator::Free(void *buf) {
}
UnLock();
free(buf);
buf = nullptr;
}
int DefaultAllocator::RefCount(void *buf) {

View File

@ -22,10 +22,10 @@ using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin,
std::vector<lite::Tensor *>::iterator dst_end,
std::vector<lite::Tensor *>::iterator src_begin,
std::vector<lite::Tensor *>::iterator src_limit) {
int CarryDataKernel::MoveData(const std::vector<lite::Tensor *>::iterator &dst_begin,
const std::vector<lite::Tensor *>::iterator &dst_end,
const std::vector<lite::Tensor *>::iterator &src_begin,
const std::vector<lite::Tensor *>::iterator &src_limit) {
for (auto dst_iter = dst_begin, src_iter = src_begin; dst_iter != dst_end; dst_iter++, src_iter++) {
if (src_iter == src_limit) {
MS_LOG(ERROR) << "out of range of input tensor";

View File

@ -30,8 +30,10 @@ class CarryDataKernel : public InnerKernel {
~CarryDataKernel() override = default;
protected:
int MoveData(std::vector<lite::Tensor *>::iterator dst_begin, std::vector<lite::Tensor *>::iterator dst_end,
std::vector<lite::Tensor *>::iterator src_begin, std::vector<lite::Tensor *>::iterator src_limit);
int MoveData(const std::vector<lite::Tensor *>::iterator &dst_begin,
const std::vector<lite::Tensor *>::iterator &dst_end,
const std::vector<lite::Tensor *>::iterator &src_begin,
const std::vector<lite::Tensor *>::iterator &src_limit);
int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor);
int MoveTensorListData(lite::TensorList *dst_tensorlist, lite::TensorList *src_tensorlist);
};

View File

@ -43,7 +43,7 @@ int CropBaseCPUKernel::ReSize() {
return RET_OK;
}
void CropBaseCPUKernel::PadOffset(int input_dim, CropParameter *crop_para) {
void CropBaseCPUKernel::PadOffset(int input_dim, CropParameter *crop_para) const {
auto axis = crop_para->axis_;
auto offsets_size = crop_para->offset_size_;
MS_ASSERT(axis <= input_dim);

View File

@ -40,7 +40,7 @@ class CropBaseCPUKernel : public InnerKernel {
std::vector<int> input_shape_;
std::vector<int> output_shape_;
CropParameter *crop_para_;
void PadOffset(int input_dim, CropParameter *crop_para);
void PadOffset(int input_dim, CropParameter *crop_para) const;
};
} // namespace mindspore::kernel

View File

@ -14,6 +14,8 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/detection_post_process_base.h"
#include <cfloat>
#include <cmath>
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@ -29,7 +31,7 @@ using mindspore::schema::PrimitiveType_DetectionPostProcess;
namespace mindspore::kernel {
void PartialArgSort(const float *scores, int *indexes, int num_to_sort, int num_values) {
std::partial_sort(indexes, indexes + num_to_sort, indexes + num_values, [&scores](const int i, const int j) {
if (scores[i] == scores[j]) {
if (std::abs(scores[i] - scores[j]) < FLT_EPSILON) {
return i < j;
}
return scores[i] > scores[j];

View File

@ -14,6 +14,8 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/pooling_base.h"
#include <cfloat>
#include <cmath>
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"
@ -59,7 +61,7 @@ int PoolingBaseCPUKernel::SetQuantParam() {
pooling_quant_arg_[1][0].scale_ = out_quant_arg.front().scale;
pooling_quant_arg_[1][0].zp_ = out_quant_arg.front().zeroPoint;
pooling_param_->quant_args_ = pooling_quant_arg_;
if (pooling_quant_arg_[0][0].scale_ == pooling_quant_arg_[1][0].scale_ &&
if (std::abs(pooling_quant_arg_[0][0].scale_ - pooling_quant_arg_[1][0].scale_) < FLT_EPSILON &&
pooling_quant_arg_[0][0].zp_ == pooling_quant_arg_[1][0].zp_) {
pooling_param_->quantize_ = false;
} else {

View File

@ -106,8 +106,8 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
ret = UInt8ToInt8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeInt8) {
auto input_quant_arg = in_tensors_.front()->quant_params().front();
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, num_unit_thread,
input_quant_arg.scale, input_quant_arg.zeroPoint);
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, input_quant_arg.scale,
input_quant_arg.zeroPoint, num_unit_thread);
if (ret) {
auto output_quant_arg = out_tensors_.front()->quant_params().front();
bool from_uint8_src = false;

View File

@ -34,7 +34,7 @@ int TileCPUKernel::Init() {
return ReSize();
}
void TileCPUKernel::ComputeStrides(const int *shape, int *strides, int ndim) {
void TileCPUKernel::ComputeStrides(const int *shape, int *strides, int ndim) const {
int stride = 1;
for (int i = ndim - 1; i >= 0; i--) {
strides[i] = stride;

View File

@ -35,7 +35,7 @@ class TileCPUKernel : public InnerKernel {
private:
int RunSimpleTile();
void ComputeStrides(const int *shape, int *strides, int ndim);
void ComputeStrides(const int *shape, int *strides, int ndim) const;
void FillOneDimTileParam();
bool one_dim_tile_ = false;
uint8_t *input_addr_ = nullptr;

View File

@ -53,7 +53,7 @@ int ArithmeticCPUKernel::ReSize() {
CalcMultiplesAndStrides(param_);
if (param_->broadcasting_) {
outside_ = 1;
for (auto i = param_->ndim_ - 1; i >= 0; --i) {
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0; --i) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_pos_ = i;
break;
@ -113,7 +113,7 @@ bool ArithmeticCPUKernel::IsBatchScalarCalc() { // 2 32 240 240, 2 32 1 1
return true;
}
bool ArithmeticCPUKernel::IsBiasCalc() { // 2 240 240 32, 1 1 1 32
bool ArithmeticCPUKernel::IsBiasCalc() const { // 2 240 240 32, 1 1 1 32
int last_shape0 = param_->in_shape0_[param_->ndim_ - 1];
int last_shape1 = param_->in_shape1_[param_->ndim_ - 1];
if (param_->in_elements_num0_ > param_->in_elements_num1_) {
@ -129,9 +129,7 @@ int ArithmeticCPUKernel::ConstTensorBroadCast() {
if (!param_->broadcasting_) {
return RET_OK;
}
if (out_tensors_[0]->Size() < 0) {
return RET_OK;
}
/* [1, 1, 2] + [1, 2, 1] -> [1, 2, 2], need broadcast both input */
if (param_->in_elements_num0_ != param_->out_elements_num_ &&
param_->in_elements_num1_ != param_->out_elements_num_) {

View File

@ -97,7 +97,7 @@ class ArithmeticCPUKernel : public InnerKernel {
int BatchScalarCalc(int task_id);
int BiasCalc(int task_id);
void FreeConstTileBuff();
bool IsBiasCalc();
bool IsBiasCalc() const;
ArithmeticRun arithmetic_run_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
ArithmeticIntRun arithmetic_run_int_ = nullptr;

View File

@ -27,10 +27,9 @@ struct TYPE_FUNC_INFO {
int primitive_type_ = 0;
ArithmeticSelfFunc func_ = nullptr;
};
using TYPE_FUNC_INFO = TYPE_FUNC_INFO;
} // namespace
ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_type) {
ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_type) const {
TYPE_FUNC_INFO type_func_table[] = {{mindspore::schema::PrimitiveType_Abs, ElementAbs},
{mindspore::schema::PrimitiveType_Cos, ElementCos},
{mindspore::schema::PrimitiveType_Log, ElementLog},
@ -53,7 +52,7 @@ ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_t
return nullptr;
}
ArithmeticSelfBoolFunc ArithmeticSelfCPUKernel::GetArithmeticSelfBoolFun(int primitive_type) {
ArithmeticSelfBoolFunc ArithmeticSelfCPUKernel::GetArithmeticSelfBoolFun(int primitive_type) const {
if (primitive_type == mindspore::schema::PrimitiveType_LogicalNot) {
return ElementLogicalNotBool;
}

View File

@ -53,8 +53,8 @@ class ArithmeticSelfCPUKernel : public InnerKernel {
virtual int DoExecute(int task_id);
private:
ArithmeticSelfFunc GetArithmeticSelfFun(int primitive_type);
ArithmeticSelfBoolFunc GetArithmeticSelfBoolFun(int primitive_type);
ArithmeticSelfFunc GetArithmeticSelfFun(int primitive_type) const;
ArithmeticSelfBoolFunc GetArithmeticSelfBoolFun(int primitive_type) const;
ArithmeticSelfFunc func_;
ArithmeticSelfBoolFunc func_bool_;
};

View File

@ -170,7 +170,7 @@ int Convolution1x1CPUKernel::Init() {
return RET_OK;
}
void Convolution1x1CPUKernel::PackMatmulInput(const float *src_ptr, float *dst_ptr, int row, int col) {
void Convolution1x1CPUKernel::PackMatmulInput(const float *src_ptr, float *dst_ptr, int row, int col) const {
#ifdef ENABLE_AVX
RowMajor2Col6Major(src_ptr, dst_ptr, row, col);
#elif defined(ENABLE_SSE)

View File

@ -53,7 +53,7 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel {
int InitConv1x1BiasWeight();
void InitConv1x1MatmulParam();
void FreeTmpBuffer();
void PackMatmulInput(const float *src_ptr, float *dst_ptr, int row, int col);
void PackMatmulInput(const float *src_ptr, float *dst_ptr, int row, int col) const;
void PackWeight();
private:

View File

@ -116,43 +116,43 @@ int LshProjectionCPUKernel::DoExecute(int task_id) {
return RET_OK;
}
int LshProjectionCPUKernel::GetSignBit(int32_t *feature_, float *weight_, float seed, LshProjectionParameter *para,
int LshProjectionCPUKernel::GetSignBit(int32_t *feature, float *weight, float seed, LshProjectionParameter *para,
char *hash_buff) {
double score = 0.0;
for (int i = 0; i < para->feature_num_; i++) {
memcpy(hash_buff, &seed, sizeof(float));
memcpy(hash_buff + sizeof(float), &(feature_[i]), sizeof(int32_t));
memcpy(hash_buff + sizeof(float), &(feature[i]), sizeof(int32_t));
int64_t hash_i = static_cast<int64_t>(lite::StringHash64(hash_buff, para->hash_buff_size_));
double hash_d = static_cast<double>(hash_i);
if (weight_ == nullptr) {
if (weight == nullptr) {
score += hash_d;
} else {
score += weight_[i] * hash_d;
score += weight[i] * hash_d;
}
}
return (score > 0) ? 1 : 0;
}
void LshProjectionCPUKernel::LshProjectionSparse(float *hash_seed_, int32_t *feature_, float *weight_, int32_t *output_,
void LshProjectionCPUKernel::LshProjectionSparse(float *hashSeed, int32_t *feature, float *weight, int32_t *output,
LshProjectionParameter *para, int32_t start, int32_t end,
char *hash_buff) {
for (int i = start; i < end; i++) {
int32_t hash_sign = 0;
for (int j = 0; j < para->hash_shape_[1]; j++) {
int bit = GetSignBit(feature_, weight_, hash_seed_[i * para->hash_shape_[1] + j], para, hash_buff);
int bit = GetSignBit(feature, weight, hashSeed[i * para->hash_shape_[1] + j], para, hash_buff);
hash_sign = (hash_sign << 1) | bit;
}
output_[i] = hash_sign + i * (1 << para->hash_shape_[1]);
output[i] = hash_sign + i * (1 << para->hash_shape_[1]);
}
}
void LshProjectionCPUKernel::LshProjectionDense(float *hash_seed_, int32_t *feature_, float *weight_, int32_t *output_,
void LshProjectionCPUKernel::LshProjectionDense(float *hashSeed, int32_t *feature, float *weight, int32_t *output,
LshProjectionParameter *para, int32_t start, int32_t end,
char *hash_buff) {
for (int i = start; i < end; i++) {
for (int j = 0; j < para->hash_shape_[1]; j++) {
output_[i * para->hash_shape_[1] + j] =
GetSignBit(feature_, weight_, hash_seed_[i * para->hash_shape_[1] + j], para, hash_buff);
output[i * para->hash_shape_[1] + j] =
GetSignBit(feature, weight, hashSeed[i * para->hash_shape_[1] + j], para, hash_buff);
}
}
}

View File

@ -40,10 +40,10 @@ class LshProjectionCPUKernel : public InnerKernel {
private:
int MallocKeys();
void FreeKeys();
int GetSignBit(int32_t *feature_, float *weight_, float seed, LshProjectionParameter *para, char *hash_buff);
void LshProjectionSparse(float *hash_seed_, int32_t *feature_, float *weight_, int32_t *output_,
int GetSignBit(int32_t *feature, float *weight, float seed, LshProjectionParameter *para, char *hash_buff);
void LshProjectionSparse(float *hashSeed, int32_t *feature, float *weight, int32_t *output,
LshProjectionParameter *param, int32_t start, int32_t end, char *hash_buff);
void LshProjectionDense(float *hash_seed_, int32_t *feature_, float *weight_, int32_t *output_,
void LshProjectionDense(float *hashSeed, int32_t *feature, float *weight, int32_t *output,
LshProjectionParameter *param, int32_t start, int32_t end, char *hash_buff);
LshProjectionParameter *param_ = nullptr;
float *hash_seed_ = nullptr;

View File

@ -234,7 +234,7 @@ void MatmulFp32BaseCPUKernel::FreeResizeBufB() {
}
}
int MatmulFp32BaseCPUKernel::FloatRun(int task_id) {
int MatmulFp32BaseCPUKernel::FloatRun(int task_id) const {
int current_start_oc = task_id * thread_stride_ * col_tile_;
int current_rest_oc = 0;
#if defined(ENABLE_AVX)

View File

@ -42,7 +42,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
int Run() override;
public:
int FloatRun(int task_id);
int FloatRun(int task_id) const;
protected:
int InitBufferA();

View File

@ -16,6 +16,8 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NON_MAX_SUPPRESSION_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NON_MAX_SUPPRESSION_H_
#include <cfloat>
#include <cmath>
#include <vector>
#include <algorithm>
#include "src/inner_kernel.h"
@ -78,7 +80,7 @@ class NMSBox {
area_ = (y2_ - y1_) * (x2_ - x1_);
}
inline bool operator<(const NMSBox &box) const {
return score_ < box.score_ || (score_ == box.score_ && index_ > box.index_);
return score_ < box.score_ || (std::abs(score_ - box.score_) < FLT_EPSILON && index_ > box.index_);
}
public:

View File

@ -185,7 +185,7 @@ void PadCPUKernel::InitMirrorPadBlock() {
return;
}
int PadCPUKernel::ExtendShape(int *shape, int length, const int *ori_shape, int rank) {
int PadCPUKernel::ExtendShape(int *shape, int length, const int *ori_shape, int rank) const {
if (shape == nullptr || ori_shape == nullptr) {
return RET_NULL_PTR;
}
@ -198,7 +198,7 @@ int PadCPUKernel::ExtendShape(int *shape, int length, const int *ori_shape, int
return RET_OK;
}
int PadCPUKernel::ExtendPaddings(int *paddings, int length, const int *ori_paddings, int ori_length) {
int PadCPUKernel::ExtendPaddings(int *paddings, int length, const int *ori_paddings, int ori_length) const {
if (paddings == nullptr || ori_paddings == nullptr) {
return RET_NULL_PTR;
}

View File

@ -47,8 +47,8 @@ class PadCPUKernel : public InnerKernel {
private:
int CheckPaddings(int *paddings, int length, int *input_shape, int mode);
void CalculateStrides();
int ExtendShape(int *shape, int length, const int *ori_shape, int rank);
int ExtendPaddings(int *paddings, int length, const int *ori_paddings, int ori_length);
int ExtendShape(int *shape, int length, const int *ori_shape, int rank) const;
int ExtendPaddings(int *paddings, int length, const int *ori_paddings, int ori_length) const;
void InitMirrorPadBlock();
protected:

View File

@ -85,7 +85,7 @@ int ResizeCPUKernel::ReSize() {
// Bicubic goes one step beyond bilinear by considering the closest 4x4 neighborhood of known pixels --- for a total of
// 16 pixels. Since these are at various distances from the unknown pixel, closer pixels are given a higher weighting in
// the calculation.
void ResizeCPUKernel::CalTmpBufferLen(int *x_len, int *y_len, int *x_weight_len, int *y_weight_len) {
void ResizeCPUKernel::CalTmpBufferLen(int *x_len, int *y_len, int *x_weight_len, int *y_weight_len) const {
if (method_ == static_cast<int>(schema::ResizeMethod_LINEAR)) {
*x_len = new_width_;
*y_len = new_height_;

View File

@ -71,7 +71,7 @@ class ResizeCPUKernel : public ResizeBaseCPUKernel {
virtual int RunImpl(int task_id);
int SelectCalculatorFunc();
int ResizePrepare();
void CalTmpBufferLen(int *x_len, int *y_len, int *x_weight_len, int *y_weight_len);
void CalTmpBufferLen(int *x_len, int *y_len, int *x_weight_len, int *y_weight_len) const;
int MallocTmpBuffer();
void FreeTmpBuffer();

View File

@ -30,20 +30,20 @@ int ReverseSequenceCPUKernel::Init() {
return ReSize();
}
void ReverseSequenceCPUKernel::ConvertAxisToPositive(const std::vector<int> shape, int *axis) {
void ReverseSequenceCPUKernel::ConvertAxisToPositive(const std::vector<int> shape, int *axis) const {
if (axis != nullptr && *axis < 0) {
*axis += static_cast<int>(shape.size());
}
}
int ReverseSequenceCPUKernel::CalcCountPreAxis(const std::vector<int> shape, int axis) {
int ReverseSequenceCPUKernel::CalcCountPreAxis(const std::vector<int> shape, int axis) const {
int count = 1;
for (int i = 0; i < axis; ++i) {
count *= shape.at(i);
}
return count;
}
int ReverseSequenceCPUKernel::CalcCountAfterAxis(const std::vector<int> shape, int axis) {
int ReverseSequenceCPUKernel::CalcCountAfterAxis(const std::vector<int> shape, int axis) const {
int count = 1;
for (size_t i = axis + 1; i < shape.size(); ++i) {
count *= shape.at(i);

View File

@ -33,9 +33,9 @@ class ReverseSequenceCPUKernel : public InnerKernel {
int Run() override;
private:
void ConvertAxisToPositive(const std::vector<int> shape, int *axis);
int CalcCountPreAxis(const std::vector<int> shape, int axis);
int CalcCountAfterAxis(const std::vector<int> shape, int axis);
void ConvertAxisToPositive(const std::vector<int> shape, int *axis) const;
int CalcCountPreAxis(const std::vector<int> shape, int axis) const;
int CalcCountAfterAxis(const std::vector<int> shape, int axis) const;
};
} // namespace mindspore::kernel

View File

@ -147,7 +147,7 @@ int SparseToDenseCPUKernel::GenerateIndices() {
return RET_OK;
}
int SparseToDenseCPUKernel::IndicesValidCheck() {
int SparseToDenseCPUKernel::IndicesValidCheck() const {
int d1 = output_shape[1] * output_shape[2] * output_shape[3];
int d2 = output_shape[2] * output_shape[3];
int d3 = output_shape[3];

View File

@ -41,7 +41,7 @@ class SparseToDenseCPUKernel : public InnerKernel {
int Run() override;
int DoExcute(int task_id);
int GenerateIndices();
int IndicesValidCheck();
int IndicesValidCheck() const;
protected:
const InnerContext *ctx_;

View File

@ -133,7 +133,7 @@ int QuantizedAddCPUKernel::ReSize() {
if (arith_para_->broadcasting_) {
size_t break_pos_ = 0;
for (auto i = arith_para_->ndim_ - 1; i >= 0; --i) {
for (int i = static_cast<int>(arith_para_->ndim_) - 1; i >= 0; --i) {
if (arith_para_->in_shape0_[i] != arith_para_->in_shape1_[i]) {
break_pos_ = i;
break;

View File

@ -14,6 +14,8 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/batch_to_space_int8.h"
#include <cfloat>
#include <cmath>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@ -76,7 +78,8 @@ int BatchToSpaceInt8CPUKernel::Run() {
auto out_shape = output->shape();
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
if (in_quant_arg_->scale_ == out_quant_arg_->scale_ && in_quant_arg_->zp_ == out_quant_arg_->zp_) {
if (std::abs(in_quant_arg_->scale_ - out_quant_arg_->scale_) < FLT_EPSILON &&
in_quant_arg_->zp_ == out_quant_arg_->zp_) {
if (param->no_crop_) {
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
sizeof(int8_t));

View File

@ -513,7 +513,7 @@ int Convolution1x1Int8CPUKernel::RunArmOc(int task_id) {
return RET_OK;
}
int Convolution1x1Int8CPUKernel::OcOptPre(int task_id) {
int Convolution1x1Int8CPUKernel::OcOptPre(int task_id) const {
int cur_stride = thread_stride_hw_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C4NUM;
int cur_hw = MSMIN(cur_stride, res_stride);

View File

@ -47,7 +47,7 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
public:
int OcRun(int task_id);
int HwRun(int task_id);
int OcOptPre(int task_id);
int OcOptPre(int task_id) const;
private:
int RunArmOc(int task_id);

View File

@ -14,6 +14,8 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/depth_to_space_int8.h"
#include <cfloat>
#include <cmath>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@ -72,7 +74,8 @@ int DepthToSpaceInt8CPUKernel::Run() {
const int8_t *input_data = reinterpret_cast<const int8_t *>(input->data_c());
int8_t *output_data = reinterpret_cast<int8_t *>(output->data_c());
auto in_shape = input->shape();
if (in_quant_arg_->scale_ == out_quant_arg_->scale_ && in_quant_arg_->zp_ == out_quant_arg_->zp_) {
if (std::abs(in_quant_arg_->scale_ - out_quant_arg_->scale_) < FLT_EPSILON &&
in_quant_arg_->zp_ == out_quant_arg_->zp_) {
DepthToSpaceForNHWC(input_data, output_data, in_shape.data(), param_);
} else {
DepthToSpaceForNHWCInt8(input_data, output_data, in_shape.data(), param_, in_quant_arg_, out_quant_arg_);

View File

@ -54,7 +54,7 @@ int HswishInt8CPUKernel::Init() {
return RET_OK;
}
void HswishInt8CPUKernel::MultiplierInt32ToInt16(int32_t input, int16_t *output) {
void HswishInt8CPUKernel::MultiplierInt32ToInt16(int32_t input, int16_t *output) const {
MS_ASSERT(input >= 0);
if (input >= std::numeric_limits<int32_t>::max() - (1 << 15)) {
*output = std::numeric_limits<int16_t>::max();

View File

@ -38,7 +38,7 @@ class HswishInt8CPUKernel : public InnerKernel {
private:
int thread_count_;
HswishQuantArg quant_arg_;
void MultiplierInt32ToInt16(int32_t input, int16_t *output);
void MultiplierInt32ToInt16(int32_t input, int16_t *output) const;
};
} // namespace mindspore::kernel

View File

@ -15,6 +15,8 @@
*/
#include "src/runtime/kernel/arm/int8/pad_int8.h"
#include <cfloat>
#include <cmath>
#include "src/kernel_registry.h"
using mindspore::lite::RET_ERROR;
@ -69,7 +71,7 @@ int PadInt8CPUKernel::SetQuantParam() {
pad_quant_args->out_quanr_args_->zp_ = out_quant_arg.front().zeroPoint;
pad_quant_args->out_quanr_args_->scale_ = out_quant_arg.front().scale;
if (pad_quant_args->in_quant_args_->scale_ != pad_quant_args->out_quanr_args_->scale_ ||
if (std::abs(pad_quant_args->in_quant_args_->scale_ - pad_quant_args->out_quanr_args_->scale_) > FLT_EPSILON ||
pad_quant_args->in_quant_args_->zp_ != pad_quant_args->out_quanr_args_->zp_) {
MS_LOG(ERROR) << "Pad int8 op : scale & zp of output and input must be equal.";
return RET_ERROR;
@ -168,7 +170,7 @@ void PadInt8CPUKernel::CalculateStrides() {
}
}
int PadInt8CPUKernel::ExtendPaddings(int *paddings, int length, const int *ori_paddings, int ori_length) {
int PadInt8CPUKernel::ExtendPaddings(int *paddings, int length, const int *ori_paddings, int ori_length) const {
if (paddings == nullptr || ori_paddings == nullptr) {
return RET_NULL_PTR;
}

View File

@ -52,7 +52,7 @@ class PadInt8CPUKernel : public InnerKernel {
int CheckPaddings(const int *paddings, int length, const int *input_shape, int mode);
int CopyPaddingFromInput();
void CalculateStrides();
int ExtendPaddings(int *paddings, int length, const int *ori_paddings, int ori_length);
int ExtendPaddings(int *paddings, int length, const int *ori_paddings, int ori_length) const;
PadParameter *pad_param_ = nullptr;
int8_t *in_data_ = nullptr;

View File

@ -252,7 +252,7 @@ int ScaleInt8CPUKernel::ReSize() {
return RET_OK;
}
int ScaleInt8CPUKernel::Scale(int task_id) {
int ScaleInt8CPUKernel::Scale(int task_id) const {
int real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_);
if (real_dst_count <= 0) {
return lite::RET_OK;

View File

@ -40,7 +40,7 @@ class ScaleInt8CPUKernel : public InnerKernel {
int Run() override;
int InitParameter();
int InitScaleOffset();
int Scale(int task_id);
int Scale(int task_id) const;
private:
int8_t *input0_data_ = nullptr;

View File

@ -64,7 +64,7 @@ int SoftmaxInt8CPUKernel::Init() {
quant_param_->output_activation_max_ = std::numeric_limits<int8_t>::max();
const double input_real_multiplier =
MSMIN(quant_param_->in_quant_args_.scale_ * (1 << (unsigned int)(31 - 5)), (1ll << 31) - 1.0);
MSMIN(quant_param_->in_quant_args_.scale_ * (1 << (unsigned int)(31 - 5)), (1LL << 31) - 1.0);
int right_shift = 0;
QuantizeMultiplierSmallerThanOne(input_real_multiplier, &quant_param_->output_multiplier_, &right_shift);
quant_param_->shift_left_ = right_shift < 0 ? -right_shift : 0;

View File

@ -42,7 +42,7 @@ int TanhInt8CPUKernel::ReSize() {
return RET_OK;
}
int TanhInt8CPUKernel::DoActivation(int task_id) {
int TanhInt8CPUKernel::DoActivation(int task_id) const {
int current_size = element_size_ - task_id * thread_stride_;
current_size = MSMIN(thread_stride_, current_size);
if (current_size <= 0) {

View File

@ -38,7 +38,7 @@ class TanhInt8CPUKernel : public InnerKernel {
int Run() override;
public:
int DoActivation(int task_id);
int DoActivation(int task_id) const;
private:
int8_t *in_ptr_{nullptr};

View File

@ -75,7 +75,7 @@ std::string NormalizeCPUKernel::Normalize(const std::string &str) {
result = GlobalReplace(result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "$1");
// transform shortening to full
MS_ASSERT(kRegexTransforms != nullptr);
for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end(); iter++) {
for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end(); ++iter) {
result = GlobalReplace(result, iter->first, iter->second);
}
result = GlobalReplace(result, "([?])+", "$1");