forked from mindspore-Ecosystem/mindspore
!6055 [MSLITE] Fix bug of several quantized operators inference.
Merge pull request !6055 from wangshaocong/lite_bugfix
This commit is contained in:
commit
6240189190
|
@ -79,7 +79,12 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
int thread_offset = task_id * thread_n_stride_;
|
||||
auto quant_arg = in_tensors_.front()->GetQuantParams().front();
|
||||
if (in_tensors_.front()->GetQuantParams().empty() && out_tensors_.front()->GetQuantParams().empty()) {
|
||||
MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto quant_arg = !in_tensors_.front()->GetQuantParams().empty() ? in_tensors_.front()->GetQuantParams().front() :
|
||||
out_tensors_.front()->GetQuantParams().front();
|
||||
int ret;
|
||||
if (inverse_) {
|
||||
ret = DoDequantizeInt8(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,
|
||||
|
|
|
@ -92,17 +92,10 @@ int QuantizedAddCPUKernel::Run() {
|
|||
input0_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
input1_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
|
||||
ArithmeticParameter tile_para;
|
||||
tile_para.ndim_ = out_tensors_.at(0)->shape().size();
|
||||
for (size_t i = 0; i < tile_para.ndim_; i++) {
|
||||
tile_para.in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
|
||||
tile_para.in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
|
||||
tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
|
||||
}
|
||||
TileDimensionsUint8(static_cast<uint8_t *>(in_tensors_.at(0)->MutableData()),
|
||||
static_cast<uint8_t *>(in_tensors_.at(1)->MutableData()),
|
||||
reinterpret_cast<uint8_t *>(input0_data_), reinterpret_cast<uint8_t *>(input1_data_),
|
||||
&tile_para);
|
||||
arith_para_);
|
||||
ret = ParallelLaunch(THREAD_POOL_DEFAULT, AddInt8Run, this, thread_count_);
|
||||
ctx_->allocator->Free(input0_data_);
|
||||
ctx_->allocator->Free(input1_data_);
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/int8/add_int8.h"
|
||||
#include "nnacl/arithmetic_common.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
@ -27,7 +28,9 @@ class QuantizedAddCPUKernel : public LiteKernel {
|
|||
explicit QuantizedAddCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx_->thread_num_) {}
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx_->thread_num_) {
|
||||
arith_para_ = reinterpret_cast<ArithmeticParameter *>(parameter);
|
||||
}
|
||||
~QuantizedAddCPUKernel() override {}
|
||||
|
||||
int Init() override;
|
||||
|
@ -38,6 +41,7 @@ class QuantizedAddCPUKernel : public LiteKernel {
|
|||
private:
|
||||
const lite::Context *ctx_;
|
||||
AddQuantParameter para_;
|
||||
ArithmeticParameter *arith_para_;
|
||||
int thread_count_;
|
||||
int64_t elements_num_;
|
||||
int64_t count_unit_;
|
||||
|
|
|
@ -91,7 +91,7 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
|
|||
MS_LOG(ERROR) << "tensor_data is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = memcpy_s(tensor_data, size * sizeof(float), tensor->MutableData(), size * sizeof(float));
|
||||
auto ret = memcpy_s(tensor_data, tensor->Size(), tensor->MutableData(), tensor->Size());
|
||||
if (ret != EOK) {
|
||||
delete[] tensor_data;
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
|
@ -234,6 +234,9 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
|||
return nullptr;
|
||||
}
|
||||
lite::Context context;
|
||||
if (context.allocator == nullptr) {
|
||||
context.allocator = lite::Allocator::Create();
|
||||
}
|
||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, &context, lite_primitive.get());
|
||||
if (lite_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
||||
|
|
Loading…
Reference in New Issue