!19926 [MS][LITE]Fix matmul and convolution delegate

Merge pull request !19926 from gongdaguo/r1.3_fix_matmul
This commit is contained in:
i-robot 2021-07-12 05:05:08 +00:00 committed by Gitee
commit 3bf1381e4b
6 changed files with 12 additions and 4 deletions

View File

@ -39,10 +39,12 @@ void ConvolutionDelegateFP16CPUKernel::FreeCopiedData() {
if ((origin_weight_ != nullptr) && (need_free_ & WEIGHT_NEED_FREE)) {
free(origin_weight_);
origin_weight_ = nullptr;
need_free_ = need_free_ & ~WEIGHT_NEED_FREE;
}
if ((origin_bias_ != nullptr) && (need_free_ & BIAS_NEED_FREE)) {
free(origin_bias_);
origin_bias_ = nullptr;
need_free_ = need_free_ & ~BIAS_NEED_FREE;
}
}

View File

@ -22,8 +22,8 @@
#include "nnacl/conv_parameter.h"
#include "nnacl/op_base.h"
#define WEIGHT_NEED_FREE 0b01
#define BIAS_NEED_FREE 0b10
#define WEIGHT_NEED_FREE 0001
#define BIAS_NEED_FREE 0010
namespace mindspore::kernel {
class ConvolutionDelegateFP16CPUKernel : public InnerKernel {

View File

@ -36,6 +36,10 @@ int MatmulBaseFP16Run(void *cdata, int task_id, float lhs_scale, float rhs_scale
}
MatmulBaseFP16CPUKernel::~MatmulBaseFP16CPUKernel() {
if (src_b_ != nullptr) {
free(src_b_);
src_b_ = nullptr;
}
if (bias_ptr_ != nullptr) {
free(bias_ptr_);
bias_ptr_ = nullptr;

View File

@ -72,13 +72,14 @@ class ConvolutionDelegateCPUKernel : public InnerKernel {
if (origin_weight_ != nullptr && need_free_weight_) {
free(origin_weight_);
origin_weight_ = nullptr;
need_free_weight_ = false;
}
if (origin_bias_ != nullptr && need_free_bias_) {
free(origin_bias_);
origin_bias_ = nullptr;
need_free_bias_ = false;
}
}
// Train API
int Eval() override {
InnerKernel::Eval();

View File

@ -35,6 +35,7 @@ MatmulFp32BaseCPUKernel::~MatmulFp32BaseCPUKernel() {
FreeResizeBufA();
FreeResizeBufB();
FreeBiasBuf();
FreeBuffSrcB();
}
void MatmulFp32BaseCPUKernel::InitParameter() {

View File

@ -203,7 +203,7 @@ size_t Tensor::Size() const {
size_t element_size = DataTypeSize(this->data_type_);
auto element_num = (format_ == mindspore::NC4HW4 || format_ == mindspore::NHWC4) ? ElementsC4Num() : ElementsNum();
if (element_num < 0) {
MS_LOG(ERROR) << "Element number of tensor should large than 0 : " << element_num;
MS_LOG(INFO) << "Element number of tensor should large than 0 : " << element_num;
return 0;
}
return element_size * element_num;