!43342 [lite]codex-bug fix

Merge pull request !43342 from 徐安越/master6
This commit is contained in:
i-robot 2022-11-08 01:51:27 +00:00 committed by Gitee
commit 5c797a41d9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
25 changed files with 74 additions and 8 deletions

View File

@ -17,6 +17,7 @@
#include <iostream>
#include <cstring>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "include/api/status.h"
@ -28,14 +29,16 @@ namespace lite {
namespace {
constexpr int kNumPrintOfOutData = 20;
Status FillInputData(const std::vector<mindspore::MSTensor> &inputs) {
std::mt19937 random_engine;
for (auto tensor : inputs) {
auto input_data = tensor.MutableData();
if (input_data == nullptr) {
std::cerr << "MallocData for inTensor failed.\n";
return kLiteError;
}
std::vector<float> temp(tensor.ElementNum(), 1.0f);
memcpy(input_data, temp.data(), tensor.DataSize());
auto distribution = std::uniform_real_distribution<float>(1.0f, 1.0f);
(void)std::generate_n(static_cast<float *>(input_data), tensor.ElementNum(),
[&]() { return distribution(random_engine); });
}
return kSuccess;
}

View File

@ -109,8 +109,12 @@ int ConvolutionBaseCPUKernel::Prepare() {
auto input = this->in_tensors_.front();
auto output = this->out_tensors_.front();
CHECK_NULL_RETURN(input);
CHECK_NULL_RETURN(in_tensors_[1]);
CHECK_NULL_RETURN(output);
CHECK_NULL_RETURN(conv_param_);
MS_CHECK_TRUE_MSG(input->shape().size() == C4NUM, RET_ERROR, "Conv-like: input-shape should be 4D.");
MS_CHECK_TRUE_MSG(in_tensors_[1]->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
MS_CHECK_TRUE_MSG(output->shape().size() == C4NUM, RET_ERROR, "Conv-like: out-shape should be 4D.");
conv_param_->input_batch_ = input->Batch();
conv_param_->input_h_ = input->Height();
conv_param_->input_w_ = input->Width();
@ -494,7 +498,8 @@ void ConvolutionBaseCPUKernel::UpdateOriginWeightAndBias() {
if (in_tensors_.at(kWeightIndex)->data() != nullptr) {
origin_weight_ = in_tensors_.at(kWeightIndex)->data();
}
if (in_tensors_.size() == kInputSize2 && in_tensors_.at(kBiasIndex)->data() != nullptr) {
if (in_tensors_.size() == kInputSize2 && in_tensors_.at(kBiasIndex) != nullptr &&
in_tensors_.at(kBiasIndex)->data() != nullptr) {
origin_bias_ = in_tensors_.at(kBiasIndex)->data();
}
}

View File

@ -38,6 +38,8 @@ namespace mindspore::kernel {
int ActivationFp16CPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), 1);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
CHECK_NULL_RETURN(in_tensors_[0]);
CHECK_NULL_RETURN(out_tensors_[0]);
return RET_OK;
}

View File

@ -36,7 +36,14 @@ int AddNLaunch(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
}
} // namespace
int AddNFp16CPUKernel::Prepare() { return RET_OK; }
int AddNFp16CPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
CHECK_NULL_RETURN(in_tensors_[0]);
CHECK_NULL_RETURN(in_tensors_[1]);
CHECK_NULL_RETURN(out_tensors_[0]);
return RET_OK;
}
int AddNFp16CPUKernel::ReSize() { return RET_OK; }

View File

@ -68,6 +68,9 @@ ArithmeticCompareOptFuncFp16 GetOptimizedArithmeticCompareFun(int primitive_type
int ArithmeticCompareFP16CPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), 2);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
CHECK_NULL_RETURN(in_tensors_[0]);
CHECK_NULL_RETURN(in_tensors_[1]);
CHECK_NULL_RETURN(out_tensors_[0]);
if (!InferShapeDone()) {
return RET_OK;
}

View File

@ -118,6 +118,9 @@ int BiasAddCPUFp16Kernel::GetBiasData() {
int BiasAddCPUFp16Kernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), 2);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
CHECK_NULL_RETURN(in_tensors_[0]);
CHECK_NULL_RETURN(in_tensors_[1]);
CHECK_NULL_RETURN(out_tensors_[0]);
bias_tensor_ = in_tensors_.at(1);
CHECK_NULL_RETURN(bias_tensor_);
if (!InferShapeDone()) {

View File

@ -22,10 +22,12 @@ using mindspore::lite::RET_OK;
namespace mindspore::kernel {
float16_t *ConvertInputFp32toFp16(lite::Tensor *input, const lite::InnerContext *ctx) {
MS_CHECK_TRUE_MSG(input != nullptr, nullptr, "input must be not a nullptr.");
float16_t *fp16_data = nullptr;
auto data_type = input->data_type();
if (data_type == kNumberTypeFloat32) {
auto ele_num = input->ElementsNum();
MS_CHECK_TRUE_MSG(ctx != nullptr, nullptr, "ctx must be not a nullptr.");
fp16_data = reinterpret_cast<float16_t *>(ctx->allocator->Malloc(ele_num * sizeof(float16_t)));
if (fp16_data == nullptr) {
MS_LOG(ERROR) << "malloc fp16_data failed.";
@ -40,10 +42,12 @@ float16_t *ConvertInputFp32toFp16(lite::Tensor *input, const lite::InnerContext
}
float16_t *MallocOutputFp16(lite::Tensor *output, const lite::InnerContext *ctx) {
MS_CHECK_TRUE_MSG(output != nullptr, nullptr, "output must be not as nullptr.");
float16_t *fp16_data = nullptr;
auto data_type = output->data_type();
if (data_type == kNumberTypeFloat32) {
auto ele_num = output->ElementsNum();
MS_CHECK_TRUE_MSG(ctx != nullptr, nullptr, "ctx must be not a nullptr.");
fp16_data = reinterpret_cast<float16_t *>(ctx->allocator->Malloc(ele_num * sizeof(float16_t)));
if (fp16_data == nullptr) {
MS_LOG(ERROR) << "malloc fp16_data failed.";
@ -56,9 +60,11 @@ float16_t *MallocOutputFp16(lite::Tensor *output, const lite::InnerContext *ctx)
}
int ConvertFp32TensorToFp16(lite::Tensor *tensor, const lite::InnerContext *ctx) {
MS_CHECK_TRUE_MSG(tensor != nullptr, RET_ERROR, "ConvertFp32TensorToFp16 failed, due to the tensor is a nullptr.");
if (tensor->data_type() == TypeId::kNumberTypeFloat16) {
return RET_OK;
}
MS_CHECK_TRUE_MSG(ctx != nullptr, RET_ERROR, "ConvertFp32TensorToFp16 failed, due to the ctx is a nullptr.");
auto fp32_data = tensor->data();
tensor->set_data(nullptr);
tensor->set_data_type(TypeId::kNumberTypeFloat16);

View File

@ -50,6 +50,7 @@ void ConvolutionDelegateFP16CPUKernel::FreeCopiedData() {
}
void *ConvolutionDelegateFP16CPUKernel::CopyData(const lite::Tensor *tensor) {
MS_CHECK_TRUE_MSG(tensor != nullptr, nullptr, "tensor must be not a nullptr.");
auto data_type = tensor->data_type();
if (data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16) {
MS_LOG(ERROR) << "Not supported data type: " << data_type;
@ -68,19 +69,21 @@ void *ConvolutionDelegateFP16CPUKernel::CopyData(const lite::Tensor *tensor) {
int ConvolutionDelegateFP16CPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
auto weight_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(weight_tensor);
if (!InferShapeDone()) {
auto weight_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(weight_tensor);
origin_weight_ = weight_tensor->data() != nullptr ? CopyData(weight_tensor) : nullptr;
need_free_ = need_free_ | WEIGHT_NEED_FREE;
if (in_tensors_.size() == C3NUM) {
CHECK_NULL_RETURN(in_tensors_.at(kBiasIndex));
origin_bias_ = CopyData(in_tensors_.at(kBiasIndex));
need_free_ = need_free_ | BIAS_NEED_FREE;
}
return RET_OK;
}
origin_weight_ = in_tensors_.at(kWeightIndex)->data();
origin_weight_ = weight_tensor->data();
if (in_tensors_.size() == C3NUM) {
CHECK_NULL_RETURN(in_tensors_.at(kBiasIndex));
origin_bias_ = in_tensors_.at(kBiasIndex)->data();
MS_ASSERT(origin_bias_ != nullptr);
}

View File

@ -68,6 +68,7 @@ int ConvolutionDepthwise3x3Fp16CPUKernel::Prepare() {
if (op_parameter_->is_train_session_) {
auto weight_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(weight_tensor);
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
int channel = weight_tensor->Batch();
int c8 = UP_ROUND(channel, C8NUM);
int pack_weight_size = c8 * C12NUM;

View File

@ -65,6 +65,7 @@ int ConvolutionDepthwiseFp16CPUKernel::Prepare() {
if (op_parameter_->is_train_session_) {
auto weight_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(weight_tensor);
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
int channel = weight_tensor->Batch();
int pack_weight_size = channel * weight_tensor->Height() * weight_tensor->Width();
set_workspace_size(pack_weight_size * sizeof(float16_t));

View File

@ -97,6 +97,7 @@ int ConvolutionDepthwiseSWFp16CPUKernel::Prepare() {
if (op_parameter_->is_train_session_) {
auto weight_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(weight_tensor);
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
int OC8 = UP_DIV(weight_tensor->Batch(), C8NUM);
int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width();
set_workspace_size(pack_weight_size * sizeof(float16_t));

View File

@ -104,6 +104,7 @@ int ConvolutionFP16CPUKernel::Prepare() {
if (op_parameter_->is_train_session_) {
auto filter_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(filter_tensor);
MS_CHECK_TRUE_MSG(filter_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
int in_channel = filter_tensor->Channel();
int out_channel = filter_tensor->Batch();
int oc8 = UP_ROUND(out_channel, col_tile_);

View File

@ -175,6 +175,7 @@ int ConvolutionWinogradFP16CPUKernel::Prepare() {
if (op_parameter_->is_train_session_) {
auto weight_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(weight_tensor);
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
int in_channel = weight_tensor->Channel();
int out_channel = weight_tensor->Batch();
int oc_block_num = UP_DIV(out_channel, col_tile_);

View File

@ -110,6 +110,8 @@ int DeconvolutionDepthwiseFp16CPUKernel::Prepare() {
if (op_parameter_->is_train_session_) {
auto weight_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(weight_tensor);
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
int OC8 = UP_DIV(weight_tensor->Batch(), C8NUM);
int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width();
set_workspace_size(pack_weight_size * sizeof(float16_t));

View File

@ -36,6 +36,8 @@ namespace mindspore::kernel {
int ActivationCPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), 1);
CHECK_NOT_EQUAL_RETURN(out_tensors_.size(), 1);
CHECK_NULL_RETURN(in_tensors_[0]);
CHECK_NULL_RETURN(out_tensors_[0]);
if (in_tensors().front()->data_type() == kNumberTypeInt32) {
if (type_ != schema::ActivationType_RELU) {

View File

@ -61,9 +61,16 @@ bool AffineFp32CPUKernel::CheckAffineValid() {
if (in_tensors_.size() < kAffineMinInputNum) {
return false;
}
if (std::any_of(in_tensors_.begin(), in_tensors_.end(), [](const auto &in_tensor) { return in_tensor == nullptr; })) {
return false;
}
if (out_tensors_.size() != kAffineMaxOutputNum) {
return false;
}
if (std::any_of(out_tensors_.begin(), out_tensors_.end(),
[](const auto &out_tensor) { return out_tensor == nullptr; })) {
return false;
}
return true;
}
@ -153,6 +160,7 @@ int AffineFp32CPUKernel::FullRunInit() {
int AffineFp32CPUKernel::IncrementInit() {
auto out_tensor = out_tensors_.at(kOutputIndex);
auto out_shape = out_tensor->shape();
MS_CHECK_TRUE_MSG(out_shape.size() >= C3NUM, RET_ERROR, "Out-shape is invalid, which must be 3D or bigger.");
matmul_col_ = out_shape.at(kInputCol);
matmul_row_ = out_shape.at(kInputRow);
MS_CHECK_INT_MUL_NOT_OVERFLOW(matmul_row_, matmul_col_, RET_ERROR);
@ -279,6 +287,7 @@ kernel::LiteKernel *AffineFp32CPUKernel::FullMatmulKernelCreate() {
kernel::LiteKernel *AffineFp32CPUKernel::IncrementMatmulKernelCreate() {
auto input_shape = in_tensors_.front()->shape();
MS_CHECK_TRUE_MSG(!input_shape.empty(), nullptr, "First input-shape is empty.");
int src_col = input_shape.at(input_shape.size() - 1);
int context_dims = affine_parameter_->context_size_;
int affine_splice_output_col = affine_parameter_->output_dim_;
@ -294,7 +303,9 @@ kernel::LiteKernel *AffineFp32CPUKernel::IncrementMatmulKernelCreate() {
MS_CHECK_TRUE_MSG(increment_input_ != nullptr, nullptr, "Create a new-tensor failed.");
// matmul_output == 1 * matmul_col
int matmul_col = out_tensors_.front()->shape().back();
auto out_shape = out_tensors_.front()->shape();
MS_CHECK_TRUE_MSG(!out_shape.empty(), nullptr, "Out-shape is empty.");
int matmul_col = out_shape.back();
increment_output_ = new (std::nothrow) lite::Tensor(kNumberTypeFloat32, {1, 1, matmul_col});
MS_CHECK_TRUE_MSG(increment_output_ != nullptr, nullptr, "Create a new-tensor failed.");
increment_output_->MallocData();

View File

@ -41,6 +41,9 @@ int BiasAddRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
int BiasCPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
CHECK_NULL_RETURN(in_tensors_[0]);
CHECK_NULL_RETURN(in_tensors_[1]);
CHECK_NULL_RETURN(out_tensors_[0]);
if (!InferShapeDone()) {
return RET_OK;
}

View File

@ -16,6 +16,7 @@
#include "src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.h"
#include "include/errorcode.h"
#include "nnacl/op_base.h"
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
using mindspore::lite::RET_ERROR;
@ -31,6 +32,7 @@ int ConvolutionDepthwise3x3CPUKernel::Prepare() {
if (op_parameter_->is_train_session_) {
auto weight_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(weight_tensor);
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
int channel = weight_tensor->Batch();
int c4 = UP_ROUND(channel, C4NUM);
MS_CHECK_INT_MUL_NOT_OVERFLOW(c4, C12NUM, RET_ERROR);

View File

@ -31,6 +31,7 @@ int ConvolutionDepthwiseCPUKernel::Prepare() {
if (op_parameter_->is_train_session_) {
auto weight_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(weight_tensor);
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_tensor->Height(), weight_tensor->Width(), RET_ERROR);
int weight_size_hw = weight_tensor->Height() * weight_tensor->Width();
MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_tensor->Batch(), weight_size_hw, RET_ERROR);

View File

@ -41,6 +41,7 @@ int ConvolutionDepthwiseIndirectCPUKernel::Prepare() {
if (op_parameter_->is_train_session_) {
auto weight_tensor = in_tensors_[kWeightIndex];
CHECK_NULL_RETURN(weight_tensor);
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
#ifdef ENABLE_AVX
int div_flag = C8NUM;
#else

View File

@ -73,6 +73,8 @@ int ConvolutionDepthwiseSWCPUKernel::Prepare() {
}
if (op_parameter_->is_train_session_) {
auto weight_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(weight_tensor);
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_tensor->Height(), weight_tensor->Width(), RET_ERROR);
int weight_size_hw = weight_tensor->Height() * weight_tensor->Width();

View File

@ -74,6 +74,8 @@ int ConvolutionDepthwiseSWCPUKernelX86::Prepare() {
#endif
if (op_parameter_->is_train_session_) {
auto weight_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(weight_tensor);
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
int oc_algin = UP_DIV(weight_tensor->Batch(), oc_tile_);
MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_tensor->Height(), weight_tensor->Width(), RET_ERROR);
int weight_size_hw = weight_tensor->Height() * weight_tensor->Width();

View File

@ -93,6 +93,7 @@ int ConvolutionCPUKernel::Prepare() {
CHECK_LESS_RETURN(out_tensors_.size(), 1);
if (op_parameter_->is_train_session_) {
auto filter_tensor = in_tensors_.at(kWeightIndex);
MS_CHECK_TRUE_MSG(filter_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
CHECK_NULL_RETURN(filter_tensor);
size_t in_channel = filter_tensor->Channel();
size_t out_channel = filter_tensor->Batch();

View File

@ -44,6 +44,7 @@ int ConvolutionSWCPUKernel::Prepare() {
if (op_parameter_->is_train_session_) {
auto filter_tensor = in_tensors_.at(kWeightIndex);
CHECK_NULL_RETURN(filter_tensor);
MS_CHECK_TRUE_MSG(filter_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
auto input_channel = filter_tensor->Channel();
auto output_channel = filter_tensor->Batch();
int kernel_h = filter_tensor->Height();

View File

@ -124,6 +124,7 @@ int ConvolutionWinogradBaseCPUKernel::Prepare() {
conv_param_->output_unit_ = output_unit_;
if (op_parameter_->is_train_session_) {
auto filter_tensor = in_tensors_.at(kWeightIndex);
MS_CHECK_TRUE_MSG(filter_tensor->shape().size() == C4NUM, RET_ERROR, "Conv-like: weight-shape only support 4D.");
CHECK_NULL_RETURN(filter_tensor);
int in_channel = filter_tensor->Channel();
int out_channel = filter_tensor->Batch();