diff --git a/include/api/callback/callback.h b/include/api/callback/callback.h index d10cffeb7c4..3332f8199be 100644 --- a/include/api/callback/callback.h +++ b/include/api/callback/callback.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "include/api/data_type.h" #include "include/api/dual_abi_helper.h" @@ -28,6 +29,8 @@ class Model; class ModelImpl; class CallbackImpl; +using GraphPoint = std::pair; + struct TrainCallBackData { TrainCallBackData(bool train_mode, int epoch, int step, Model *model): train_mode_(train_mode), epoch_(epoch), step_(step), model_(model) {} diff --git a/include/api/callback/loss_monitor.h b/include/api/callback/loss_monitor.h index 48684f3f1d4..9e0a8247e37 100644 --- a/include/api/callback/loss_monitor.h +++ b/include/api/callback/loss_monitor.h @@ -21,8 +21,6 @@ #include #include "include/api/callback/callback.h" -using GraphPoint = std::pair; - namespace mindspore { class LossMonitor: public TrainCallBack { diff --git a/include/api/callback/train_accuracy.h b/include/api/callback/train_accuracy.h index 0b31cfbc617..5838dfd9c1b 100644 --- a/include/api/callback/train_accuracy.h +++ b/include/api/callback/train_accuracy.h @@ -24,8 +24,6 @@ #include "include/api/callback/callback.h" #include "include/api/metrics/accuracy.h" -using GraphPoint = std::pair; - namespace mindspore { class TrainAccuracy: public TrainCallBack { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/gatherNd_fp32.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/gatherNd_fp32.h index cd92f0853b0..d5d591f164a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/gatherNd_fp32.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/gatherNd_fp32.h @@ -19,7 +19,7 @@ #include "nnacl/op_base.h" -typedef struct GatherNdParameter { +typedef struct { // Primitive parameter OpParameter op_parameter_; } GatherNdParameter; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/binary_cross_entropy.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/binary_cross_entropy.h index ce3581f8902..6ba6422db1f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/binary_cross_entropy.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/binary_cross_entropy.h @@ -18,7 +18,7 @@ #include "nnacl/op_base.h" -typedef struct BinaryCrossEntropyParameter { +typedef struct { OpParameter op_parameter_; int reduction; } BinaryCrossEntropyParameter; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/binary_cross_entropy_grad.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/binary_cross_entropy_grad.h index 57a7785a140..f3506f4f5ea 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/binary_cross_entropy_grad.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/binary_cross_entropy_grad.h @@ -18,7 +18,7 @@ #include "nnacl/op_base.h" -typedef struct BinaryCrossEntropyGradParameter { +typedef struct { OpParameter op_parameter_; int reduction; } BinaryCrossEntropyGradParameter; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/dropout_parameter.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/dropout_parameter.h index 29b988c4c8c..9042c1cd60f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/dropout_parameter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/dropout_parameter.h @@ -19,7 +19,7 @@ #include "nnacl/op_base.h" -typedef struct DropoutParameter { +typedef struct { OpParameter op_parameter_; float ratio_; } DropoutParameter; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/gemm.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/gemm.c index 8df87bc4bdb..2f07e1e47c2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/gemm.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/gemm.c @@ -240,7 +240,7 @@ static void RowMajor2Col12MajorStride(const float *src_ptr, float *dst_ptr, size } for (; ri < row; ri++) { - for (int i = 0; i < col; i++) { + for (size_t i = 0; i < col; i++) { dst_r[i * C12NUM] = src_r[i]; } src_r += lead; @@ -457,7 +457,7 @@ static void RowMajor2Col8MajorStride(const float *src_ptr, float *dst_ptr, size_ dst_r += C8NUM * col; } for (; ri < row; ri++) { - for (int i = 0; i < col; i++) { + for (size_t i = 0; i < col; i++) { dst_r[i * C8NUM] = src_r[i]; } src_r += lead; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/layernormgrad_parameter.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/layernormgrad_parameter.h index c783ffd6c85..8d5661e67c1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/layernormgrad_parameter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/layernormgrad_parameter.h @@ -18,7 +18,7 @@ #include "nnacl/op_base.h" -typedef struct LayerNormGradParameter { +typedef struct { OpParameter op_parameter_; int begin_norm_axis_; int begin_params_axis_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/optimizer.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/optimizer.h index 91659d60071..871eb62ce56 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/optimizer.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/optimizer.h @@ -19,20 +19,20 @@ #include "nnacl/op_base.h" -typedef struct ApplyMomentumParameter { +typedef struct { OpParameter op_parameter_; bool use_nesterov_; float grad_scale_; } ApplyMomentumParameter; -typedef struct SgdParameter { +typedef struct { OpParameter op_parameter_; float dampening_; bool use_nesterov_; float weight_decay_; } SgdParameter; -typedef struct AdamParameter { +typedef struct { OpParameter op_parameter_; bool use_nesterov_; } AdamParameter; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/smooth_l1_loss.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/smooth_l1_loss.h index d7c340710e7..7f65b910e24 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/smooth_l1_loss.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/smooth_l1_loss.h @@ -19,7 +19,7 @@ #include "nnacl/op_base.h" -typedef struct SmoothL1LossParameter { +typedef struct { OpParameter op_parameter_; float beta_; } SmoothL1LossParameter; diff --git a/mindspore/lite/examples/train_lenet/src/net_runner.cc b/mindspore/lite/examples/train_lenet/src/net_runner.cc index 0180eca5ae4..5df07d0556b 100644 --- a/mindspore/lite/examples/train_lenet/src/net_runner.cc +++ b/mindspore/lite/examples/train_lenet/src/net_runner.cc @@ -185,7 +185,7 @@ float NetRunner::CalculateAccuracy(int max_tests) { Rescaler rescale(kScalePoint); - loop_->Eval(test_ds_.get(), std::vector{&rescale}); + loop_->Eval(test_ds_.get(), std::vector{&rescale}, nullptr, INT_MAX); std::cout << "Accuracy is " << acc_metrics_->Eval() << std::endl; return 0.0; @@ -222,12 +222,13 @@ int NetRunner::TrainLoop() { Measurement measure(epochs_); if (virtual_batch_ > 0) { - loop_->Train(epochs_, train_ds_.get(), std::vector{&rescale, &lm, &cs, &am, &measure}); + loop_->Train(epochs_, train_ds_.get(), std::vector{&rescale, &lm, &cs, &am, &measure}, + nullptr); } else { struct mindspore::lite::StepLRLambda step_lr_lambda(1, kGammaFactor); mindspore::lite::LRScheduler step_lr_sched(mindspore::lite::StepLRLambda, static_cast(&step_lr_lambda), 1); loop_->Train(epochs_, train_ds_.get(), - std::vector{&rescale, &lm, &cs, &am, &step_lr_sched, &measure}); + std::vector{&rescale, &lm, &cs, &am, &step_lr_sched, &measure}, nullptr); } return 0; diff --git a/mindspore/lite/include/train/accuracy_monitor.h b/mindspore/lite/include/train/accuracy_monitor.h index 30f839d6732..9572975ed66 100644 --- a/mindspore/lite/include/train/accuracy_monitor.h +++ b/mindspore/lite/include/train/accuracy_monitor.h @@ -21,8 +21,6 @@ #include #include "include/train/train_loop.h" -using GraphPoint = std::pair; - namespace mindspore { namespace lite { diff --git a/mindspore/lite/include/train/ckpt_saver.h b/mindspore/lite/include/train/ckpt_saver.h index e33d5fa55eb..d4ae1806eac 100644 --- a/mindspore/lite/include/train/ckpt_saver.h +++ b/mindspore/lite/include/train/ckpt_saver.h @@ -15,24 +15,22 @@ */ #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_ #define MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_ -#include +#include #include #include #include #include #include "include/train/train_loop.h" -using GraphPoint = std::pair; - namespace mindspore { namespace lite { class CkptSaver : public session::TrainLoopCallBack { public: - CkptSaver(int save_every_n, const std::string &filename_prefix) - : save_every_n_(save_every_n), filename_prefix_(filename_prefix) {} + CkptSaver(size_t save_every_n, std::string filename_prefix) + : save_every_n_(save_every_n), filename_prefix_(std::move(filename_prefix)) {} - ~CkptSaver() = default; + ~CkptSaver() override = default; int EpochEnd(const session::TrainLoopCallBackData &cb_data) override { if ((cb_data.epoch_ + 1) % save_every_n_ == 0) { @@ -44,7 +42,7 @@ class CkptSaver : public session::TrainLoopCallBack { } private: - int save_every_n_; + size_t save_every_n_; std::string filename_prefix_; }; diff --git a/mindspore/lite/include/train/classification_train_accuracy_monitor.h b/mindspore/lite/include/train/classification_train_accuracy_monitor.h index 3df8af97ab7..acff9920065 100644 --- a/mindspore/lite/include/train/classification_train_accuracy_monitor.h +++ b/mindspore/lite/include/train/classification_train_accuracy_monitor.h @@ -24,8 +24,6 @@ #include "include/train/train_loop.h" #include "include/train/accuracy_metrics.h" -using GraphPoint = std::pair; - namespace mindspore { namespace lite { diff --git a/mindspore/lite/include/train/loss_monitor.h b/mindspore/lite/include/train/loss_monitor.h index a5adba81776..ffffe9d9040 100644 --- a/mindspore/lite/include/train/loss_monitor.h +++ b/mindspore/lite/include/train/loss_monitor.h @@ -22,8 +22,6 @@ #include #include "include/train/train_loop_callback.h" -using GraphPoint = std::pair; - namespace mindspore { namespace lite { diff --git a/mindspore/lite/include/train/train_cfg.h b/mindspore/lite/include/train/train_cfg.h index 1aa501729dd..11db4c3bcd0 100644 --- a/mindspore/lite/include/train/train_cfg.h +++ b/mindspore/lite/include/train/train_cfg.h @@ -36,13 +36,7 @@ class MixPrecisionCfg { this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_; this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_; } - MixPrecisionCfg &operator=(MixPrecisionCfg const &rhs) { - this->dynamic_loss_scale_ = rhs.dynamic_loss_scale_; - this->loss_scale_ = rhs.loss_scale_; - this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_; - this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_; - return *this; - } + MixPrecisionCfg &operator=(MixPrecisionCfg const &rhs) = default; bool dynamic_loss_scale_ = false; /**< Enable\disable dynamic loss scale during mix precision training */ float loss_scale_; /**< Initial loss scale factor */ bool keep_batchnorm_fp32_ = true; /**< Keep batch norm in FP32 while training */ @@ -58,12 +52,7 @@ class TrainCfg { this->mix_precision_cfg_ = rhs.mix_precision_cfg_; this->accumulate_gradients_ = rhs.accumulate_gradients_; } - TrainCfg &operator=(const TrainCfg &rhs) { - this->loss_name_ = rhs.loss_name_; - this->mix_precision_cfg_ = rhs.mix_precision_cfg_; - this->accumulate_gradients_ = rhs.accumulate_gradients_; - return *this; - } + TrainCfg &operator=(const TrainCfg &rhs) = default; std::vector loss_name_ = {"loss_fct"}; /**< Set part of the name that identify a loss kernel */ MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */ bool accumulate_gradients_ = false; /**< If true gardents are accmulated and can be read by GetGradients */ diff --git a/mindspore/lite/include/train/train_loop.h b/mindspore/lite/include/train/train_loop.h index 2d6c13bec8c..8a2164ad261 100644 --- a/mindspore/lite/include/train/train_loop.h +++ b/mindspore/lite/include/train/train_loop.h @@ -18,7 +18,6 @@ #include #include #include -#include #include #include "include/train/train_loop_callback.h" #include "include/train/metrics.h" @@ -87,7 +86,7 @@ class TrainLoop { /// /// \return 0 on success or -1 in case of error virtual int Train(int epochs, mindspore::dataset::Dataset *dataset, std::vector cbs, - LoadDataFunc load_func = nullptr) = 0; + LoadDataFunc load_func) = 0; /// \brief Performs loop over all data in Eval Mode /// @@ -97,8 +96,8 @@ class TrainLoop { /// \param[in] max_steps (with default = INT_MAX the method iterates all dataset) /// /// \return 0 on success or -1 in case of error - virtual int Eval(mindspore::dataset::Dataset *dataset, std::vector cbs, - LoadDataFunc load_func = nullptr, int max_steps = INT_MAX) = 0; + virtual int Eval(mindspore::dataset::Dataset *dataset, std::vector cbs, LoadDataFunc load_func, + int max_steps) = 0; }; } // namespace session } // namespace mindspore diff --git a/mindspore/lite/include/train/train_loop_callback.h b/mindspore/lite/include/train/train_loop_callback.h index 93ae265c519..87fa0490cc0 100644 --- a/mindspore/lite/include/train/train_loop_callback.h +++ b/mindspore/lite/include/train/train_loop_callback.h @@ -18,7 +18,9 @@ #include #include #include +#include #include +#include "include/api/callback/callback.h" namespace mindspore { namespace session { diff --git a/mindspore/lite/src/cxx_api/train/model.cc b/mindspore/lite/src/cxx_api/train/model.cc index 9e91a40c4ac..80d840f86ca 100644 --- a/mindspore/lite/src/cxx_api/train/model.cc +++ b/mindspore/lite/src/cxx_api/train/model.cc @@ -58,7 +58,7 @@ Status Model::Train(int epochs, std::shared_ptr ds, std::vecto return status; } - auto ret = loop->Train(epochs, ds.get(), cbs); + auto ret = loop->Train(epochs, ds.get(), cbs, nullptr); clearVectorOfPointers(&adapter_metrics); clearVectorOfPointers(&adapter_cbs); @@ -98,7 +98,7 @@ Status Model::Evaluate(std::shared_ptr ds, std::vectorEval(ds.get(), cbs); + auto ret = loop->Eval(ds.get(), cbs, nullptr, INT_MAX); clearVectorOfPointers(&adapter_metrics); clearVectorOfPointers(&adapter_cbs); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/bn_fp16_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/bn_fp16_grad.cc index 264155b3229..9b923ea2460 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/bn_fp16_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/bn_fp16_grad.cc @@ -15,7 +15,7 @@ */ #include "src/runtime/kernel/arm/fp16_grad/bn_fp16_grad.h" -#include +#include #include #include #include diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc index fcb78bfb566..123cfd5b81b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc @@ -99,7 +99,6 @@ int AdamCPUKernel::Execute(int task_id) { } int AdamRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { - CHECK_NULL_RETURN(cdata); auto adam_kernel = reinterpret_cast(cdata); CHECK_NULL_RETURN(adam_kernel); auto error_code = RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc index be8e4a68791..5d66c57d10e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc @@ -85,13 +85,13 @@ int ArithmeticGradCPUKernel::Prepare() { int ArithmeticGradCPUKernel::ArithmeticGradAdd(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size) { if (dx1_size == dy_size) { - memcpy(dx1, dy, dy_size * sizeof(float)); + memcpy(dx1, dy, static_cast(dy_size) * sizeof(float)); } else { ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx1, arithmeticParameter_->in_shape0_, arithmeticParameter_->ndim_); } if (dx2_size == dy_size) { - memcpy(dx2, dy, dy_size * sizeof(float)); + memcpy(dx2, dy, static_cast(dy_size) * sizeof(float)); } else { ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx2, arithmeticParameter_->in_shape1_, arithmeticParameter_->ndim_); @@ -102,7 +102,7 @@ int ArithmeticGradCPUKernel::ArithmeticGradAdd(float *dy, int dy_size, float *dx int ArithmeticGradCPUKernel::ArithmeticGradSub(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size) { if (dx1_size == dy_size) { - memcpy(dx1, dy, dy_size * sizeof(float)); + memcpy(dx1, dy, static_cast(dy_size) * sizeof(float)); } else { ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx1, arithmeticParameter_->in_shape0_, arithmeticParameter_->ndim_); @@ -263,7 +263,6 @@ int ArithmeticGradCPUKernel::Execute(int task_id) { } int ArithmeticGradRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { - CHECK_NULL_RETURN(cdata); auto Arithmetic_kernel = reinterpret_cast(cdata); CHECK_NULL_RETURN(Arithmetic_kernel); auto error_code = Arithmetic_kernel->Execute(task_id); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h index 4debc87dfc3..5036a202003 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h @@ -39,8 +39,12 @@ class ArithmeticGradCPUKernel : public InnerKernel { public: explicit ArithmeticGradCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) - : InnerKernel(parameter, inputs, outputs, ctx), tile_data0(NULL), tile_data1(NULL), tile_data2(NULL) { - switch (type()) { + : InnerKernel(parameter, inputs, outputs, ctx), + arithmetic_grad_(nullptr), + tile_data0(nullptr), + tile_data1(nullptr), + tile_data2(nullptr) { + switch (parameter->type_) { case PrimitiveType_MulGrad: arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul; // this will be adjusted in InferShape break; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.h index 94e5306d763..47805926a5c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.h @@ -29,8 +29,8 @@ class ArithmeticSelfGradCPUKernel : public InnerKernel { public: ArithmeticSelfGradCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) - : InnerKernel(parameter, inputs, outputs, ctx), thread_count_(ctx->thread_num_) {} - ~ArithmeticSelfGradCPUKernel() override {} + : InnerKernel(parameter, inputs, outputs, ctx), thread_count_(ctx->thread_num_), self_grad_operation_(nullptr) {} + ~ArithmeticSelfGradCPUKernel() override = default; int Prepare() override; int ReSize() override; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc index 7959f2e8e0b..65d5aefda7b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc @@ -36,7 +36,7 @@ int AssignCPUKernel::Execute(int task_id) { CHECK_NULL_RETURN(y); int length = in_tensors_.at(0)->ElementsNum(); int stride = UP_DIV(length, thread_count_); - int count = MSMIN(stride, length - stride * task_id); + size_t count = MSMIN(stride, length - stride * task_id); int start = stride * task_id; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc index 9d034d5a9d2..98dbb664590 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc @@ -62,7 +62,7 @@ int BiasGradCPUKernel::Execute(int task_id) { size_t nhw_size = 1; size_t channels = bias_param->in_shape0_[bias_param->ndim_ - 1]; // C in NHWC for (unsigned int i = 0; i < bias_param->ndim_ - 1; i++) { - nhw_size *= bias_param->in_shape0_[i]; + nhw_size *= static_cast(bias_param->in_shape0_[i]); } size_t total_size = channels * nhw_size; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc index c295f15caf8..71696cec12d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc @@ -15,7 +15,6 @@ */ #include "src/runtime/kernel/arm/fp32_grad/bn_grad.h" -#include #include #include #include @@ -121,6 +120,9 @@ int BNGradCPUKernel::Execute(int task_id) { case 2: std::fill(dscale, dscale + channels, 0.f); break; + default: + MS_LOG(ERROR) << "Exceeds the maximum thread"; + return RET_ERROR; } } if (thread_num == 1) { @@ -137,6 +139,9 @@ int BNGradCPUKernel::Execute(int task_id) { scale, count, total, channels, dx + task_id * stride * channels); break; } + default: + MS_LOG(ERROR) << "Unsupported stage"; + return RET_ERROR; } return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc index 96db5b44450..3bf70532c6b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc @@ -54,22 +54,18 @@ int ConvolutionTrainCPUKernel::ReSize() { const int n = conv_param_->output_channel_ * conv_param_->group_; const int k = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ / conv_param_->group_; - do_img2col_ = (conv_param_->kernel_h_ == 1) && (conv_param_->kernel_w_ == 1) && (conv_param_->pad_d_ == 0) && - (conv_param_->pad_u_ == 0) && (conv_param_->pad_l_ == 0) && (conv_param_->pad_r_ == 0) && - (conv_param_->dilation_h_ == 1) && (conv_param_->dilation_w_ == 1) && - (conv_param_->stride_h_ == 1) && (conv_param_->stride_w_ == 1) && (conv_param_->group_ == 1) - ? false - : true; + do_img2col_ = !((conv_param_->kernel_h_ == 1) && (conv_param_->kernel_w_ == 1) && (conv_param_->pad_d_ == 0) && + (conv_param_->pad_u_ == 0) && (conv_param_->pad_l_ == 0) && (conv_param_->pad_r_ == 0) && + (conv_param_->dilation_h_ == 1) && (conv_param_->dilation_w_ == 1) && (conv_param_->stride_h_ == 1) && + (conv_param_->stride_w_ == 1) && (conv_param_->group_ == 1)); do_dw_ = (conv_param_->output_channel_ == conv_param_->group_) && - (conv_param_->input_channel_ == conv_param_->output_channel_) && (conv_param_->dilation_h_ == 1) && - (conv_param_->dilation_w_ == 1) - ? true - : false; + (conv_param_->input_channel_ == conv_param_->output_channel_) && (conv_param_->dilation_h_ == 1) && + (conv_param_->dilation_w_ == 1); ws_size_ = chunk_ * conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_; ws_size_ = do_dw_ ? ws_size_ : ws_size_ / conv_param_->group_; int mat_alloc = MatSizeTotal(chunk_, n, k, 0); - set_workspace_size((ws_size_ + mat_alloc) * sizeof(float)); + set_workspace_size(static_cast(ws_size_ + mat_alloc) * sizeof(float)); return RET_OK; } @@ -139,7 +135,7 @@ int ConvolutionTrainCPUKernel::Execute(int task_id) { } } else { mat_b = w_addr; - const size_t in_plane_size = in_ch * in_h * in_w; + const int in_plane_size = in_ch * in_h * in_w; for (int i = 0; i < batch; ++i) { im = x_addr + i * in_plane_size; for (int ci = 0; ci < m; ci += chunk_) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc index 9caedc82624..2228687d965 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc @@ -61,26 +61,21 @@ int ConvolutionGradFilterCPUKernel::ReSize() { conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; NNACL_CHECK_ZERO_RETURN_ERR(conv_param->group_); - do_img2col_ = (conv_param->kernel_h_ == 1) && (conv_param->kernel_w_ == 1) && (conv_param->pad_d_ == 0) && - (conv_param->pad_u_ == 0) && (conv_param->pad_l_ == 0) && (conv_param->pad_r_ == 0) && - (conv_param->dilation_h_ == 1) && (conv_param->dilation_w_ == 1) && (conv_param->stride_h_ == 1) && - (conv_param->stride_w_ == 1) && (conv_param->group_ == 1) - ? false - : true; + do_img2col_ = !((conv_param->kernel_h_ == 1) && (conv_param->kernel_w_ == 1) && (conv_param->pad_d_ == 0) && + (conv_param->pad_u_ == 0) && (conv_param->pad_l_ == 0) && (conv_param->pad_r_ == 0) && + (conv_param->dilation_h_ == 1) && (conv_param->dilation_w_ == 1) && (conv_param->stride_h_ == 1) && + (conv_param->stride_w_ == 1) && (conv_param->group_ == 1)); do_dw_ = (conv_param->output_channel_ == conv_param->group_) && - (conv_param->input_channel_ == conv_param->output_channel_) && (conv_param->dilation_h_ == 1) && - (conv_param->dilation_w_ == 1) - ? true - : false; + (conv_param->input_channel_ == conv_param->output_channel_) && (conv_param->dilation_h_ == 1) && + (conv_param->dilation_w_ == 1); ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; - ws_size_ = do_dw_ ? ws_size_ : ws_size_ / conv_param->group_; + ws_size_ = do_dw_ ? ws_size_ : ws_size_ / static_cast(conv_param->group_); int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; int k = conv_param->output_channel_ / conv_param->group_; - int thread_num = op_parameter_->thread_num_; + auto thread_num = static_cast(op_parameter_->thread_num_); mat_alloc_ = MatSizeTotal(k, n, chunk_, 0); - set_workspace_size((ws_size_ + mat_alloc_ + (k * n)) * thread_num * sizeof(float)); - + set_workspace_size((ws_size_ + mat_alloc_ + static_cast(k * n)) * thread_num * sizeof(float)); return RET_OK; } @@ -97,27 +92,24 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { auto dy_addr = reinterpret_cast(input_dy->MutableData()); auto dw_addr = reinterpret_cast(out_dw->MutableData()); - int nweights = out_dw->ElementsNum(); int in_ch = conv_param->input_channel_; int in_h = conv_param->input_h_; int in_w = conv_param->input_w_; int k_h = conv_param->kernel_h_; int k_w = conv_param->kernel_w_; - int batch = conv_param->output_batch_; int out_ch = conv_param->output_channel_; int groups = conv_param->group_; - int out_h = conv_param->output_h_; - int out_w = conv_param->output_w_; - int m = out_h * out_w; + int m = conv_param->output_h_ * conv_param->output_w_; int n = k_h * k_w * in_ch / groups; int k = out_ch / groups; int thread_num = op_parameter_->thread_num_; - float *workspace_temp = reinterpret_cast(workspace()); - float *mat_workspace = workspace_temp + ws_size_ * thread_num + task_id * (mat_alloc_ + k * n); + auto *workspace_temp = reinterpret_cast(workspace()); + float *mat_workspace = + workspace_temp + static_cast(ws_size_) * thread_num + task_id * (static_cast(mat_alloc_) + k * n); float *mat_tmp = mat_workspace + mat_alloc_; - int stride = UP_DIV(batch, thread_num); - int count = MSMIN(stride, batch - stride * task_id); + int stride = UP_DIV(conv_param->output_batch_, thread_num); + int count = MSMIN(stride, conv_param->output_batch_ - stride * task_id); count = (count < 0) ? 0 : count; int start = stride * task_id; int end = start + count; @@ -140,7 +132,7 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { end = start + count; const int kernel_spatial = k_h * k_w; - for (int i = 0; i < batch; ++i) { + for (int i = 0; i < conv_param->output_batch_; ++i) { for (int ci = 0; ci < m; ci += chunk_) { real_chunk = MSMIN(m - ci, chunk_); mat_b = workspace_temp + task_id * ws_size_; @@ -148,7 +140,7 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { RollingIm2ColPackDwUnitFp32(im, conv_param, mat_b, real_chunk, ci); for (int j = start; j < end; ++j) { mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; - mat_c = dw_addr + j * nweights / groups; + mat_c = dw_addr + j * out_dw->ElementsNum() / groups; GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b + (j * kernel_spatial), n * groups, 1, mat_c, n, mat_workspace); } @@ -161,8 +153,8 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { for (int j = 0; j < groups; ++j) { real_chunk = MSMIN(m - ci, chunk_); mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; - mat_b = workspace_temp + task_id * ws_size_; - mat_c = dw_addr + j * nweights / groups; + mat_b = workspace_temp + task_id * static_cast(ws_size_); + mat_c = dw_addr + j * out_dw->ElementsNum() / groups; im = x_addr + (i * in_ch * in_h * in_w) + j * (in_ch / groups); RollingIm2ColPackUnitFp32(im, conv_param, mat_b, real_chunk, ci); GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b, n, 0, mat_tmp, n, mat_workspace); @@ -172,17 +164,16 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { } } } else { - NNACL_CHECK_ZERO_RETURN_ERR(out_w * conv_param->stride_h_); - NNACL_CHECK_ZERO_RETURN_ERR(out_w * conv_param->stride_w_); + NNACL_CHECK_ZERO_RETURN_ERR(conv_param->output_w_); mat_c = dw_addr; - const size_t in_plane_size = in_ch * in_h * in_w; + auto in_plane_size = in_ch * in_h * in_w; for (int i = start; i < end; ++i) { for (int ci = 0; ci < m; ci += chunk_) { real_chunk = MSMIN(m - ci, chunk_); mat_a = dy_addr + i * m * k + ci * out_ch; im = x_addr + i * in_plane_size; - int input_h = ci / out_w * conv_param->stride_h_; - int input_w = ci % out_w * conv_param->stride_w_; + int input_h = ci / conv_param->output_w_ * conv_param->stride_h_; + int input_w = ci % conv_param->output_w_ * conv_param->stride_w_; int offset = (input_h * in_w + input_w) * in_ch; GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, im + offset, n, 0, mat_tmp, n, mat_workspace); std::unique_lock merge_lock(lock_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc index 01012d53b21..1568dff8c92 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc @@ -54,7 +54,7 @@ int ConvolutionGradInputCPUKernel::ReSize() { int n = conv_param->kernel_w_ * conv_param->kernel_h_ * conv_param->input_channel_ / conv_param->group_; int k = conv_param->output_channel_ / conv_param->group_; - int thread_num = op_parameter_->thread_num_; + auto thread_num = static_cast(op_parameter_->thread_num_); mat_alloc_ = MatSizeTotal(chunk_, n, k, 0); set_workspace_size((ws_size_ + mat_alloc_) * sizeof(float) * thread_num); @@ -102,7 +102,8 @@ int ConvolutionGradInputCPUKernel::Execute(int task_id) { int m = out_h * out_w; int n = k_w * k_h * in_ch / groups; int k = out_ch / groups; - float *workspace_temp = reinterpret_cast(workspace()) + task_id * (mat_alloc_ + ws_size_); + float *workspace_temp = + reinterpret_cast(workspace()) + static_cast(task_id) * (mat_alloc_ + ws_size_); float *mat_workspace = workspace_temp + ws_size_; int stride = UP_DIV(batch, thread_num); int count = MSMIN(stride, batch - stride * task_id); @@ -169,10 +170,10 @@ int ConvolutionGradInputRun(void *cdata, int task_id, float lhs_scale, float rhs int ConvolutionGradInputCPUKernel::Run() { auto conv_param = reinterpret_cast(op_parameter_); - int batch = conv_param->output_batch_; - int in_ch = conv_param->input_channel_; - int in_h = conv_param->input_h_; - int in_w = conv_param->input_w_; + auto batch = static_cast(conv_param->output_batch_); + auto in_ch = static_cast(conv_param->input_channel_); + auto in_h = static_cast(conv_param->input_h_); + auto in_w = static_cast(conv_param->input_w_); auto *out_dx = out_tensors_.at(0); auto dx_addr = reinterpret_cast(out_dx->MutableData()); memset(dx_addr, 0, sizeof(float) * batch * in_ch * in_h * in_w); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.cc index 1ce0e9c6fd9..f0696ce1af5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.cc @@ -80,7 +80,7 @@ int DropoutCPUKernel::Execute(int task_id) { std::bernoulli_distribution distribution(param->ratio_); for (int i = start; i < end; i++) { - mask[i] = distribution(generator); + mask[i] = static_cast(distribution(generator)); output_ptr[i] = input_ptr[i] * mask[i] * scale_; } } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.h index 812099124e4..14d459ea8cc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.h @@ -34,7 +34,7 @@ class DropoutGradCPUKernel : public InnerKernel { int Execute(int task_id); private: - float scale_; + float scale_ = 1.0f; int thread_count_ = 1; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/resize_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/resize_grad.cc index 83df9af4664..d68333ff12e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/resize_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/resize_grad.cc @@ -73,13 +73,13 @@ int ResizeGradCPUKernel::Execute(int task_id) { CHECK_NULL_RETURN(param); auto batch_size = in_tensors_.at(0)->Batch(); auto channel = in_tensors_.at(0)->Channel(); - int error_code = NNACL_OK; + int error_code; if (param->method == static_cast(schema::ResizeMethod_NEAREST)) { error_code = ResizeNearestNeighborGrad(in_addr, out_addr, batch_size, channel, in_tensors_.at(0)->format(), param); } else { error_code = ResizeBiLinearGrad(in_addr, out_addr, batch_size, channel, in_tensors_.at(0)->format(), param); } - if (error_code != NNACL_OK) { + if (error_code != static_cast(NNACL_OK)) { MS_LOG(ERROR) << "Resize fp32 grad failed."; return error_code; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits.cc index e49b2254987..b4293276045 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits.cc @@ -15,7 +15,6 @@ */ #include "src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits.h" -#include #include "src/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits_grad.cc index e2c0d8b8cb5..822005ac81e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits_grad.cc @@ -15,7 +15,6 @@ */ #include "src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits_grad.h" -#include #include "src/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.cc index 877b42afed3..af6b77e8e8b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.cc @@ -30,7 +30,7 @@ constexpr static int kOutputIdx = 0; int SmoothL1LossCPUKernel::ReSize() { return RET_OK; } -int SmoothL1LossCPUKernel::Execute(int task_id) { +int SmoothL1LossCPUKernel::Execute(size_t task_id) { SmoothL1LossParameter *smooth_l1_loss_param = reinterpret_cast(op_parameter_); CHECK_NULL_RETURN(smooth_l1_loss_param); auto predict = reinterpret_cast(in_tensors_.at(kPredictIdx)->MutableData()); @@ -42,7 +42,7 @@ int SmoothL1LossCPUKernel::Execute(int task_id) { const size_t length = in_tensors_.at(kPredictIdx)->ElementsNum(); size_t stride = UP_DIV(length, thread_count_); - int count = MSMIN(stride, length - stride * task_id); + size_t count = MSMIN(stride, length - stride * task_id); size_t start = stride * task_id; size_t end = start + count; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.h index d00c8c9aec6..ae6da391d8c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.h @@ -33,11 +33,11 @@ class SmoothL1LossCPUKernel : public InnerKernel { int Prepare() override; int ReSize() override; int Run() override; - int Execute(int task_id); + int Execute(size_t task_id); private: SmoothL1LossParameter *smooth_l1_param_; - int thread_count_ = 1; + size_t thread_count_ = 1; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc index e05c8271bff..1ea783ed83a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc @@ -33,7 +33,7 @@ void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *lab float *output2) const { float eps = 1e-6; if (grads != nullptr) { - for (int i = 0; i < param_->batch_size_; ++i) { + for (size_t i = 0; i < static_cast(param_->batch_size_); ++i) { float loss = 0.f; for (size_t j = 0; j < param_->number_of_classes_; ++j) { float logit = @@ -45,7 +45,7 @@ void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *lab output2[i] = loss; } } else { - for (int i = 0; i < param_->batch_size_; ++i) { + for (size_t i = 0; i < static_cast(param_->batch_size_); ++i) { float loss = 0.f; for (size_t j = 0; j < param_->number_of_classes_; ++j) { float logit = @@ -123,7 +123,7 @@ int SoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { } size_t data_size = in_tensors_.at(0)->ElementsNum(); - set_workspace_size((data_size + dims.at(0)) * sizeof(float)); + set_workspace_size((data_size + static_cast(dims.at(0))) * sizeof(float)); sm_params_.n_dim_ = 2; sm_params_.element_size_ = data_size; sm_params_.axis_ = 1; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.h index 27259cf6ee4..c99e61bbd7a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.h @@ -33,7 +33,7 @@ class SoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel { : LossKernel(parameter, inputs, outputs, ctx) { param_ = reinterpret_cast(parameter); } - ~SoftmaxCrossEntropyWithLogitsCPUKernel() override {} + ~SoftmaxCrossEntropyWithLogitsCPUKernel() override = default; void ForwardPostExecute(const float *labels, const float *logits, float *output1, float *output2) const; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.cc index df8c21fae2f..de62bc84e54 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.cc @@ -49,9 +49,9 @@ int SoftmaxGradCPUKernel::Prepare() { inner_size_ = 1; for (size_t i = axis + 1; i < in_dims; i++) { - inner_size_ *= in_shape.at(i); + inner_size_ *= static_cast(in_shape.at(i)); } - set_workspace_size(inner_size_ * (1 + in_shape.at(axis)) * sizeof(float)); + set_workspace_size(inner_size_ * (1 + static_cast(in_shape.at(axis))) * sizeof(float)); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.h index 19f1469b095..ffaba9c8a22 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.h @@ -37,7 +37,7 @@ class SoftmaxGradCPUKernel : public InnerKernel { private: SoftmaxParameter *param; - size_t inner_size_; + size_t inner_size_ = 0; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc index 9db82c3e727..b596f6053de 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc @@ -34,7 +34,7 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int * float *output) const { float total_loss = 0; MS_CHECK_GT(param->batch_size_, 0, RET_ERROR); - for (int i = 0; i < param->batch_size_; ++i) { + for (size_t i = 0; i < static_cast(param->batch_size_); ++i) { if (labels[i] < 0) { MS_LOG(ERROR) << "label value must >= 0"; return RET_ERROR; @@ -91,7 +91,7 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) { float *losses = static_cast(workspace()); CHECK_NULL_RETURN(losses); float *sum_data = losses + data_size; - int length = sm_params_.input_shape_[sm_params_.axis_]; + int length = sm_params_->input_shape_[sm_params_->axis_]; int stride = UP_DIV(outter_size_, threads_); int count = MSMIN(stride, outter_size_ - stride * task_id); if (count <= 0) return RET_OK; @@ -108,6 +108,9 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) { } else { return ForwardPostExecute(labels, losses, out); } + default: + MS_LOG(ERROR) << "Unsupported stage"; + return RET_ERROR; } return RET_OK; } @@ -125,9 +128,9 @@ int SparseSoftmaxCrossEntropyWithLogitsRun(void *cdata, int task_id, float lhs_s } int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { - int axis = sm_params_.axis_; - int n_dim = sm_params_.n_dim_; - const int *input_shape = sm_params_.input_shape_; + int axis = sm_params_->axis_; + int n_dim = sm_params_->n_dim_; + const int *input_shape = sm_params_->input_shape_; int inner_size = 1; int outter_size = 1; CHECK_NULL_RETURN(in_tensors_.at(0)); @@ -136,7 +139,7 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { CHECK_NULL_RETURN(losses); float *sum_data = losses + data_size; std::fill(losses, losses + data_size, 0.f); - std::fill(sum_data, sum_data + sm_params_.input_shape_[0], 0.f); + std::fill(sum_data, sum_data + sm_params_->input_shape_[0], 0.f); for (int i = 0; i < axis; i++) { outter_size *= input_shape[i]; } @@ -182,12 +185,17 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Prepare() { return RET_ERROR; } size_t data_size = in_tensors_.at(0)->ElementsNum(); - set_workspace_size((data_size + dims.at(0)) * sizeof(float)); - sm_params_.n_dim_ = 2; - sm_params_.element_size_ = static_cast(data_size); - sm_params_.axis_ = 1; + set_workspace_size((data_size + static_cast(dims.at(0))) * sizeof(float)); + sm_params_ = new (std::nothrow) SoftmaxParameter(); + if (sm_params_ == nullptr) { + MS_LOG(ERROR) << "new softmax param failed."; + return RET_ERROR; + } + sm_params_->n_dim_ = 2; + sm_params_->element_size_ = static_cast(data_size); + sm_params_->axis_ = 1; for (size_t i = 0; i < dims.size(); i++) { - sm_params_.input_shape_[i] = dims.at(i); + sm_params_->input_shape_[i] = dims.at(i); } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h index c91dc8f4300..3233634686e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h @@ -34,7 +34,12 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel { : LossKernel(parameter, inputs, outputs, ctx) { param = reinterpret_cast(parameter); } - ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override {} + ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override { + if (sm_params_ != nullptr) { + delete sm_params_; + sm_params_ = nullptr; + } + } int ForwardPostExecute(const int *labels, const float *losses, float *output) const; int GradPostExecute(const int *labels, const float *losses, float *grads) const; @@ -46,11 +51,11 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel { private: SoftmaxCrossEntropyParameter *param; - SoftmaxParameter sm_params_; + SoftmaxParameter *sm_params_ = nullptr; int inner_size_ = 1; int outter_size_ = 1; - int stage_; - int threads_; + int stage_ = 0; + int threads_ = 0; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc index f16f603a984..5413873ac43 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc @@ -46,13 +46,13 @@ int UnsortedSegmentSumCPUKernel::Prepare() { for (size_t i = 0; i < input_shape.size(); ++i) { unit_num_ *= input_shape[i]; if (i >= segment_ids_shape.size()) { - input_dim1_ *= input_shape[i]; + input_dim1_ *= static_cast(input_shape[i]); } } output_dim0_ = output_shape[0]; output_dim1_ = 1; for (size_t j = 1; j < output_shape.size(); j++) { - output_dim1_ *= output_shape[j]; + output_dim1_ *= static_cast(output_shape[j]); } return RET_OK; } diff --git a/mindspore/lite/src/train/loss_kernel.h b/mindspore/lite/src/train/loss_kernel.h index 09dc5bc5549..fe375f469f8 100644 --- a/mindspore/lite/src/train/loss_kernel.h +++ b/mindspore/lite/src/train/loss_kernel.h @@ -25,7 +25,7 @@ class LossKernel : public InnerKernel { LossKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) : InnerKernel(parameter, inputs, outputs, ctx) {} - ~LossKernel() = default; + ~LossKernel() override = default; }; } // namespace mindspore::kernel