From 3afea5ce9de91e7f2149d22e95689e96cf22e2a7 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Mon, 7 Feb 2022 17:19:21 +0800 Subject: [PATCH] [MSLITE][DEVELOP] code review for lite: base op directory, include directory --- include/api/kernel.h | 1 - .../cpu/nnacl/base/depth_to_space_base.c | 2 +- .../cpu/nnacl/base/zeroslike_base.h | 2 +- mindspore/lite/include/context.h | 2 +- .../lite/include/registry/register_kernel.h | 4 +- .../registry/register_kernel_interface.h | 4 +- .../lite/src/cxx_api/callback/callback_impl.h | 11 -- mindspore/lite/src/cxx_api/context.cc | 5 + mindspore/lite/src/cxx_api/converters.cc | 8 +- mindspore/lite/src/cxx_api/graph/graph_data.h | 2 +- .../src/cxx_api/metrics/metrics_adapter.h | 11 -- mindspore/lite/src/cxx_api/serialization.cc | 2 +- .../lite/src/cxx_api/tensor/tensor_impl.h | 4 +- mindspore/lite/src/cxx_api/types.cc | 4 +- .../lite/src/runtime/kernel/arm/base/call.h | 2 +- .../src/runtime/kernel/arm/base/carry_data.cc | 139 ------------------ .../src/runtime/kernel/arm/base/carry_data.h | 46 ------ .../kernel/arm/base/convolution_base.h | 2 +- .../kernel/arm/base/group_convolution_base.cc | 5 +- .../kernel/arm/base/group_convolution_base.h | 2 +- .../arm/base/group_convolution_creator.cc | 5 +- .../arm/base/group_convolution_creator.h | 10 +- .../runtime/kernel/arm/base/one_hot_base.cc | 1 - .../runtime/kernel/arm/base/one_hot_base.h | 1 - .../runtime/kernel/arm/base/partial_fusion.h | 2 - .../runtime/kernel/arm/base/pooling_base.h | 3 +- .../src/runtime/kernel/arm/base/prior_box.cc | 5 +- .../runtime/kernel/arm/base/reduce_base.cc | 2 +- .../src/runtime/kernel/arm/base/reduce_base.h | 1 - .../runtime/kernel/arm/base/reshape_base.h | 1 - .../src/runtime/kernel/arm/base/select.cc | 116 +++++++++++++++ .../lite/src/runtime/kernel/arm/base/select.h | 5 +- .../src/runtime/kernel/arm/base/slice_base.h | 2 - .../runtime/kernel/arm/base/softmax_base.h | 3 +- .../src/runtime/kernel/arm/base/split_base.h | 2 - .../arm/base/split_with_over_lap_base.cc | 2 +- .../arm/base/split_with_over_lap_base.h | 1 - .../src/runtime/kernel/arm/base/stack_base.cc | 5 +- .../src/runtime/kernel/arm/base/stack_base.h | 2 +- .../arm/fp16/convolution_delegate_fp16.cc | 2 +- .../arm/fp32/convolution_delegate_fp32.cc | 2 +- .../arm/int8/convolution_int8_creator.cc | 2 +- 42 files changed, 162 insertions(+), 271 deletions(-) delete mode 100644 mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc delete mode 100644 mindspore/lite/src/runtime/kernel/arm/base/carry_data.h diff --git a/include/api/kernel.h b/include/api/kernel.h index c5461f3edd6..9b322711e06 100644 --- a/include/api/kernel.h +++ b/include/api/kernel.h @@ -139,7 +139,6 @@ class MS_API Kernel { /// \param[in] value define the kernel's attribute value. void SetAttr(const std::string &key, const std::string &value) { attrs_[key] = value; } - protected: std::string name_; const mindspore::Context *context_ = nullptr; std::vector inputs_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/depth_to_space_base.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/depth_to_space_base.c index dd159efb3e7..e8821b8ba84 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/depth_to_space_base.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/depth_to_space_base.c @@ -21,7 +21,7 @@ void DepthToSpaceForNHWC(const void *input, void *output, const int *in_shape, c int32_t block_size = param->block_size_; int32_t in_shape_dim2 = in_shape[2]; int32_t in_shape_dim1 = in_shape[1]; - size_t copy_size = block_size * param->out_stride_dim2_ * param->data_type_size_; + size_t copy_size = (size_t)block_size * param->out_stride_dim2_ * param->data_type_size_; for (int i = 0; i < in_shape[0]; ++i) { int in_offset_n = i * param->in_stride_dim0_; int out_offset_n = i * param->out_stride_dim0_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/zeroslike_base.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/zeroslike_base.h index 2374014d276..9ba0d904a42 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/zeroslike_base.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/zeroslike_base.h @@ -23,7 +23,7 @@ extern "C" { #endif static inline void ApproximateZerosLike(void *output, int data_size) { - memset(output, 0.0, data_size); + (void)memset(output, 0, data_size); return; } diff --git a/mindspore/lite/include/context.h b/mindspore/lite/include/context.h index c842e778a38..f1492cbd3da 100644 --- a/mindspore/lite/include/context.h +++ b/mindspore/lite/include/context.h @@ -46,7 +46,7 @@ typedef struct NpuDeviceInfo { /// \brief AscendDeviceInfo defined for Ascend's configuration information. typedef struct AscendDeviceInfo { - uint32_t device_id_; + uint32_t device_id_ = 0; std::string batch_size_; std::string image_size_; } AscendDeviceInfo; diff --git a/mindspore/lite/include/registry/register_kernel.h b/mindspore/lite/include/registry/register_kernel.h index 6c73317c80c..764675b8bfc 100644 --- a/mindspore/lite/include/registry/register_kernel.h +++ b/mindspore/lite/include/registry/register_kernel.h @@ -116,7 +116,7 @@ class MS_API KernelReg { /// \param[in] creator Define a function pointer to create a kernel. KernelReg(const std::string &arch, const std::string &provider, DataType data_type, int op_type, const CreateKernel creator) { - RegisterKernel::RegKernel(arch, provider, data_type, op_type, creator); + (void)RegisterKernel::RegKernel(arch, provider, data_type, op_type, creator); } /// \brief Method to register customized op. @@ -128,7 +128,7 @@ class MS_API KernelReg { /// \param[in] creator Define a function pointer to create a kernel. KernelReg(const std::string &arch, const std::string &provider, DataType data_type, const std::string &op_type, const CreateKernel creator) { - RegisterKernel::RegCustomKernel(arch, provider, data_type, op_type, creator); + (void)RegisterKernel::RegCustomKernel(arch, provider, data_type, op_type, creator); } }; diff --git a/mindspore/lite/include/registry/register_kernel_interface.h b/mindspore/lite/include/registry/register_kernel_interface.h index 93e02a45741..4d99de57fbd 100644 --- a/mindspore/lite/include/registry/register_kernel_interface.h +++ b/mindspore/lite/include/registry/register_kernel_interface.h @@ -83,7 +83,7 @@ class MS_API KernelInterfaceReg { /// \param[in] op_type Define the ordinary op type. /// \param[in] creator Define the KernelInterface create function. KernelInterfaceReg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) { - RegisterKernelInterface::Reg(provider, op_type, creator); + (void)RegisterKernelInterface::Reg(provider, op_type, creator); } /// \brief Constructor of KernelInterfaceReg to register custom op. @@ -92,7 +92,7 @@ class MS_API KernelInterfaceReg { /// \param[in] op_type Define the concrete type of a custom op. /// \param[in] creator Define the KernelInterface create function. KernelInterfaceReg(const std::string &provider, const std::string &op_type, const KernelInterfaceCreator creator) { - RegisterKernelInterface::CustomReg(provider, op_type, creator); + (void)RegisterKernelInterface::CustomReg(provider, op_type, creator); } virtual ~KernelInterfaceReg() = default; diff --git a/mindspore/lite/src/cxx_api/callback/callback_impl.h b/mindspore/lite/src/cxx_api/callback/callback_impl.h index 6d293833146..7760d3f4fe8 100644 --- a/mindspore/lite/src/cxx_api/callback/callback_impl.h +++ b/mindspore/lite/src/cxx_api/callback/callback_impl.h @@ -17,17 +17,6 @@ #ifndef MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_IMPL_H_ #define MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_IMPL_H_ -#include -#include -#include -#include -#include -#include -#include -#include "include/api/model.h" -#include "include/api/context.h" -#include "include/api/cell.h" -#include "include/lite_session.h" #include "include/train/train_loop_callback.h" namespace mindspore { diff --git a/mindspore/lite/src/cxx_api/context.cc b/mindspore/lite/src/cxx_api/context.cc index 046d096efc4..7c42a3ac5e1 100644 --- a/mindspore/lite/src/cxx_api/context.cc +++ b/mindspore/lite/src/cxx_api/context.cc @@ -18,6 +18,7 @@ #include #include "include/api/types.h" #include "include/api/data_type.h" +#include "include/lite_types.h" #include "src/runtime/inner_allocator.h" #include "src/common/log_adapter.h" #include "src/delegate/tensorrt/distribution/distribution_base.h" @@ -109,6 +110,10 @@ void Context::SetThreadAffinity(int mode) { MS_LOG(ERROR) << "Invalid context."; return; } + if (mode < lite::NO_BIND || mode > lite::MID_CPU) { + MS_LOG(WARNING) << "Invalid thread affinity mode: " << mode << ", change to NO_BIND mode."; + mode = lite::NO_BIND; + } data_->affinity_mode_ = mode; return; } diff --git a/mindspore/lite/src/cxx_api/converters.cc b/mindspore/lite/src/cxx_api/converters.cc index 2b79e284a74..fb4c802541f 100644 --- a/mindspore/lite/src/cxx_api/converters.cc +++ b/mindspore/lite/src/cxx_api/converters.cc @@ -36,7 +36,7 @@ Status ContextUtils::AddCpuDevice(const std::shared_ptr &allocator, i MS_LOG(ERROR) << "Invalid affinity mode, only supports 0:no affinities, 1:big cores first, 2:little cores first."; return kLiteInputParamInvalid; } - lite::DeviceInfo device_info = {0}; + lite::DeviceInfo device_info; device_info.cpu_device_info_ = {enable_fp16, static_cast(affinity_mode)}; inner_context->device_list_.push_back({lite::DT_CPU, device_info, provider, provider_device, allocator}); return kSuccess; @@ -46,7 +46,7 @@ Status ContextUtils::AddGpuDevice(bool enable_fp16, uint32_t device_id, int rank bool enable_gl_texture, void *gl_context, void *gl_display, const std::string &provider, const std::string &provider_device, const std::shared_ptr &allocator, lite::InnerContext *inner_context) { - lite::DeviceInfo device_info = {0}; + lite::DeviceInfo device_info; device_info.gpu_device_info_ = {enable_fp16, device_id, rank_id, group_size, enable_gl_texture, gl_context, gl_display}; inner_context->device_list_.push_back({lite::DT_GPU, device_info, provider, provider_device, allocator}); @@ -54,14 +54,14 @@ Status ContextUtils::AddGpuDevice(bool enable_fp16, uint32_t device_id, int rank } Status ContextUtils::AddNpuDevice(int frequency, lite::InnerContext *inner_context) { - lite::DeviceInfo device_info = {0}; + lite::DeviceInfo device_info; device_info.npu_device_info_ = {frequency}; inner_context->device_list_.push_back({lite::DT_NPU, device_info}); return kSuccess; } Status ContextUtils::AddAscendDevice(lite::InnerContext *inner_context, DeviceInfoContext *device) { - lite::DeviceInfo device_info = {0}; + lite::DeviceInfo device_info; auto ascend_context = device->Cast(); device_info.ascend_device_info_ = {ascend_context->GetDeviceID(), ascend_context->GetDynamicBatchSize(), ascend_context->GetDynamicImageSize()}; diff --git a/mindspore/lite/src/cxx_api/graph/graph_data.h b/mindspore/lite/src/cxx_api/graph/graph_data.h index 26603ed6eb2..d83d0cfd0dc 100644 --- a/mindspore/lite/src/cxx_api/graph/graph_data.h +++ b/mindspore/lite/src/cxx_api/graph/graph_data.h @@ -36,7 +36,7 @@ class Graph::GraphData { std::shared_ptr lite_model() { return lite_model_; } - bool IsTrainModel() { return true; } + bool IsTrainModel() const { return true; } private: std::shared_ptr lite_model_ = nullptr; diff --git a/mindspore/lite/src/cxx_api/metrics/metrics_adapter.h b/mindspore/lite/src/cxx_api/metrics/metrics_adapter.h index c8b66c22173..56cca00db2a 100644 --- a/mindspore/lite/src/cxx_api/metrics/metrics_adapter.h +++ b/mindspore/lite/src/cxx_api/metrics/metrics_adapter.h @@ -17,18 +17,7 @@ #ifndef MINDSPORE_LITE_SRC_CXX_API_METRICS_METRICS_ADAPTER_H_ #define MINDSPORE_LITE_SRC_CXX_API_METRICS_METRICS_ADAPTER_H_ -#include -#include -#include #include -#include -#include -#include -#include "include/api/model.h" -#include "include/api/context.h" -#include "include/api/cell.h" -#include "include/api/metrics/metrics.h" -#include "include/lite_session.h" #include "include/train/metrics.h" namespace mindspore { diff --git a/mindspore/lite/src/cxx_api/serialization.cc b/mindspore/lite/src/cxx_api/serialization.cc index f839a8ee65d..44798c04182 100644 --- a/mindspore/lite/src/cxx_api/serialization.cc +++ b/mindspore/lite/src/cxx_api/serialization.cc @@ -34,7 +34,7 @@ Key::Key(const char *dec_key, size_t key_len) { return; } - memcpy(key, dec_key, key_len); + (void)memcpy(key, dec_key, key_len); len = key_len; } diff --git a/mindspore/lite/src/cxx_api/tensor/tensor_impl.h b/mindspore/lite/src/cxx_api/tensor/tensor_impl.h index 3151a583e35..ef704011a4c 100644 --- a/mindspore/lite/src/cxx_api/tensor/tensor_impl.h +++ b/mindspore/lite/src/cxx_api/tensor/tensor_impl.h @@ -69,7 +69,7 @@ class MSTensor::Impl { #endif virtual const std::string &Name() const { - static std::string empty = ""; + static const std::string empty = ""; if (lite_tensor_ == nullptr) { MS_LOG(ERROR) << "Invalid tensor."; return empty; @@ -111,7 +111,7 @@ class MSTensor::Impl { } virtual const std::vector &Shape() const { - static std::vector empty; + static std::vector empty{}; if (lite_tensor_ == nullptr) { MS_LOG(ERROR) << "Invalid tensor."; return empty; diff --git a/mindspore/lite/src/cxx_api/types.cc b/mindspore/lite/src/cxx_api/types.cc index bb271873292..982eb20daf8 100644 --- a/mindspore/lite/src/cxx_api/types.cc +++ b/mindspore/lite/src/cxx_api/types.cc @@ -15,8 +15,8 @@ */ #include "include/api/types.h" -#include -#include +#include +#include #include #include "include/api/status.h" #include "include/api/dual_abi_helper.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/call.h b/mindspore/lite/src/runtime/kernel/arm/base/call.h index b8b137cf85a..eb382e21e1f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/call.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/call.h @@ -17,8 +17,8 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CALL_H_ #include -#include "src/runtime/kernel/arm/base/carry_data.h" #include "src/tensor.h" +#include "src/inner_kernel.h" #ifndef CONTROLFLOW_TENSORLIST_CLIP #include "src/tensorlist.h" #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc b/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc deleted file mode 100644 index 347751c0912..00000000000 --- a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/runtime/kernel/arm/base/carry_data.h" -#include "include/errorcode.h" - -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_NOT_SUPPORT; -using mindspore::lite::RET_OK; - -namespace mindspore::kernel { -int CarryDataKernel::MoveData(const std::vector::iterator &dst_begin, - const std::vector::iterator &dst_end, - const std::vector::iterator &src_begin, - const std::vector::iterator &src_limit) { - for (auto dst_iter = dst_begin, src_iter = src_begin; dst_iter != dst_end; dst_iter++, src_iter++) { - if (src_iter == src_limit) { - MS_LOG(ERROR) << "out of range of input tensor"; - return RET_ERROR; - } - auto *dst_tensor = *dst_iter; - auto *src_tensor = *src_iter; - if (dst_tensor == nullptr || src_tensor == nullptr) { - MS_LOG(ERROR) << "input tensor or output tensor of merge is nullptr"; - return RET_ERROR; - } - lite::STATUS ret = RET_OK; - if (src_tensor->IsConst() || src_tensor->IsGraphInput()) { - MS_LOG(DEBUG) << "Carry const data and graph inputs."; - dst_tensor->set_data(src_tensor->data()); - dst_tensor->set_own_data(false); - } else { - if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) { -#ifndef CONTROLFLOW_TENSORLIST_CLIP - MS_LOG(DEBUG) << "Carry MoveTensorListData"; - ret = MoveTensorListData(reinterpret_cast(dst_tensor), - reinterpret_cast(src_tensor)); -#else - MS_LOG(ERROR) << unsupport_controlflow_tensorlist_log; - return RET_NOT_SUPPORT; -#endif - } else { - MS_LOG(DEBUG) << "Carry MoveTensorData"; - ret = MoveTensorData(dst_tensor, src_tensor); - } - } - if (ret != RET_OK) { - MS_LOG(ERROR) << "Move data failed : " << ret; - return ret; - } - } - return RET_OK; -} - -int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor) { - if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() || - !(dst_tensor->shape() == src_tensor->shape() || (dst_tensor->shape().empty() && src_tensor->shape().empty()))) { - MS_LOG(ERROR) << "input tensor and output tensor is incompatible."; - MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs " - << "output tensor data_type: " << dst_tensor->data_type() - << "input tensor format: " << src_tensor->format() << " vs " - << "output tensor format: " << dst_tensor->format() << " input tensor shape: " << src_tensor->shape() - << " vs " - << "output tensor shape: " << dst_tensor->shape(); - return RET_ERROR; - } - if (src_tensor->allocator() == nullptr) { - MS_LOG(ERROR) << "src_tensor allocator is nullptr."; - return RET_ERROR; - } - - CHECK_NULL_RETURN(src_tensor->data()); - CHECK_NULL_RETURN(dst_tensor->data()); - // need replace with increase data ref count - MS_CHECK_FALSE(src_tensor->Size() == 0, RET_ERROR); - memcpy(dst_tensor->data(), src_tensor->data(), src_tensor->Size()); - return RET_OK; -} - -#ifndef CONTROLFLOW_TENSORLIST_CLIP -int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensorlist, lite::TensorList *src_tensorlist) { - // shape may change, because tensors.size() can be change in RunGraph - if (dst_tensorlist->data_type() != src_tensorlist->data_type() || - dst_tensorlist->format() != src_tensorlist->format()) { - MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible"; - MS_LOG(ERROR) << "input tensor data_type: " << src_tensorlist->data_type() << " vs " - << "output tensor data_type: " << dst_tensorlist->data_type() - << "input tensor format: " << src_tensorlist->format() << " vs " - << "output tensor format: " << dst_tensorlist->format(); - return RET_ERROR; - } - // when tensorlist malloc is done. this need to check element_shape compatibility - dst_tensorlist->set_element_shape(src_tensorlist->element_shape()); - - auto update_data_type = kTypeUnknown; - auto dst_tensor_data_type = dst_tensorlist->tensors_data_type(); - auto src_tensor_data_type = src_tensorlist->tensors_data_type(); - if (dst_tensor_data_type != src_tensor_data_type) { - if (src_tensor_data_type != kTypeUnknown && dst_tensor_data_type != kTypeUnknown) { - MS_LOG(ERROR) << "input tensorlist and output tensorlist is incompatible"; - return RET_ERROR; - } - update_data_type = dst_tensor_data_type != kTypeUnknown ? dst_tensor_data_type : src_tensor_data_type; - } - if (update_data_type != kTypeUnknown) { - src_tensorlist->set_tensors_data_type(update_data_type); - dst_tensorlist->set_tensors_data_type(update_data_type); - } - size_t src_tensorlist_tensors_size = src_tensorlist->tensors().size(); - for (size_t i = 0; i < src_tensorlist_tensors_size; ++i) { - auto &src_tensor = src_tensorlist->tensors()[i]; - auto &dst_tensor = dst_tensorlist->tensors()[i]; - - if (src_tensor->allocator() != nullptr) { - src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count()); - } - dst_tensor->set_own_data(src_tensor->own_data()); - if (src_tensor->data() != nullptr) { - dst_tensor->set_data(src_tensor->data()); - } - dst_tensor->set_shape(src_tensor->shape()); - } - return RET_OK; -} -#endif -} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.h b/mindspore/lite/src/runtime/kernel/arm/base/carry_data.h deleted file mode 100644 index 638d340fee9..00000000000 --- a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CARRY_DATA_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CARRY_DATA_H_ - -#include -#include "src/inner_kernel.h" -#include "src/tensor.h" -#ifndef CONTROLFLOW_TENSORLIST_CLIP -#include "src/tensorlist.h" -#endif - -namespace mindspore::kernel { -class CarryDataKernel : public InnerKernel { - public: - CarryDataKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::InnerContext *ctx) - : InnerKernel(parameter, inputs, outputs, ctx) {} - ~CarryDataKernel() override = default; - - protected: - int MoveData(const std::vector::iterator &dst_begin, - const std::vector::iterator &dst_end, - const std::vector::iterator &src_begin, - const std::vector::iterator &src_limit); - int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor); -#ifndef CONTROLFLOW_TENSORLIST_CLIP - int MoveTensorListData(lite::TensorList *dst_tensorlist, lite::TensorList *src_tensorlist); -#endif -}; -} // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CARRY_DATA_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h index b26ba724865..b0b66290391 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h @@ -79,7 +79,7 @@ class ConvolutionBaseCPUKernel : public InnerKernel { virtual int MallocWeightBiasData() { return RET_OK; } virtual void PackWeight() {} - bool IsRepack() { return is_repack_; } + bool IsRepack() const { return is_repack_; } std::unordered_map addr_map; void *packed_weight_ = nullptr; #ifdef SERVER_INFERENCE diff --git a/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_base.cc index ca6d823ed72..b6f687bbfbe 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_base.cc @@ -55,7 +55,7 @@ int GroupConvolutionBaseCPUKernel::ReSize() { return RET_OK; } -void GroupConvolutionBaseCPUKernel::FreeSubKernel() { +GroupConvolutionBaseCPUKernel::~GroupConvolutionBaseCPUKernel() { for (auto &sub_conv : group_convs_) { // free sub conv input tensors / output tensors manually auto sub_in_tensors = sub_conv->in_tensors(); @@ -102,7 +102,6 @@ int GroupConvolutionBaseCPUKernel::PreProcess() { sub_kernel_in_tensor->set_shape(in_shape); ret = sub_kernel_in_tensor->MallocData(); if (ret != RET_OK) { - FreeSubKernel(); MS_LOG(ERROR) << "sub kernel in tensor malloc data failed."; return ret; } @@ -116,7 +115,6 @@ int GroupConvolutionBaseCPUKernel::PreProcess() { tensor->set_shape(out_shape); ret = tensor->MallocData(); if (ret != RET_OK) { - FreeSubKernel(); MS_LOG(ERROR) << "sub kernel out tensor malloc data failed."; return ret; } @@ -134,7 +132,6 @@ int GroupConvolutionBaseCPUKernel::PreProcess() { CHECK_NULL_RETURN(output); auto ret = output->MallocData(); if (ret != RET_OK) { - FreeSubKernel(); MS_LOG(ERROR) << "group conv out tensor malloc data failed."; return ret; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_base.h b/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_base.h index e119963d3f1..04f53a9bea6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_base.h @@ -35,7 +35,7 @@ class GroupConvolutionBaseCPUKernel : public ConvolutionBaseCPUKernel { group_conv_creator_(group_conv_creator), group_num_(group_num) {} // opParameter(in channel, out channel) in this kernel has been split to groups, if // you want to get real params, multiply in channel / out channel with group num - ~GroupConvolutionBaseCPUKernel() override { FreeSubKernel(); } + ~GroupConvolutionBaseCPUKernel() override; int Prepare() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_creator.cc b/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_creator.cc index 46c20022145..417ce8deb8f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_creator.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_creator.cc @@ -81,7 +81,8 @@ lite::Tensor *CreateConstTensor(const lite::Tensor *tensor, const std::vector(tensor->data()) + index * new_tensor->Size(); + void *new_tensor_data = + reinterpret_cast(reinterpret_cast(tensor->data()) + index * new_tensor->Size()); memcpy(new_tensor->data(), reinterpret_cast(new_tensor_data), new_tensor->Size()); return new_tensor; } @@ -141,7 +142,7 @@ int GroupConvCreator::NewInputTensor(std::vector *tensors) { return lite::RET_OK; } -int GroupConvCreator::NewOutputTensor(std::vector *tensors, lite::Tensor *output) { +int GroupConvCreator::NewOutputTensor(std::vector *tensors, const lite::Tensor *output) const { auto out_tensor = CreateVarTensor({output_shape_, output->format(), data_type_, output->category(), false}, infered_); if (out_tensor == nullptr) { return lite::RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_creator.h b/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_creator.h index 03d12b0c8ec..976ba5f43e3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_creator.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/group_convolution_creator.h @@ -35,12 +35,11 @@ struct TensorInfo { class GroupConvCreator { public: GroupConvCreator(std::vector inputs, std::vector outputs, OpParameter *op_parameter, - const lite::InnerContext *ctx, bool is_quant, TypeId data_type) + bool is_quant, TypeId data_type) : origin_inputs_(std::move(inputs)), origin_outputs_(std::move(outputs)), is_quant_(is_quant), - data_type_(data_type), - ctx_(ctx) { + data_type_(data_type) { auto shape = origin_outputs_.front()->shape(); infered_ = std::find(shape.begin(), shape.end(), -1) == shape.end(); conv_param_ = reinterpret_cast(op_parameter); @@ -48,10 +47,8 @@ class GroupConvCreator { ~GroupConvCreator() = default; - public: void SetShapeOfTensors(); int CreateConvs(std::vector *group_convs); - std::vector *get_group_conv() { return &group_convs_; } void CopyQuantParam(const std::vector *tensors); int GetSingleConvParam(ConvParameter *conv_param, std::vector *new_inputs, std::vector *new_outputs, int group_id); @@ -64,7 +61,7 @@ class GroupConvCreator { void FreeGroupConvs(); int NewInputTensor(std::vector *tensors); int NewConstTensor(std::vector *tensors, int group_id); - int NewOutputTensor(std::vector *tensors, lite::Tensor *output); + int NewOutputTensor(std::vector *tensors, const lite::Tensor *output) const; private: std::vector origin_inputs_; @@ -78,7 +75,6 @@ class GroupConvCreator { bool infered_ = false; bool is_quant_ = false; TypeId data_type_; - const lite::InnerContext *ctx_ = nullptr; }; ConvParameter *CreateNewConvParameter(const ConvParameter *parameter); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/one_hot_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/one_hot_base.cc index 5d293a93e47..6ada7977462 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/one_hot_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/one_hot_base.cc @@ -87,7 +87,6 @@ int OneHotCPUKernel::ReSize() { } int RunOneHot(void *cdata, int task_id, float lhs_scale, float rhs_scale) { - CHECK_NULL_RETURN(cdata); auto onehot_kernel = reinterpret_cast(cdata); if (onehot_kernel == nullptr) { MS_LOG(ERROR) << "cast OneHotCPUKernel failed"; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/one_hot_base.h b/mindspore/lite/src/runtime/kernel/arm/base/one_hot_base.h index 1ebec711039..6ee73affed6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/one_hot_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/one_hot_base.h @@ -41,7 +41,6 @@ class OneHotCPUKernel : public InnerKernel { int InitOnOffValueForThreeInputs(); int InitOnOffValueForFourInputs(); - private: int thread_num_ = 1; int axis_ = 0; int outer_size_ = 0; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/partial_fusion.h b/mindspore/lite/src/runtime/kernel/arm/base/partial_fusion.h index 6bfc59ff8bc..8fcd86e28b8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/partial_fusion.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/partial_fusion.h @@ -33,8 +33,6 @@ class PartialFusionKernel : public InnerKernel { int Run() override; void set_subgraph_kernels(const std::vector &subgraph_kernels) { subgraph_kernels_ = subgraph_kernels; } std::vector subgraph_kernels() const { return subgraph_kernels_; } - - private: // One partial corresponds to a subgraph at offline stage, after graph schedule, a subgraph may be split into many // graphs, so use a vector. std::vector subgraph_kernels_{}; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h index e4a3b2485de..706733ee4c1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h @@ -30,7 +30,7 @@ class PoolingBaseCPUKernel : public InnerKernel { public: PoolingBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const InnerContext *ctx) - : InnerKernel(parameter, inputs, outputs, ctx), ctx_(ctx), thread_count_(ctx->thread_num_) { + : InnerKernel(parameter, inputs, outputs, ctx), thread_count_(ctx->thread_num_) { pooling_param_ = reinterpret_cast(op_parameter_); } ~PoolingBaseCPUKernel() = default; @@ -42,7 +42,6 @@ class PoolingBaseCPUKernel : public InnerKernel { void FreeQuantParam(); protected: - const InnerContext *ctx_; int thread_count_; PoolingParameter *pooling_param_ = nullptr; QuantArg **pooling_quant_arg_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc index d25afddfd7a..06e33c5a49e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc @@ -159,12 +159,11 @@ int PriorBoxCPUKernel::GeneratePriorBox() { int PriorBoxCPUKernel::PriorBoxImpl(int task_id) { auto src = output_.data(); - MS_CHECK_TRUE_RET(src != nullptr, RET_NULL_PTR); + CHECK_NULL_RETURN(src); auto output = out_tensors_.at(0); CHECK_NULL_RETURN(output); - MS_CHECK_TRUE_RET(output != nullptr, RET_NULL_PTR); auto output_data = reinterpret_cast(output->data()); - MS_CHECK_TRUE_RET(output_data != nullptr, RET_NULL_PTR); + CHECK_NULL_RETURN(output_data); auto ret = PriorBox(src, output_data, output_.size(), task_id, thread_count_); return ret; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc index 33ff1b08bf4..9d09724c6c1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc @@ -106,7 +106,7 @@ int ReduceBaseCPUKernel::Prepare() { MS_CHECK_FALSE_MSG((axes_tensor->data_type() != kNumberTypeInt && axes_tensor->data_type() != kNumberTypeInt32), RET_ERROR, "The data type of axes tensor should be int32"); num_axes_ = axes_tensor->ElementsNum(); - if (num_axes_ <= 0 && num_axes_ > MAX_SHAPE_SIZE) { + if (num_axes_ <= 0 || num_axes_ > MAX_SHAPE_SIZE) { MS_LOG(ERROR) << "input axes invalid."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.h b/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.h index e5e3be9b8f5..c25e566d8b0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.h @@ -19,7 +19,6 @@ #include #include "src/inner_kernel.h" - #include "nnacl/reduce_parameter.h" namespace mindspore::kernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h index eb65d110439..2660ec09c8b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h @@ -20,7 +20,6 @@ #include "src/inner_kernel.h" #include "include/context.h" #include "include/errorcode.h" -#include "src/runtime/kernel/arm/base/carry_data.h" using mindspore::lite::InnerContext; namespace mindspore::kernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/base/select.cc b/mindspore/lite/src/runtime/kernel/arm/base/select.cc index 48171d33515..f91a2f10e03 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/select.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/select.cc @@ -22,6 +22,7 @@ using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NOT_SUPPORT; using mindspore::lite::RET_NULL_PTR; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Select; @@ -34,6 +35,121 @@ int SelectCPUKernel::Prepare() { return RET_OK; } int SelectCPUKernel::ReSize() { return RET_OK; } +int MoveTensorData(lite::Tensor *dst_tensor, const lite::Tensor *src_tensor) { + if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() || + !(dst_tensor->shape() == src_tensor->shape() || (dst_tensor->shape().empty() && src_tensor->shape().empty()))) { + MS_LOG(ERROR) << "input tensor and output tensor is incompatible."; + MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs " + << "output tensor data_type: " << dst_tensor->data_type() + << "input tensor format: " << src_tensor->format() << " vs " + << "output tensor format: " << dst_tensor->format() << " input tensor shape: " << src_tensor->shape() + << " vs " + << "output tensor shape: " << dst_tensor->shape(); + return RET_ERROR; + } + if (src_tensor->allocator() == nullptr) { + MS_LOG(ERROR) << "src_tensor allocator is nullptr."; + return RET_ERROR; + } + + CHECK_NULL_RETURN(src_tensor->data()); + CHECK_NULL_RETURN(dst_tensor->data()); + // need replace with increase data ref count + MS_CHECK_FALSE(src_tensor->Size() == 0, RET_ERROR); + (void)memcpy(dst_tensor->data(), src_tensor->data(), src_tensor->Size()); + return RET_OK; +} + +#ifndef CONTROLFLOW_TENSORLIST_CLIP +int MoveTensorListData(lite::TensorList *dst_tensorlist, lite::TensorList *src_tensorlist) { + // shape may change, because tensors.size() can be change in RunGraph + if (dst_tensorlist->data_type() != src_tensorlist->data_type() || + dst_tensorlist->format() != src_tensorlist->format()) { + MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible"; + MS_LOG(ERROR) << "input tensor data_type: " << src_tensorlist->data_type() << " vs " + << "output tensor data_type: " << dst_tensorlist->data_type() + << "input tensor format: " << src_tensorlist->format() << " vs " + << "output tensor format: " << dst_tensorlist->format(); + return RET_ERROR; + } + // when tensorlist malloc is done. this need to check element_shape compatibility + dst_tensorlist->set_element_shape(src_tensorlist->element_shape()); + + auto update_data_type = kTypeUnknown; + auto dst_tensor_data_type = dst_tensorlist->tensors_data_type(); + auto src_tensor_data_type = src_tensorlist->tensors_data_type(); + if (dst_tensor_data_type != src_tensor_data_type) { + if (src_tensor_data_type != kTypeUnknown && dst_tensor_data_type != kTypeUnknown) { + MS_LOG(ERROR) << "input tensorlist and output tensorlist is incompatible"; + return RET_ERROR; + } + update_data_type = dst_tensor_data_type != kTypeUnknown ? dst_tensor_data_type : src_tensor_data_type; + } + if (update_data_type != kTypeUnknown) { + src_tensorlist->set_tensors_data_type(update_data_type); + dst_tensorlist->set_tensors_data_type(update_data_type); + } + size_t src_tensorlist_tensors_size = src_tensorlist->tensors().size(); + for (size_t i = 0; i < src_tensorlist_tensors_size; ++i) { + auto &src_tensor = src_tensorlist->tensors()[i]; + auto &dst_tensor = dst_tensorlist->tensors()[i]; + + if (src_tensor->allocator() != nullptr) { + src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count()); + } + dst_tensor->set_own_data(src_tensor->own_data()); + if (src_tensor->data() != nullptr) { + dst_tensor->set_data(src_tensor->data()); + } + dst_tensor->set_shape(src_tensor->shape()); + } + return RET_OK; +} +#endif + +int MoveData(const std::vector::iterator &dst_begin, + const std::vector::iterator &dst_end, + const std::vector::iterator &src_begin, + const std::vector::iterator &src_limit) { + for (auto dst_iter = dst_begin, src_iter = src_begin; dst_iter != dst_end; dst_iter++, src_iter++) { + if (src_iter == src_limit) { + MS_LOG(ERROR) << "out of range of input tensor"; + return RET_ERROR; + } + auto *dst_tensor = *dst_iter; + auto *src_tensor = *src_iter; + if (dst_tensor == nullptr || src_tensor == nullptr) { + MS_LOG(ERROR) << "input tensor or output tensor of merge is nullptr"; + return RET_ERROR; + } + lite::STATUS ret = RET_OK; + if (src_tensor->IsConst() || src_tensor->IsGraphInput()) { + MS_LOG(DEBUG) << "Carry const data and graph inputs."; + dst_tensor->set_data(src_tensor->data()); + dst_tensor->set_own_data(false); + } else { + if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) { +#ifndef CONTROLFLOW_TENSORLIST_CLIP + MS_LOG(DEBUG) << "Carry MoveTensorListData"; + ret = MoveTensorListData(reinterpret_cast(dst_tensor), + reinterpret_cast(src_tensor)); +#else + MS_LOG(ERROR) << unsupport_controlflow_tensorlist_log; + return RET_NOT_SUPPORT; +#endif + } else { + MS_LOG(DEBUG) << "Carry MoveTensorData"; + ret = MoveTensorData(dst_tensor, src_tensor); + } + } + if (ret != RET_OK) { + MS_LOG(ERROR) << "Move data failed : " << ret; + return ret; + } + } + return RET_OK; +} + // inputs: bool*1 true-data*n false-data*n // output: data*n int SelectCPUKernel::Run() { diff --git a/mindspore/lite/src/runtime/kernel/arm/base/select.h b/mindspore/lite/src/runtime/kernel/arm/base/select.h index eefc8c37717..5b3e96f8641 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/select.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/select.h @@ -17,18 +17,17 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SELECT_H_ #include -#include "src/runtime/kernel/arm/base/carry_data.h" #include "src/inner_kernel.h" #ifndef CONTROLFLOW_TENSORLIST_CLIP #include "src/tensorlist.h" #endif namespace mindspore::kernel { -class SelectCPUKernel : public CarryDataKernel { +class SelectCPUKernel : public InnerKernel { public: SelectCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) - : CarryDataKernel(parameter, inputs, outputs, ctx) {} + : InnerKernel(parameter, inputs, outputs, ctx) {} ~SelectCPUKernel() override = default; int Prepare() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/slice_base.h b/mindspore/lite/src/runtime/kernel/arm/base/slice_base.h index 312b0a7eac8..eddbf95aa4f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/slice_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/slice_base.h @@ -33,8 +33,6 @@ class SliceCPUKernel : public InnerKernel { int Prepare() override; int ReSize() override; int Run() override; - - public: int SliceParallelRun(int thread_id); protected: diff --git a/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h index f8ac5a22bde..2649236066b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h @@ -26,7 +26,7 @@ class SoftmaxBaseCPUKernel : public InnerKernel { public: SoftmaxBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) - : InnerKernel(parameter, inputs, outputs, ctx), ctx_(ctx), thread_count_(ctx->thread_num_) { + : InnerKernel(parameter, inputs, outputs, ctx), thread_count_(ctx->thread_num_) { softmax_param_ = reinterpret_cast(op_parameter_); } ~SoftmaxBaseCPUKernel() = default; @@ -36,7 +36,6 @@ class SoftmaxBaseCPUKernel : public InnerKernel { int Run() override { return 0; } protected: - const lite::InnerContext *ctx_; int thread_count_; SoftmaxParameter *softmax_param_; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_base.h b/mindspore/lite/src/runtime/kernel/arm/base/split_base.h index 6778a65607b..4f67f0f0ea4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/split_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_base.h @@ -41,8 +41,6 @@ class SplitBaseCPUKernel : public InnerKernel { int Prepare() override; int ReSize() override; int Run() override; - - public: virtual int Split(int task_id); static int CheckAndInitSplitParam(const lite::Tensor &in_tensor, SplitParameter *param); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.cc index 32cacdef9f5..90829c47efa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.cc @@ -25,9 +25,9 @@ using mindspore::lite::RET_ERROR; using mindspore::lite::RET_NULL_PTR; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_SplitWithOverlap; -#define MIN_NUM_SPLIT 2 namespace mindspore::kernel { +const int MIN_NUM_SPLIT = 2; int SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const std::vector &shape) { int total_block_count = 0; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.h b/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.h index 819a7659900..7272a10a7bc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_with_over_lap_base.h @@ -41,7 +41,6 @@ class SplitWithOverlapBaseCPUKernel : public InnerKernel { private: int CalculateSplitedShapes(const std::vector &shape); - private: // range: [start, end) std::vector start_indices_; std::vector end_indices_; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/stack_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/stack_base.cc index 494470ba71e..5421a3cd49d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/stack_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/stack_base.cc @@ -68,7 +68,6 @@ int StackBaseCPUKernel::ReSize() { copy_size_ = GetCopyNum(input0_shape, axis_, input0_shape.size()) * data_type_size_; outer_size_ = GetOuterSize(input0_shape, axis_); } - MS_CHECK_GT(copy_size_, 0, RET_ERROR); return RET_OK; } @@ -83,7 +82,7 @@ int StackBaseCPUKernel::Prepare() { return ReSize(); } -int StackBaseCPUKernel::Execute(int task_id) { +int StackBaseCPUKernel::StackExecute(int task_id) { auto output_data = reinterpret_cast(out_tensors_.at(0)->data()); if (output_data == nullptr) { return RET_NULL_PTR; @@ -103,7 +102,7 @@ int StackBaseCPUKernel::Execute(int task_id) { static int StackRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { CHECK_NULL_RETURN(cdata); auto stack = reinterpret_cast(cdata); - if (stack->Execute(task_id) != RET_OK) { + if (stack->StackExecute(task_id) != RET_OK) { return RET_ERROR; } return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/stack_base.h b/mindspore/lite/src/runtime/kernel/arm/base/stack_base.h index 967b2e4d6cd..29c5c7ed17a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/stack_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/stack_base.h @@ -34,7 +34,7 @@ class StackBaseCPUKernel : public InnerKernel { int Prepare() override; int ReSize() override; int Run() override; - int Execute(int task_id); + int StackExecute(int task_id); protected: StackParameter *stack_param_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc index 225ae3a34ce..9053aefe0d5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc @@ -202,7 +202,7 @@ kernel::InnerKernel *CpuGroupConvFp16KernelCreator(const std::vector &outputs, OpParameter *op_parameter, const InnerContext *ctx) { auto *group_conv_creator = - new (std::nothrow) GroupConvCreator(inputs, outputs, op_parameter, ctx, false, kNumberTypeFloat16); + new (std::nothrow) GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat16); if (group_conv_creator == nullptr) { MS_LOG(ERROR) << "new GroupConvCreator fail"; free(op_parameter); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc index e47732467a3..fe341e94aeb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc @@ -289,7 +289,7 @@ kernel::InnerKernel *CpuConvDwFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::InnerContext *ctx) { - auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, ctx, false, kNumberTypeFloat32); + auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat32); auto group_kernel = new (std::nothrow) GroupConvolutionFp32CPUKernel( op_parameter, inputs, outputs, ctx, group_conv_creator, reinterpret_cast(op_parameter)->group_); if (group_kernel == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8_creator.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8_creator.cc index f1b8e546d88..42c01a50ac8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8_creator.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8_creator.cc @@ -108,7 +108,7 @@ kernel::InnerKernel *CpuGroupConvInt8KernelCreator(const std::vectorinput_channel_; return nullptr; } - auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, ctx, true, kNumberTypeInt8); + auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, true, kNumberTypeInt8); return new (std::nothrow) GroupConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx, group_conv_creator, group); }