diff --git a/mindspore/lite/examples/runtime_extend/main.cc b/mindspore/lite/examples/runtime_extend/main.cc index 9f558a16abb..1fd18ee98f7 100644 --- a/mindspore/lite/examples/runtime_extend/main.cc +++ b/mindspore/lite/examples/runtime_extend/main.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include "include/api/status.h" @@ -28,14 +29,16 @@ namespace lite { namespace { constexpr int kNumPrintOfOutData = 20; Status FillInputData(const std::vector &inputs) { + std::mt19937 random_engine; for (auto tensor : inputs) { auto input_data = tensor.MutableData(); if (input_data == nullptr) { std::cerr << "MallocData for inTensor failed.\n"; return kLiteError; } - std::vector temp(tensor.ElementNum(), 1.0f); - memcpy(input_data, temp.data(), tensor.DataSize()); + auto distribution = std::uniform_real_distribution(1.0f, 1.0f); + (void)std::generate_n(static_cast(input_data), tensor.ElementNum(), + [&]() { return distribution(random_engine); }); } return kSuccess; } diff --git a/mindspore/lite/src/litert/kernel/cpu/base/convolution_base.cc b/mindspore/lite/src/litert/kernel/cpu/base/convolution_base.cc index 367798a5322..cfd83bf76a6 100644 --- a/mindspore/lite/src/litert/kernel/cpu/base/convolution_base.cc +++ b/mindspore/lite/src/litert/kernel/cpu/base/convolution_base.cc @@ -109,8 +109,12 @@ int ConvolutionBaseCPUKernel::Prepare() { auto input = this->in_tensors_.front(); auto output = this->out_tensors_.front(); CHECK_NULL_RETURN(input); + CHECK_NULL_RETURN(in_tensors_[1]); CHECK_NULL_RETURN(output); CHECK_NULL_RETURN(conv_param_); + MS_CHECK_TRUE_MSG(input->shape().size() == C4NUM, RET_ERROR, "Conv-like: input-shape should be 4D."); + MS_CHECK_TRUE_MSG(in_tensors_[1]->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); + MS_CHECK_TRUE_MSG(output->shape().size() == C4NUM, RET_ERROR, "Conv-like: out-shape should be 4D."); conv_param_->input_batch_ = input->Batch(); conv_param_->input_h_ = input->Height(); conv_param_->input_w_ = input->Width(); @@ -494,7 +498,8 @@ void ConvolutionBaseCPUKernel::UpdateOriginWeightAndBias() { if (in_tensors_.at(kWeightIndex)->data() != nullptr) { origin_weight_ = in_tensors_.at(kWeightIndex)->data(); } - if (in_tensors_.size() == kInputSize2 && in_tensors_.at(kBiasIndex)->data() != nullptr) { + if (in_tensors_.size() == kInputSize2 && in_tensors_.at(kBiasIndex) != nullptr && + in_tensors_.at(kBiasIndex)->data() != nullptr) { origin_bias_ = in_tensors_.at(kBiasIndex)->data(); } } diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/activation_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/activation_fp16.cc index 8c40f156192..8ea39b2aefa 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp16/activation_fp16.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp16/activation_fp16.cc @@ -38,6 +38,8 @@ namespace mindspore::kernel { int ActivationFp16CPUKernel::Prepare() { CHECK_LESS_RETURN(in_tensors_.size(), 1); CHECK_LESS_RETURN(out_tensors_.size(), 1); + CHECK_NULL_RETURN(in_tensors_[0]); + CHECK_NULL_RETURN(out_tensors_[0]); return RET_OK; } diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/addn_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/addn_fp16.cc index 045b1ad4aab..cf6cf76de0d 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp16/addn_fp16.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp16/addn_fp16.cc @@ -36,7 +36,14 @@ int AddNLaunch(void *cdata, int task_id, float lhs_scale, float rhs_scale) { } } // namespace -int AddNFp16CPUKernel::Prepare() { return RET_OK; } +int AddNFp16CPUKernel::Prepare() { + CHECK_LESS_RETURN(in_tensors_.size(), C2NUM); + CHECK_LESS_RETURN(out_tensors_.size(), 1); + CHECK_NULL_RETURN(in_tensors_[0]); + CHECK_NULL_RETURN(in_tensors_[1]); + CHECK_NULL_RETURN(out_tensors_[0]); + return RET_OK; +} int AddNFp16CPUKernel::ReSize() { return RET_OK; } diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/arithmetic_compare_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/arithmetic_compare_fp16.cc index c881b5cca2c..040c3186541 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp16/arithmetic_compare_fp16.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp16/arithmetic_compare_fp16.cc @@ -68,6 +68,9 @@ ArithmeticCompareOptFuncFp16 GetOptimizedArithmeticCompareFun(int primitive_type int ArithmeticCompareFP16CPUKernel::Prepare() { CHECK_LESS_RETURN(in_tensors_.size(), 2); CHECK_LESS_RETURN(out_tensors_.size(), 1); + CHECK_NULL_RETURN(in_tensors_[0]); + CHECK_NULL_RETURN(in_tensors_[1]); + CHECK_NULL_RETURN(out_tensors_[0]); if (!InferShapeDone()) { return RET_OK; } diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/biasadd_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/biasadd_fp16.cc index 91de8d2e713..78bdbaed4f4 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp16/biasadd_fp16.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp16/biasadd_fp16.cc @@ -118,6 +118,9 @@ int BiasAddCPUFp16Kernel::GetBiasData() { int BiasAddCPUFp16Kernel::Prepare() { CHECK_LESS_RETURN(in_tensors_.size(), 2); CHECK_LESS_RETURN(out_tensors_.size(), 1); + CHECK_NULL_RETURN(in_tensors_[0]); + CHECK_NULL_RETURN(in_tensors_[1]); + CHECK_NULL_RETURN(out_tensors_[0]); bias_tensor_ = in_tensors_.at(1); CHECK_NULL_RETURN(bias_tensor_); if (!InferShapeDone()) { diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/common_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/common_fp16.cc index 31935c9c0e3..47ac4de3d0d 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp16/common_fp16.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp16/common_fp16.cc @@ -22,10 +22,12 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { float16_t *ConvertInputFp32toFp16(lite::Tensor *input, const lite::InnerContext *ctx) { + MS_CHECK_TRUE_MSG(input != nullptr, nullptr, "input must be not a nullptr."); float16_t *fp16_data = nullptr; auto data_type = input->data_type(); if (data_type == kNumberTypeFloat32) { auto ele_num = input->ElementsNum(); + MS_CHECK_TRUE_MSG(ctx != nullptr, nullptr, "ctx must be not a nullptr."); fp16_data = reinterpret_cast(ctx->allocator->Malloc(ele_num * sizeof(float16_t))); if (fp16_data == nullptr) { MS_LOG(ERROR) << "malloc fp16_data failed."; @@ -40,10 +42,12 @@ float16_t *ConvertInputFp32toFp16(lite::Tensor *input, const lite::InnerContext } float16_t *MallocOutputFp16(lite::Tensor *output, const lite::InnerContext *ctx) { + MS_CHECK_TRUE_MSG(output != nullptr, nullptr, "output must be not as nullptr."); float16_t *fp16_data = nullptr; auto data_type = output->data_type(); if (data_type == kNumberTypeFloat32) { auto ele_num = output->ElementsNum(); + MS_CHECK_TRUE_MSG(ctx != nullptr, nullptr, "ctx must be not a nullptr."); fp16_data = reinterpret_cast(ctx->allocator->Malloc(ele_num * sizeof(float16_t))); if (fp16_data == nullptr) { MS_LOG(ERROR) << "malloc fp16_data failed."; @@ -56,9 +60,11 @@ float16_t *MallocOutputFp16(lite::Tensor *output, const lite::InnerContext *ctx) } int ConvertFp32TensorToFp16(lite::Tensor *tensor, const lite::InnerContext *ctx) { + MS_CHECK_TRUE_MSG(tensor != nullptr, RET_ERROR, "ConvertFp32TensorToFp16 failed, due to the tensor is a nullptr."); if (tensor->data_type() == TypeId::kNumberTypeFloat16) { return RET_OK; } + MS_CHECK_TRUE_MSG(ctx != nullptr, RET_ERROR, "ConvertFp32TensorToFp16 failed, due to the ctx is a nullptr."); auto fp32_data = tensor->data(); tensor->set_data(nullptr); tensor->set_data_type(TypeId::kNumberTypeFloat16); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.cc index 85e27e3f0d8..7e41540b3ce 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.cc @@ -50,6 +50,7 @@ void ConvolutionDelegateFP16CPUKernel::FreeCopiedData() { } void *ConvolutionDelegateFP16CPUKernel::CopyData(const lite::Tensor *tensor) { + MS_CHECK_TRUE_MSG(tensor != nullptr, nullptr, "tensor must be not a nullptr."); auto data_type = tensor->data_type(); if (data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16) { MS_LOG(ERROR) << "Not supported data type: " << data_type; @@ -68,19 +69,21 @@ void *ConvolutionDelegateFP16CPUKernel::CopyData(const lite::Tensor *tensor) { int ConvolutionDelegateFP16CPUKernel::Prepare() { CHECK_LESS_RETURN(in_tensors_.size(), C2NUM); CHECK_LESS_RETURN(out_tensors_.size(), 1); + auto weight_tensor = in_tensors_.at(kWeightIndex); + CHECK_NULL_RETURN(weight_tensor); if (!InferShapeDone()) { - auto weight_tensor = in_tensors_.at(kWeightIndex); - CHECK_NULL_RETURN(weight_tensor); origin_weight_ = weight_tensor->data() != nullptr ? CopyData(weight_tensor) : nullptr; need_free_ = need_free_ | WEIGHT_NEED_FREE; if (in_tensors_.size() == C3NUM) { + CHECK_NULL_RETURN(in_tensors_.at(kBiasIndex)); origin_bias_ = CopyData(in_tensors_.at(kBiasIndex)); need_free_ = need_free_ | BIAS_NEED_FREE; } return RET_OK; } - origin_weight_ = in_tensors_.at(kWeightIndex)->data(); + origin_weight_ = weight_tensor->data(); if (in_tensors_.size() == C3NUM) { + CHECK_NULL_RETURN(in_tensors_.at(kBiasIndex)); origin_bias_ = in_tensors_.at(kBiasIndex)->data(); MS_ASSERT(origin_bias_ != nullptr); } diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.cc index 15479cb2cb7..435560d74cc 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.cc @@ -68,6 +68,7 @@ int ConvolutionDepthwise3x3Fp16CPUKernel::Prepare() { if (op_parameter_->is_train_session_) { auto weight_tensor = in_tensors_.at(kWeightIndex); CHECK_NULL_RETURN(weight_tensor); + MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); int channel = weight_tensor->Batch(); int c8 = UP_ROUND(channel, C8NUM); int pack_weight_size = c8 * C12NUM; diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.cc index 5da5ae8f032..04312bb66e3 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.cc @@ -65,6 +65,7 @@ int ConvolutionDepthwiseFp16CPUKernel::Prepare() { if (op_parameter_->is_train_session_) { auto weight_tensor = in_tensors_.at(kWeightIndex); CHECK_NULL_RETURN(weight_tensor); + MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); int channel = weight_tensor->Batch(); int pack_weight_size = channel * weight_tensor->Height() * weight_tensor->Width(); set_workspace_size(pack_weight_size * sizeof(float16_t)); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.cc index d157966ebfa..bdeb2e574b6 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.cc @@ -97,6 +97,7 @@ int ConvolutionDepthwiseSWFp16CPUKernel::Prepare() { if (op_parameter_->is_train_session_) { auto weight_tensor = in_tensors_.at(kWeightIndex); CHECK_NULL_RETURN(weight_tensor); + MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); int OC8 = UP_DIV(weight_tensor->Batch(), C8NUM); int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width(); set_workspace_size(pack_weight_size * sizeof(float16_t)); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_fp16.cc index 842493ff0a6..cebd9ea740f 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_fp16.cc @@ -104,6 +104,7 @@ int ConvolutionFP16CPUKernel::Prepare() { if (op_parameter_->is_train_session_) { auto filter_tensor = in_tensors_.at(kWeightIndex); CHECK_NULL_RETURN(filter_tensor); + MS_CHECK_TRUE_MSG(filter_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); int in_channel = filter_tensor->Channel(); int out_channel = filter_tensor->Batch(); int oc8 = UP_ROUND(out_channel, col_tile_); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_winograd_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_winograd_fp16.cc index 3e132d7e40e..569c5a862bc 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_winograd_fp16.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp16/convolution_winograd_fp16.cc @@ -175,6 +175,7 @@ int ConvolutionWinogradFP16CPUKernel::Prepare() { if (op_parameter_->is_train_session_) { auto weight_tensor = in_tensors_.at(kWeightIndex); CHECK_NULL_RETURN(weight_tensor); + MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); int in_channel = weight_tensor->Channel(); int out_channel = weight_tensor->Batch(); int oc_block_num = UP_DIV(out_channel, col_tile_); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.cc index 73e3b318837..b42ccdc216c 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.cc @@ -110,6 +110,8 @@ int DeconvolutionDepthwiseFp16CPUKernel::Prepare() { if (op_parameter_->is_train_session_) { auto weight_tensor = in_tensors_.at(kWeightIndex); + CHECK_NULL_RETURN(weight_tensor); + MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); int OC8 = UP_DIV(weight_tensor->Batch(), C8NUM); int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width(); set_workspace_size(pack_weight_size * sizeof(float16_t)); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/activation_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/activation_fp32.cc index f0ad3e6de77..04566c2c942 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp32/activation_fp32.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp32/activation_fp32.cc @@ -36,6 +36,8 @@ namespace mindspore::kernel { int ActivationCPUKernel::Prepare() { CHECK_LESS_RETURN(in_tensors_.size(), 1); CHECK_NOT_EQUAL_RETURN(out_tensors_.size(), 1); + CHECK_NULL_RETURN(in_tensors_[0]); + CHECK_NULL_RETURN(out_tensors_[0]); if (in_tensors().front()->data_type() == kNumberTypeInt32) { if (type_ != schema::ActivationType_RELU) { diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/affine_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/affine_fp32.cc index ad67b18abeb..8f2d96b6e73 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp32/affine_fp32.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp32/affine_fp32.cc @@ -61,9 +61,16 @@ bool AffineFp32CPUKernel::CheckAffineValid() { if (in_tensors_.size() < kAffineMinInputNum) { return false; } + if (std::any_of(in_tensors_.begin(), in_tensors_.end(), [](const auto &in_tensor) { return in_tensor == nullptr; })) { + return false; + } if (out_tensors_.size() != kAffineMaxOutputNum) { return false; } + if (std::any_of(out_tensors_.begin(), out_tensors_.end(), + [](const auto &out_tensor) { return out_tensor == nullptr; })) { + return false; + } return true; } @@ -153,6 +160,7 @@ int AffineFp32CPUKernel::FullRunInit() { int AffineFp32CPUKernel::IncrementInit() { auto out_tensor = out_tensors_.at(kOutputIndex); auto out_shape = out_tensor->shape(); + MS_CHECK_TRUE_MSG(out_shape.size() >= C3NUM, RET_ERROR, "Out-shape is invalid, which must be 3D or bigger."); matmul_col_ = out_shape.at(kInputCol); matmul_row_ = out_shape.at(kInputRow); MS_CHECK_INT_MUL_NOT_OVERFLOW(matmul_row_, matmul_col_, RET_ERROR); @@ -279,6 +287,7 @@ kernel::LiteKernel *AffineFp32CPUKernel::FullMatmulKernelCreate() { kernel::LiteKernel *AffineFp32CPUKernel::IncrementMatmulKernelCreate() { auto input_shape = in_tensors_.front()->shape(); + MS_CHECK_TRUE_MSG(!input_shape.empty(), nullptr, "First input-shape is empty."); int src_col = input_shape.at(input_shape.size() - 1); int context_dims = affine_parameter_->context_size_; int affine_splice_output_col = affine_parameter_->output_dim_; @@ -294,7 +303,9 @@ kernel::LiteKernel *AffineFp32CPUKernel::IncrementMatmulKernelCreate() { MS_CHECK_TRUE_MSG(increment_input_ != nullptr, nullptr, "Create a new-tensor failed."); // matmul_output == 1 * matmul_col - int matmul_col = out_tensors_.front()->shape().back(); + auto out_shape = out_tensors_.front()->shape(); + MS_CHECK_TRUE_MSG(!out_shape.empty(), nullptr, "Out-shape is empty."); + int matmul_col = out_shape.back(); increment_output_ = new (std::nothrow) lite::Tensor(kNumberTypeFloat32, {1, 1, matmul_col}); MS_CHECK_TRUE_MSG(increment_output_ != nullptr, nullptr, "Create a new-tensor failed."); increment_output_->MallocData(); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/bias_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/bias_fp32.cc index 414efe33667..b45572f57d3 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp32/bias_fp32.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp32/bias_fp32.cc @@ -41,6 +41,9 @@ int BiasAddRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { int BiasCPUKernel::Prepare() { CHECK_LESS_RETURN(in_tensors_.size(), C2NUM); CHECK_LESS_RETURN(out_tensors_.size(), 1); + CHECK_NULL_RETURN(in_tensors_[0]); + CHECK_NULL_RETURN(in_tensors_[1]); + CHECK_NULL_RETURN(out_tensors_[0]); if (!InferShapeDone()) { return RET_OK; } diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.cc index a82ed767110..add7e698801 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.cc @@ -16,6 +16,7 @@ #include "src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.h" #include "include/errorcode.h" +#include "nnacl/op_base.h" #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) using mindspore::lite::RET_ERROR; @@ -31,6 +32,7 @@ int ConvolutionDepthwise3x3CPUKernel::Prepare() { if (op_parameter_->is_train_session_) { auto weight_tensor = in_tensors_.at(kWeightIndex); CHECK_NULL_RETURN(weight_tensor); + MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); int channel = weight_tensor->Batch(); int c4 = UP_ROUND(channel, C4NUM); MS_CHECK_INT_MUL_NOT_OVERFLOW(c4, C12NUM, RET_ERROR); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.cc index 6a514b3ae38..ef2cc681095 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.cc @@ -31,6 +31,7 @@ int ConvolutionDepthwiseCPUKernel::Prepare() { if (op_parameter_->is_train_session_) { auto weight_tensor = in_tensors_.at(kWeightIndex); CHECK_NULL_RETURN(weight_tensor); + MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_tensor->Height(), weight_tensor->Width(), RET_ERROR); int weight_size_hw = weight_tensor->Height() * weight_tensor->Width(); MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_tensor->Batch(), weight_size_hw, RET_ERROR); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_indirect_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_indirect_fp32.cc index 6f37d2f35b6..e18c5a34de6 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_indirect_fp32.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_indirect_fp32.cc @@ -41,6 +41,7 @@ int ConvolutionDepthwiseIndirectCPUKernel::Prepare() { if (op_parameter_->is_train_session_) { auto weight_tensor = in_tensors_[kWeightIndex]; CHECK_NULL_RETURN(weight_tensor); + MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); #ifdef ENABLE_AVX int div_flag = C8NUM; #else diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_fp32.cc index 83d91733725..81e9b4b489b 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_fp32.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_fp32.cc @@ -73,6 +73,8 @@ int ConvolutionDepthwiseSWCPUKernel::Prepare() { } if (op_parameter_->is_train_session_) { auto weight_tensor = in_tensors_.at(kWeightIndex); + CHECK_NULL_RETURN(weight_tensor); + MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM); MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_tensor->Height(), weight_tensor->Width(), RET_ERROR); int weight_size_hw = weight_tensor->Height() * weight_tensor->Width(); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.cc index de9c334f149..255e547e679 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.cc @@ -74,6 +74,8 @@ int ConvolutionDepthwiseSWCPUKernelX86::Prepare() { #endif if (op_parameter_->is_train_session_) { auto weight_tensor = in_tensors_.at(kWeightIndex); + CHECK_NULL_RETURN(weight_tensor); + MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); int oc_algin = UP_DIV(weight_tensor->Batch(), oc_tile_); MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_tensor->Height(), weight_tensor->Width(), RET_ERROR); int weight_size_hw = weight_tensor->Height() * weight_tensor->Width(); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_fp32.cc index 7ae635d7a32..c6c28d3a081 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_fp32.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_fp32.cc @@ -93,6 +93,7 @@ int ConvolutionCPUKernel::Prepare() { CHECK_LESS_RETURN(out_tensors_.size(), 1); if (op_parameter_->is_train_session_) { auto filter_tensor = in_tensors_.at(kWeightIndex); + MS_CHECK_TRUE_MSG(filter_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); CHECK_NULL_RETURN(filter_tensor); size_t in_channel = filter_tensor->Channel(); size_t out_channel = filter_tensor->Batch(); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.cc index 7da82796230..d036b562a80 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.cc @@ -44,6 +44,7 @@ int ConvolutionSWCPUKernel::Prepare() { if (op_parameter_->is_train_session_) { auto filter_tensor = in_tensors_.at(kWeightIndex); CHECK_NULL_RETURN(filter_tensor); + MS_CHECK_TRUE_MSG(filter_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); auto input_channel = filter_tensor->Channel(); auto output_channel = filter_tensor->Batch(); int kernel_h = filter_tensor->Height(); diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.cc index 55fa84f4b97..fce83c5e929 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.cc @@ -124,6 +124,7 @@ int ConvolutionWinogradBaseCPUKernel::Prepare() { conv_param_->output_unit_ = output_unit_; if (op_parameter_->is_train_session_) { auto filter_tensor = in_tensors_.at(kWeightIndex); + MS_CHECK_TRUE_MSG(filter_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D."); CHECK_NULL_RETURN(filter_tensor); int in_channel = filter_tensor->Channel(); int out_channel = filter_tensor->Batch();