[MSLITE][DEVELOP] clean code check warnings

This commit is contained in:
yangruoqi713 2021-10-25 16:43:40 +08:00
parent f64c2b8f38
commit d071a7e91f
25 changed files with 108 additions and 83 deletions

View File

@ -33,9 +33,11 @@ class MixPrecisionCfg {
this->num_of_not_nan_iter_th_ = 1000;
}
bool dynamic_loss_scale_ = false; /**< Enable\disable dynamic loss scale during mix precision training */
float loss_scale_; /**< Initial loss scale factor */
uint32_t num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */
~MixPrecisionCfg() = default;
bool dynamic_loss_scale_ = false; /**< Enable\disable dynamic loss scale during mix precision training */
float loss_scale_; /**< Initial loss scale factor */
uint32_t num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */
bool is_raw_mix_precision_ = false; /**< Is mix precision model export from mindspore */
};
@ -43,6 +45,8 @@ class TrainCfg {
public:
TrainCfg() { this->loss_name_ = "_loss_fn"; }
~TrainCfg() = default;
OptimizationLevel optimization_level_ = kO0;
std::string loss_name_; /**< Set part of the name that identify a loss kernel */
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */

View File

@ -54,15 +54,15 @@
#define MSVALID(left, x, right) (MSMIN((MSMAX(left, x)), right))
#define SIZE_MUL_OVERFLOW(x, y) (((x) == 0) ? false : (SIZE_MAX / (x)) < (y))
#define INT_MUL_OVERFLOW(x, y) \
((x == 0) ? false \
: ((x) > 0 ? ((y >= 0) ? (INT_MAX / (x)) < (y) : (INT_MAX / (x)) < (-1 * (y))) \
: ((y >= 0) ? (INT_MAX / (x)) > (-1 * (y)) : (INT_MAX / (x)) > (y))))
#define INT_MUL_OVERFLOW(x, y) \
(((x) == 0) ? false \
: ((x) > 0 ? (((y) >= 0) ? (INT_MAX / (x)) < (y) : (INT_MAX / (x)) < (-1 * (y))) \
: (((y) >= 0) ? (INT_MAX / (x)) > (-1 * (y)) : (INT_MAX / (x)) > (y))))
#define INT_MUL_OVERFLOW_THRESHOLD(x, y, threshold) \
((x == 0) ? false \
: ((x) > 0 ? ((y >= 0) ? ((threshold) / (x)) < (y) : ((threshold) / (x)) < (-1 * (y))) \
: ((y >= 0) ? ((threshold) / (x)) > (-1 * (y)) : ((threshold) / (x)) > (y))))
#define INT_MUL_OVERFLOW_THRESHOLD(x, y, threshold) \
(((x) == 0) ? false \
: ((x) > 0 ? (((y) >= 0) ? ((threshold) / (x)) < (y) : ((threshold) / (x)) < (-1 * (y))) \
: (((y) >= 0) ? ((threshold) / (x)) > (-1 * (y)) : ((threshold) / (x)) > (y))))
#define INT_ADD_OVERFLOW(x, y) (INT_MAX - (x)) < (y)

View File

@ -69,6 +69,9 @@ class String {
}
String(size_t count, char ch) {
if (count > SIZE_MAX / sizeof(char) - 1) {
MS_C_EXCEPTION("Invalid string size");
}
buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * (count + 1)));
if (buffer_ == nullptr) {
MS_C_EXCEPTION("malloc data failed");
@ -98,7 +101,7 @@ class String {
if (buffer_ == nullptr) {
MS_C_EXCEPTION("malloc data failed");
}
strncpy(buffer_, s, size_);
memcpy(buffer_, s, size_);
buffer_[size_] = '\0';
}
@ -150,12 +153,17 @@ class String {
if (buffer_ == nullptr) {
MS_C_EXCEPTION("malloc data failed");
}
strncpy(buffer_, other.buffer_ + pos, size_);
memcpy(buffer_, other.buffer_ + pos, size_);
buffer_[size_] = '\0';
}
}
~String() { free(buffer_); }
~String() {
if (buffer_ != nullptr) {
free(buffer_);
buffer_ = nullptr;
}
}
String &operator=(const String &str) {
if (this == &str) {
@ -241,7 +249,9 @@ class String {
}
String &append(size_t count, const char ch) {
(*this) += ch;
for (size_t i = 0; i < count; i++) {
(*this) += ch;
}
return *this;
}
@ -264,6 +274,9 @@ class String {
}
String &operator+=(const String &str) {
if (size_ > SIZE_MAX / sizeof(char) - str.size_ - 1) {
MS_C_EXCEPTION("Invalid string size");
}
size_t new_size = size_ + str.size_;
char *tmp = reinterpret_cast<char *>(malloc(sizeof(char) * (new_size + 1)));
if (tmp == nullptr) {
@ -283,6 +296,9 @@ class String {
return *this;
}
size_t str_size = strlen(str);
if (size_ > SIZE_MAX / sizeof(char) - str_size - 1) {
MS_C_EXCEPTION("Invalid string size");
}
size_t new_size = size_ + str_size;
char *tmp = reinterpret_cast<char *>(malloc(sizeof(char) * (new_size + 1)));
if (tmp == nullptr) {
@ -298,6 +314,9 @@ class String {
}
String &operator+=(const char ch) {
if (size_ > SIZE_MAX / sizeof(char) - 2) {
MS_C_EXCEPTION("Invalid string size");
}
char *tmp = reinterpret_cast<char *>(malloc(sizeof(char) * (size_ + 2)));
if (tmp == nullptr) {
MS_C_EXCEPTION("malloc data failed");
@ -585,6 +604,10 @@ class Vector {
size_ = vec.size_;
elem_size_ = sizeof(T);
capacity_ = vec.capacity_;
if (data_ != nullptr) {
delete[] data_;
data_ = nullptr;
}
data_ = new (std::nothrow) T[capacity_];
if (data_ == nullptr) {
MS_C_EXCEPTION("malloc data failed");

View File

@ -17,17 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_ADAPTER_H_
#define MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_ADAPTER_H_
#include <functional>
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <utility>
#include <unordered_map>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/cell.h"
#include "include/lite_session.h"
#include "include/train/train_loop_callback.h"
namespace mindspore {

View File

@ -18,8 +18,6 @@
#include "src/common/log_adapter.h"
namespace mindspore {
class GraphImpl {};
std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const {
std::vector<Output> empty;
MS_LOG(ERROR) << "Unsupported feature.";

View File

@ -171,7 +171,7 @@ std::shared_ptr<Delegate> Context::GetDelegate() const {
}
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
static std::vector<std::shared_ptr<DeviceInfoContext>> empty;
static std::vector<std::shared_ptr<DeviceInfoContext>> empty{};
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return empty;

View File

@ -27,7 +27,7 @@
namespace mindspore {
constexpr static int kMaxNumOfDevices = 3;
Status AddCpuDevice(Context *a_context, lite::InnerContext *l_context, DeviceInfoContext *device) {
Status AddCpuDevice(const Context *a_context, lite::InnerContext *l_context, DeviceInfoContext *device) {
auto cpu_context = device->Cast<CPUDeviceInfo>();
l_context->allocator = cpu_context->GetAllocator();
if (l_context->allocator == nullptr) {
@ -54,7 +54,7 @@ Status AddCpuDevice(Context *a_context, lite::InnerContext *l_context, DeviceInf
return kSuccess;
}
Status AddGpuDevice(Context *a_context, lite::InnerContext *l_context, DeviceInfoContext *device) {
Status AddGpuDevice(lite::InnerContext *l_context, DeviceInfoContext *device) {
lite::DeviceInfo device_info = {0};
auto gpu_context = device->Cast<GPUDeviceInfo>();
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16(), gpu_context->GetDeviceID()};
@ -63,7 +63,7 @@ Status AddGpuDevice(Context *a_context, lite::InnerContext *l_context, DeviceInf
return kSuccess;
}
Status AddNpuDevice(Context *a_context, lite::InnerContext *l_context, DeviceInfoContext *device) {
Status AddNpuDevice(lite::InnerContext *l_context, DeviceInfoContext *device) {
lite::DeviceInfo device_info = {0};
auto npu_context = device->Cast<KirinNPUDeviceInfo>();
device_info.npu_device_info_ = {npu_context->GetFrequency()};
@ -71,7 +71,7 @@ Status AddNpuDevice(Context *a_context, lite::InnerContext *l_context, DeviceInf
return kSuccess;
}
Status AddAscend310Device(Context *a_context, lite::InnerContext *l_context, DeviceInfoContext *device) {
Status AddAscend310Device(lite::InnerContext *l_context, DeviceInfoContext *device) {
lite::DeviceInfo device_info = {0};
auto ascend310_context = device->Cast<Ascend310DeviceInfo>();
device_info.ascend310_device_info_ = {ascend310_context->GetDeviceID()};
@ -105,11 +105,11 @@ Status A2L_ConvertContext(Context *a_context, lite::InnerContext *l_context) {
if (device->GetDeviceType() == kCPU) {
error_code = AddCpuDevice(a_context, l_context, device.get());
} else if (device->GetDeviceType() == kGPU) {
error_code = AddGpuDevice(a_context, l_context, device.get());
error_code = AddGpuDevice(l_context, device.get());
} else if (device->GetDeviceType() == kKirinNPU) {
error_code = AddNpuDevice(a_context, l_context, device.get());
error_code = AddNpuDevice(l_context, device.get());
} else if (device->GetDeviceType() == kAscend310) {
error_code = AddAscend310Device(a_context, l_context, device.get());
error_code = AddAscend310Device(l_context, device.get());
} else {
MS_LOG(ERROR) << "Invalid device.";
return kLiteInputParamInvalid;

View File

@ -255,7 +255,7 @@ Status Model::SetOptimizerParams(const std::vector<MSTensor> &params) {
return impl_->SetOptimizerParams(params);
}
Status Model::InitMetrics(std::vector<Metrics *> metrics) {
Status Model::InitMetrics(const std::vector<Metrics *> metrics) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteUninitializedObj;

View File

@ -51,11 +51,13 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode
MS_CHECK_TRUE_MSG(lite_context != nullptr, kLiteNullptr, "inner context failed");
auto status = A2L_ConvertContext(ms_context.get(), lite_context);
if (status != kSuccess) {
delete lite_context;
return status;
}
auto session = std::shared_ptr<session::LiteSession>(CreateLiteSession(lite_context));
if (session == nullptr) {
delete lite_context;
MS_LOG(ERROR) << "Allocate session failed.";
return kLiteNullptr;
}
@ -77,11 +79,13 @@ Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
MS_CHECK_TRUE_MSG(lite_context != nullptr, kLiteNullptr, "inner context failed");
auto status = A2L_ConvertContext(ms_context.get(), lite_context);
if (status != kSuccess) {
delete lite_context;
return status;
}
auto session = std::shared_ptr<session::LiteSession>(CreateLiteSession(lite_context));
if (session == nullptr) {
delete lite_context;
MS_LOG(ERROR) << "Allocate session failed.";
return kLiteNullptr;
}
@ -113,6 +117,7 @@ Status ModelImpl::Build() {
MS_CHECK_TRUE_MSG(lite_context != nullptr, kLiteNullptr, "inner context failed");
auto status = A2L_ConvertContext(context_.get(), lite_context);
if (status != kSuccess) {
delete lite_context;
MS_LOG(ERROR) << "Failed to convert Context to Lite Context";
return status;
}
@ -129,12 +134,14 @@ Status ModelImpl::Build() {
auto model = graph_->graph_data_->lite_model();
if (model == nullptr || model->buf == nullptr) {
delete lite_context;
MS_LOG(ERROR) << "Lite model has been freed.";
return kLiteError;
}
auto session = std::shared_ptr<session::LiteSession>(CreateLiteSession(lite_context));
if (session == nullptr) {
delete lite_context;
MS_LOG(ERROR) << "Allocate session failed.";
return kLiteNullptr;
}
@ -149,7 +156,7 @@ Status ModelImpl::Build() {
return kSuccess;
}
static void ResetTensorData(std::vector<void *> old_data, std::vector<tensor::MSTensor *> tensors) {
static void ResetTensorData(std::vector<void *> old_data, const std::vector<tensor::MSTensor *> &tensors) {
for (size_t j = 0; j < old_data.size(); j++) {
tensors.at(j)->set_data(old_data.at(j));
}

View File

@ -84,7 +84,7 @@ class ModelImpl {
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
bool IsTrainModel();
Status InitMetrics(std::vector<Metrics *> metrics) {
Status InitMetrics(const std::vector<Metrics *> metrics) {
metrics_ = metrics;
return kSuccess;
}

View File

@ -88,6 +88,9 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
}
std::string filename(file.data(), file.size());
if (filename.find_last_of(".") == std::string::npos) {
return kLiteInputParamInvalid;
}
if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
filename = filename + ".ms";
}

View File

@ -41,10 +41,7 @@ class Buffer::Impl {
void *MutableData() { return data_.data(); }
size_t DataSize() const { return data_.size(); }
bool ResizeData(size_t data_len) {
data_.resize(data_len);
return true;
}
void ResizeData(size_t data_len) { data_.resize(data_len); }
bool SetData(const void *data, size_t data_len) {
ResizeData(data_len);
@ -89,7 +86,7 @@ bool MSTensor::operator==(const MSTensor &tensor) const {
MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
if (data_len < 0 || data_len > MAX_MALLOC_SIZE) {
if (data_len > MAX_MALLOC_SIZE) {
MS_LOG(ERROR) << "data_len is error.";
return nullptr;
}
@ -182,7 +179,7 @@ MSTensor *MSTensor::Clone() const {
return nullptr;
}
auto data_len = this->DataSize();
if (data_len <= 0 || data_len > MAX_MALLOC_SIZE) {
if (data_len > MAX_MALLOC_SIZE) {
MS_LOG(ERROR) << "Illegal data size of tensor.";
return nullptr;
}
@ -235,7 +232,7 @@ enum DataType MSTensor::DataType() const {
}
const std::vector<int64_t> &MSTensor::Shape() const {
static std::vector<int64_t> empty;
static const std::vector<int64_t> empty{};
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Invalid tensor implement.";
return empty;
@ -409,7 +406,8 @@ bool Buffer::ResizeData(size_t data_len) {
MS_LOG(ERROR) << "impl is nullptr.";
return false;
}
return impl_->ResizeData(data_len);
impl_->ResizeData(data_len);
return true;
}
bool Buffer::SetData(const void *data, size_t data_len) {

View File

@ -82,6 +82,8 @@ int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_
return RET_ERROR;
}
CHECK_NULL_RETURN(src_tensor->data());
CHECK_NULL_RETURN(dst_tensor->data());
// need replace with increase data ref count
memcpy(dst_tensor->data(), src_tensor->data(), src_tensor->Size());
return RET_OK;

View File

@ -32,7 +32,7 @@ void *ConvolutionBaseCPUKernel::MallocAlignedData(size_t alignment, size_t size)
MS_LOG(ERROR) << "MallocAlignedData failed!";
return nullptr;
}
auto aligned_ptr = (reinterpret_cast<uintptr_t>(ptr) + alignment - 1) & (~(alignment - 1));
uintptr_t aligned_ptr = (reinterpret_cast<uintptr_t>(ptr) + alignment - 1) & (~(alignment - 1));
addr_map[aligned_ptr] = ptr;
return reinterpret_cast<void *>(aligned_ptr);
}
@ -220,8 +220,7 @@ int ConvolutionBaseCPUKernel::SetIfPerChannel() {
}
int ConvolutionBaseCPUKernel::MallocQuantParam() {
conv_quant_arg_ = &conv_param_->conv_quant_arg_;
CHECK_NULL_RETURN(conv_quant_arg_);
conv_quant_arg_ = &(conv_param_->conv_quant_arg_);
auto input_tensor = in_tensors_.at(kInputIndex);
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto output_tensor = out_tensors_.at(kOutputIndex);

View File

@ -17,13 +17,13 @@
#include "src/runtime/kernel/arm/base/group_convolution_creator.h"
namespace mindspore::kernel {
void CopyTensorQuantParam(lite::Tensor *dst, lite::Tensor *src) {
void CopyTensorQuantParam(lite::Tensor *dst, const lite::Tensor *src) {
for (size_t i = 0; i < src->quant_params().size(); i++) {
dst->AddQuantParam(src->quant_params().at(i));
}
}
ConvParameter *CreateNewConvParameter(ConvParameter *parameter) {
ConvParameter *CreateNewConvParameter(const ConvParameter *parameter) {
auto conv_parameter = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (conv_parameter == nullptr) {
MS_LOG(ERROR) << "Malloc new conv parameter failed.";
@ -37,6 +37,7 @@ void FreeCurrentConv(ConvParameter *conv_param, std::vector<lite::Tensor *> *new
std::vector<lite::Tensor *> *new_outputs) {
if (conv_param != nullptr) {
free(conv_param);
conv_param = nullptr;
}
if (new_inputs != nullptr) {
for (auto &in_tensor : *new_inputs) {
@ -61,7 +62,7 @@ static inline lite::Tensor *TensorMalloc(lite::Tensor *tensor) {
return tensor;
}
lite::Tensor *CreateConstTensor(lite::Tensor *tensor, const std::vector<int> &shape, const int index) {
lite::Tensor *CreateConstTensor(const lite::Tensor *tensor, const std::vector<int> &shape, const int index) {
auto new_tensor =
new (std::nothrow) lite::Tensor(tensor->data_type(), shape, mindspore::NHWC, lite::Tensor::Category::CONST_TENSOR);
if (new_tensor == nullptr) {
@ -76,7 +77,7 @@ lite::Tensor *CreateConstTensor(lite::Tensor *tensor, const std::vector<int> &sh
}
uint8_t *new_tensor_data = reinterpret_cast<uint8_t *>(tensor->data()) + index * new_tensor->Size();
memcpy(new_tensor->data(), new_tensor_data, new_tensor->Size());
memcpy(new_tensor->data(), reinterpret_cast<void *>(new_tensor_data), new_tensor->Size());
return new_tensor;
}
@ -104,7 +105,7 @@ lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info, bool inferred) {
}
/* Class GroupConv Creator Implement Part */
void GroupConvCreator::CopyQuantParam(std::vector<lite::Tensor *> *tensors) {
void GroupConvCreator::CopyQuantParam(const std::vector<lite::Tensor *> *tensors) {
for (size_t j = 0; j < origin_inputs_.size(); ++j) {
CopyTensorQuantParam(tensors->at(j), origin_inputs_.at(j));
}
@ -112,11 +113,13 @@ void GroupConvCreator::CopyQuantParam(std::vector<lite::Tensor *> *tensors) {
void GroupConvCreator::FreeGroupConvs() {
for (auto &sub_conv : group_convs_) {
for (auto &in_tensor : sub_conv->in_tensors()) {
for (auto in_tensor : sub_conv->in_tensors()) {
delete in_tensor;
in_tensor = nullptr;
}
for (auto &out_tensor : sub_conv->out_tensors()) {
for (auto out_tensor : sub_conv->out_tensors()) {
delete out_tensor;
out_tensor = nullptr;
}
delete sub_conv;
sub_conv = nullptr;

View File

@ -51,7 +51,7 @@ class GroupConvCreator {
void SetShapeOfTensors();
int CreateConvs(std::vector<kernel::InnerKernel *> *group_convs);
std::vector<kernel::InnerKernel *> *get_group_conv() { return &group_convs_; }
void CopyQuantParam(std::vector<lite::Tensor *> *tensors);
void CopyQuantParam(const std::vector<lite::Tensor *> *tensors);
int GetSingleConvParam(ConvParameter *conv_param, std::vector<lite::Tensor *> *new_inputs,
std::vector<lite::Tensor *> *new_outputs, int group_id);
@ -80,7 +80,7 @@ class GroupConvCreator {
const lite::InnerContext *ctx_ = nullptr;
};
ConvParameter *CreateNewConvParameter(ConvParameter *parameter);
ConvParameter *CreateNewConvParameter(const ConvParameter *parameter);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GROUP_CONVOLUTION_CREATOR_H_

View File

@ -50,12 +50,11 @@ int OneHotCPUKernel::Prepare() {
}
thread_num_ = op_parameter_->thread_num_;
auto param = reinterpret_cast<OneHotParameter *>(op_parameter_);
if (param == nullptr) {
if (one_hot_param_ == nullptr) {
MS_LOG(ERROR) << "OneHot op_parameter_ nullptr";
return RET_NULL_PTR;
}
axis_ = param->axis_;
axis_ = one_hot_param_->axis_;
if (!InferShapeDone()) {
return RET_OK;
@ -115,16 +114,15 @@ int OneHotCPUKernel::OneHotImpl(int task_id) {
if (output_data == nullptr) {
return RET_NULL_PTR;
}
auto one_hot_param = reinterpret_cast<OneHotParameter *>(op_parameter_);
if (output->data_type() == kNumberTypeFloat32) {
auto ret = OneHotToFp32(indices_data, on_value_, off_value_, reinterpret_cast<float *>(output_data), one_hot_param,
auto ret = OneHotToFp32(indices_data, on_value_, off_value_, reinterpret_cast<float *>(output_data), one_hot_param_,
task_id, thread_num_);
return ret;
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
} else if (output->data_type() == kNumberTypeFloat16) {
auto ret = OneHotToFp16(indices_data, on_value_, off_value_, reinterpret_cast<float16_t *>(output_data),
one_hot_param, task_id, thread_num_);
one_hot_param_, task_id, thread_num_);
return ret;
#endif
} else {
@ -134,12 +132,6 @@ int OneHotCPUKernel::OneHotImpl(int task_id) {
}
int OneHotCPUKernel::InitParamsAndOnOffValue() {
auto one_hot_param = reinterpret_cast<OneHotParameter *>(op_parameter_);
if (one_hot_param == nullptr) {
MS_LOG(ERROR) << "cast OneHotParameter nullptr";
return RET_NULL_PTR;
}
auto depth_tensor = in_tensors_.at(1);
if (depth_tensor == nullptr) {
MS_LOG(ERROR) << "OneHot inputs[1] depth nullptr";
@ -149,11 +141,11 @@ int OneHotCPUKernel::InitParamsAndOnOffValue() {
if (depth == nullptr) {
return RET_NULL_PTR;
}
one_hot_param->depth_ = *depth;
one_hot_param_->depth_ = *depth;
if (in_tensors_.size() == kInputNum) {
// 4 inputs: indices, depth, on_value, off_value
one_hot_param->support_neg_index_ = false;
one_hot_param_->support_neg_index_ = false;
auto ret = InitOnOffValueForFourInputs();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init on off value failed";
@ -161,7 +153,7 @@ int OneHotCPUKernel::InitParamsAndOnOffValue() {
}
} else {
// 3 inputs: indices, depth, off_on_value
one_hot_param->support_neg_index_ = true;
one_hot_param_->support_neg_index_ = true;
auto ret = InitOnOffValueForThreeInputs();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init on off value failed";
@ -169,8 +161,8 @@ int OneHotCPUKernel::InitParamsAndOnOffValue() {
}
}
one_hot_param->outer_size_ = outer_size_;
one_hot_param->inner_size_ = inner_size_;
one_hot_param_->outer_size_ = outer_size_;
one_hot_param_->inner_size_ = inner_size_;
return RET_OK;
}

View File

@ -18,13 +18,16 @@
#include <vector>
#include "src/inner_kernel.h"
#include "nnacl/one_hot_parameter.h"
namespace mindspore::kernel {
class OneHotCPUKernel : public InnerKernel {
public:
OneHotCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx) {}
: InnerKernel(parameter, inputs, outputs, ctx) {
one_hot_param_ = reinterpret_cast<OneHotParameter *>(parameter);
}
~OneHotCPUKernel() override = default;
@ -50,6 +53,7 @@ class OneHotCPUKernel : public InnerKernel {
float on_value_ = 0.;
float off_value_ = 0.;
#endif
OneHotParameter *one_hot_param_;
};
} // namespace mindspore::kernel

View File

@ -32,7 +32,7 @@ class PartialFusionKernel : public InnerKernel {
int ReSize() override;
int Run() override;
void set_subgraph_kernel(LiteKernel *subgraph_kernel) { subgraph_kernel_ = subgraph_kernel; }
LiteKernel *subgraph_kernel() { return subgraph_kernel_; }
LiteKernel *subgraph_kernel() const { return subgraph_kernel_; }
private:
LiteKernel *subgraph_kernel_ = nullptr;

View File

@ -208,12 +208,14 @@ int QuantDTypeCastCPUKernel::Run() {
if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 &&
out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) {
delete (float32_ptr_);
float32_ptr_ = nullptr;
}
return RET_ERROR;
}
if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 &&
out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) {
delete (float32_ptr_);
float32_ptr_ = nullptr;
}
return RET_OK;
}

View File

@ -103,6 +103,8 @@ int ReduceBaseCPUKernel::Prepare() {
MS_CHECK_FALSE_MSG(op_parameter_->thread_num_ == 0, RET_ERROR, "thread_num_ should not be 0");
if (in_tensors_.size() > 1) {
auto axes_tensor = in_tensors_.at(1);
MS_CHECK_FALSE_MSG((axes_tensor->data_type() != kNumberTypeInt && axes_tensor->data_type() != kNumberTypeInt32),
RET_ERROR, "The data type of axes tensor should be int32");
num_axes_ = axes_tensor->ElementsNum();
if (axes_tensor->ElementsNum() > MAX_SHAPE_SIZE) {
MS_LOG(ERROR) << "input axes invalid.";

View File

@ -47,7 +47,7 @@ class SplitWithOverlapBaseCPUKernel : public InnerKernel {
std::vector<int> end_indices_;
SplitWithOverlapParameter *param_ = nullptr;
int thread_count_;
int thread_count_ = 0;
char *input_ptr_{nullptr};
std::vector<char *> output_ptr_;

View File

@ -179,9 +179,6 @@ int StridedSliceCPUKernel::FastRun() {
CHECK_NULL_RETURN(input_ptr_);
output_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_.front()->data());
CHECK_NULL_RETURN(output_ptr_);
if (input_ptr_ == nullptr || output_ptr_ == nullptr) {
return RET_NULL_PTR;
}
auto ret = ParallelLaunch(this->ms_context_, StrideRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Stride run error error_code[" << ret << "]";

View File

@ -39,7 +39,6 @@ int TileCPUKernel::Prepare() {
}
int TileCPUKernel::ReSize() {
tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_);
CHECK_NULL_RETURN(tile_parameter_);
if (in_tensors_.size() == kDoubleInputsSize) {
if (in_tensors_[1]->ElementsNum() > static_cast<int>(in_tensors_[0]->shape().size())) {

View File

@ -25,7 +25,9 @@ class TileCPUKernel : public InnerKernel {
public:
TileCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx) {}
: InnerKernel(parameter, inputs, outputs, ctx) {
tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_);
}
~TileCPUKernel() override = default;
int Prepare() override;