!5637 fix quant MS model runtime

Merge pull request !5637 from yankai10/merge_0901
This commit is contained in:
mindspore-ci-bot 2020-09-03 11:09:22 +08:00 committed by Gitee
commit a6f8904212
16 changed files with 229 additions and 384 deletions

View File

@ -253,6 +253,16 @@ std::string Tensor::ToString() const {
}
}
} break;
case kNumberTypeInt8: {
auto data = static_cast<int8_t *>(this->data_);
if (data == nullptr) {
return "Data of tensor is nullptr";
} else {
for (int i = 0; i < 40 && i < this->ElementsNum(); i++) {
oss << " " << static_cast<int32_t >(data[i]);
}
}
} break;
default:
oss << "Unsupported data type to print";
break;

View File

@ -15,11 +15,16 @@
*/
#include "src/ops/conv2d.h"
#include <string>
#include <map>
#include <memory>
#include <string>
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#ifdef PRIMITIVE_WRITEABLE
#include <float.h>
#include "tools/converter/quantizer/quantize_util.h"
#endif
@ -156,6 +161,13 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
int channel_mutiplier = 1;
if (prim.GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
@ -213,100 +225,20 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
// attr->padMode = schema::PadMode_SAME;
// attr->activationType = schema::ActivationType_RELU;
primitive->value.type = schema::PrimitiveType_Conv2D;
primitive->value.value = attr.release();
}
void Conv2D::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) {
const float qmin = 0;
const float qmax = 255;
*mMin = static_cast<float>((qmin - mean) / stdDev);
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void Conv2D::PopulaterQuantParam(const Primitive &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim.GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim.GetAttr("num_bits");
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
std::vector<schema::QuantParamT> quants;
schema::QuantParamT quantParam;
auto mean = prim.GetAttr("mean");
auto std_dev = prim.GetAttr("std_dev");
if (mean != nullptr && std_dev != nullptr) {
auto meanQuantOaram = GetValue<double>(mean);
double stddevQuantOaram = GetValue<double>(std_dev);
float mMin = 0.0;
float mMax = 0.0;
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
quantParam.min = mMin;
quantParam.max = mMax;
} else {
auto inputMin = prim.GetAttr("input_minq");
auto inputMax = prim.GetAttr("input_maxq");
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(inputMinPtr->Data());
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
}
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecInputQuantParam->emplace_back(quants);
quants.clear();
int biasQuantSize = 0;
auto filterMin = prim.GetAttr("filter_minq");
auto filterMax = prim.GetAttr("filter_maxq");
if (filterMin != nullptr && filterMax != nullptr) {
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(filterMinPtr->Data());
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
biasQuantSize = filterMinPtr->DataSize();
for (int i = 0; i < biasQuantSize; ++i) {
quantParam.min = *(minBuf++);
quantParam.max = *(maxBuf++);
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
}
vecInputQuantParam->emplace_back(quants);
}
quants.clear();
for (int i = 0; i < biasQuantSize; ++i) {
quantParam.min = 0.0;
quantParam.max = 0.0;
quantParam.zeroPoint = 0;
quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
quants.emplace_back(quantParam);
}
vecInputQuantParam->emplace_back(quants);
quants.clear();
auto outputMin = prim.GetAttr("output_minq");
auto outputMax = prim.GetAttr("output_maxq");
if (outputMin != nullptr && outputMax != nullptr) {
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(outputMinPtr->Data());
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecOutputQuantParam->emplace_back(quants);
}
}
int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;

View File

@ -57,9 +57,6 @@ class Conv2D : public PrimitiveC {
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
const std::vector<AnfNodePtr> &inputs);
void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group);
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
#else
public:

View File

@ -15,6 +15,7 @@
*/
#include "src/ops/depthwise_conv2d.h"
#include <memory>
#include <string>
#ifdef PRIMITIVE_WRITEABLE
@ -69,96 +70,6 @@ void DepthwiseConv2D::SetActivationType(int activation_type) {
this->primitive_->value.AsDepthwiseConv2D()->activationType = (schema::ActivationType)activation_type;
}
void DepthwiseConv2D::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) {
const float qmin = 0;
const float qmax = 255;
*mMin = static_cast<float>((qmin - mean) / stdDev);
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void DepthwiseConv2D::PopulaterQuantParam(const Primitive &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim.GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim.GetAttr("num_bits");
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
std::vector<schema::QuantParamT> quants;
schema::QuantParamT quantParam;
auto mean = prim.GetAttr("mean");
auto std_dev = prim.GetAttr("std_dev");
if (mean != nullptr && std_dev != nullptr) {
auto meanQuantOaram = GetValue<double>(mean);
double stddevQuantOaram = GetValue<double>(std_dev);
float mMin = 0.0;
float mMax = 0.0;
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
quantParam.min = mMin;
quantParam.max = mMax;
} else {
auto inputMin = prim.GetAttr("input_minq");
auto inputMax = prim.GetAttr("input_maxq");
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(inputMinPtr->Data());
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
}
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecInputQuantParam->emplace_back(quants);
quants.clear();
int biasQuantSize = 0;
auto filterMin = prim.GetAttr("filter_minq");
auto filterMax = prim.GetAttr("filter_maxq");
if (filterMin != nullptr && filterMax != nullptr) {
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(filterMinPtr->Data());
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
biasQuantSize = filterMinPtr->DataSize();
for (int i = 0; i < biasQuantSize; ++i) {
quantParam.min = *(minBuf++);
quantParam.max = *(maxBuf++);
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
}
vecInputQuantParam->emplace_back(quants);
}
quants.clear();
for (int i = 0; i < biasQuantSize; ++i) {
quantParam.min = 0.0;
quantParam.max = 0.0;
quantParam.zeroPoint = 0;
quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
quants.emplace_back(quantParam);
}
vecInputQuantParam->emplace_back(quants);
quants.clear();
auto outputMin = prim.GetAttr("output_minq");
auto outputMax = prim.GetAttr("output_maxq");
if (outputMin != nullptr && outputMax != nullptr) {
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(outputMinPtr->Data());
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecOutputQuantParam->emplace_back(quants);
}
}
int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
@ -197,7 +108,14 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
// attr->padMode = schema::PadMode_SAME;
// attr->activationType = schema::ActivationType_RELU;
auto channel_multiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
attr->channelMultiplier = channel_multiplier;

View File

@ -51,10 +51,6 @@ class DepthwiseConv2D : public PrimitiveC {
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
private:
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
#else
public:

View File

@ -30,83 +30,6 @@ bool MatMul::GetTransposeB() const { return this->primitive_->value.AsMatMul()->
void MatMul::SetTransposeA(bool transpose_a) { this->primitive_->value.AsMatMul()->transposeA = transpose_a; }
void MatMul::SetTransposeB(bool transpose_b) { this->primitive_->value.AsMatMul()->transposeB = transpose_b; }
void MatMul::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) {
const float qmin = 0;
const float qmax = 255;
*mMin = static_cast<float>((qmin - mean) / stdDev);
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void MatMul::PopulaterQuantParam(const Primitive &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim.GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim.GetAttr("num_bits");
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
std::vector<schema::QuantParamT> quants;
schema::QuantParamT quantParam;
auto mean = prim.GetAttr("mean");
auto std_dev = prim.GetAttr("std_dev");
if (mean != nullptr && std_dev != nullptr) {
auto meanQuantOaram = GetValue<double>(mean);
double stddevQuantOaram = GetValue<double>(std_dev);
float mMin = 0.0;
float mMax = 0.0;
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
quantParam.min = mMin;
quantParam.max = mMax;
} else {
auto inputMin = prim.GetAttr("input_minq");
auto inputMax = prim.GetAttr("input_maxq");
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(inputMinPtr->Data());
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
}
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecInputQuantParam->emplace_back(quants);
quants.clear();
auto filterMin = prim.GetAttr("filter_minq");
auto filterMax = prim.GetAttr("filter_maxq");
if (filterMin != nullptr && filterMax != nullptr) {
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(filterMinPtr->Data());
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
for (int i = 0; i < filterMinPtr->DataSize(); ++i) {
quantParam.min = *(minBuf++);
quantParam.max = *(maxBuf++);
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
}
vecInputQuantParam->emplace_back(quants);
}
quants.clear();
auto outputMin = prim.GetAttr("output_minq");
auto outputMax = prim.GetAttr("output_maxq");
if (outputMin != nullptr && outputMax != nullptr) {
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(outputMinPtr->Data());
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecOutputQuantParam->emplace_back(quants);
}
}
int MatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;

View File

@ -36,10 +36,6 @@ class MatMul : public PrimitiveC {
void SetTransposeA(bool transpose_a);
void SetTransposeB(bool transpose_b);
private:
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
#else
public:

View File

@ -16,6 +16,7 @@
#include "src/ops/primitive_c.h"
#include <memory>
#include <map>
#include "src/ops/space_to_batch.h"
#include "src/ops/space_to_batch_nd.h"
#include "src/ops/conv2d.h"
@ -121,10 +122,99 @@
#include "src/ops/l2_norm.h"
#include "src/ops/sparse_to_dense.h"
#include "src/ops/detection_post_process.h"
#ifdef PRIMITIVE_WRITEABLE
#include "tools/converter/quantizer/quantize_util.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
void PrimitiveC::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) {
const float qmin = 0;
const float qmax = 255;
*mMin = static_cast<float>((qmin - mean) / stdDev);
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim.GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim.GetAttr("num_bits");
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
std::vector<schema::QuantParamT> quants;
schema::QuantParamT quantParam;
auto mean = prim.GetAttr("mean");
auto std_dev = prim.GetAttr("std_dev");
if (mean != nullptr && std_dev != nullptr) {
auto meanQuantOaram = GetValue<double>(mean);
double stddevQuantOaram = GetValue<double>(std_dev);
float mMin = 0.0;
float mMax = 0.0;
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
quantParam.min = mMin;
quantParam.max = mMax;
} else {
auto inputMin = prim.GetAttr("input_minq");
auto inputMax = prim.GetAttr("input_maxq");
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(inputMinPtr->Data());
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
}
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecInputQuantParam->emplace_back(quants);
quants.clear();
auto filterMin = prim.GetAttr("filter_minq");
auto filterMax = prim.GetAttr("filter_maxq");
if (filterMin != nullptr && filterMax != nullptr) {
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(filterMinPtr->Data());
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
quantParam.min = FLT_MAX;
quantParam.max = FLT_MIN;
for (int i = 0; i < filterMinPtr->DataSize(); ++i) {
quantParam.min = (*(minBuf) < quantParam.min) ? (*minBuf) : quantParam.min;
quantParam.max = (*(maxBuf) > quantParam.max) ? (*maxBuf) : quantParam.max;
minBuf++;
maxBuf++;
}
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, true, numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecInputQuantParam->emplace_back(quants);
}
quants.clear();
quantParam.min = 0.0;
quantParam.max = 0.0;
quantParam.zeroPoint = 0;
quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale;
quants.emplace_back(quantParam);
vecInputQuantParam->emplace_back(quants);
quants.clear();
auto outputMin = prim.GetAttr("output_minq");
auto outputMax = prim.GetAttr("output_maxq");
if (outputMin != nullptr && outputMax != nullptr) {
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(outputMinPtr->Data());
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecOutputQuantParam->emplace_back(quants);
}
}
schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; }
void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; }
@ -152,7 +242,7 @@ void PrimitiveC::AddOutputQuantParam(std::vector<schema::QuantParamT> quant_para
}
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::GetOutputQuantParams() const { return output_quant_param_; }
void PrimitiveC::SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; }
void PrimitiveC::SetQuantType(const schema::QuantType &quant_type) { this->quant_type_ = quant_type; }
schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_; }
@ -205,12 +295,14 @@ std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() {
}
template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>>
std::shared_ptr<PrimitiveC> NewPrimitiveC(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
std::shared_ptr<PrimitiveC> NewPrimitiveC(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
const schema::QuantType &quantType) {
auto primc = std::make_shared<T>();
if (primc == nullptr) {
MS_LOG(ERROR) << "make_shared PrimitiveC failed";
return nullptr;
}
primc->SetQuantType(quantType);
auto ret = primc->UnPackAttr(prim, inputs);
if (ret != RET_OK) {
MS_LOG(ERROR) << "UnPackAttr failed";
@ -220,46 +312,47 @@ std::shared_ptr<PrimitiveC> NewPrimitiveC(const Primitive &prim, const std::vect
}
std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &prim,
const std::vector<AnfNodePtr> &inputs) {
const std::vector<AnfNodePtr> &inputs,
const schema::QuantType &quantType) {
const auto &op_type = prim.name();
if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid") {
return NewPrimitiveC<Activation>(prim, inputs);
return NewPrimitiveC<Activation>(prim, inputs, quantType);
} else if (op_type == "BatchNorm") {
return NewPrimitiveC<BatchNorm>(prim, inputs);
return NewPrimitiveC<BatchNorm>(prim, inputs, quantType);
} else if (op_type == "BiasAdd") {
return NewPrimitiveC<BiasAdd>(prim, inputs);
return NewPrimitiveC<BiasAdd>(prim, inputs, quantType);
} else if (op_type == "Concat") {
return NewPrimitiveC<Concat>(prim, inputs);
return NewPrimitiveC<Concat>(prim, inputs, quantType);
} else if (op_type == "Conv2D") {
return NewPrimitiveC<Conv2D>(prim, inputs);
return NewPrimitiveC<Conv2D>(prim, inputs, quantType);
} else if (op_type == "DepthwiseConv2dNative" || op_type == "DepthwiseConv2D") {
return NewPrimitiveC<DepthwiseConv2D>(prim, inputs);
return NewPrimitiveC<DepthwiseConv2D>(prim, inputs, quantType);
} else if (op_type == "Dequant") {
return NewPrimitiveC<Dequant>(prim, inputs);
return NewPrimitiveC<Dequant>(prim, inputs, quantType);
} else if (op_type == "Flatten") {
return NewPrimitiveC<Flatten>(prim, inputs);
return NewPrimitiveC<Flatten>(prim, inputs, quantType);
} else if (op_type == "make_tuple") {
return NewPrimitiveC<MakeTuple>(prim, inputs);
return NewPrimitiveC<MakeTuple>(prim, inputs, quantType);
} else if (op_type == "MatMul") {
return NewPrimitiveC<MatMul>(prim, inputs);
return NewPrimitiveC<MatMul>(prim, inputs, quantType);
} else if (op_type == "Mul") {
return NewPrimitiveC<Mul>(prim, inputs);
return NewPrimitiveC<Mul>(prim, inputs, quantType);
} else if (op_type == "MaxPool") {
return NewPrimitiveC<Pooling>(prim, inputs);
return NewPrimitiveC<Pooling>(prim, inputs, quantType);
} else if (op_type == "Quant") {
return NewPrimitiveC<Quant>(prim, inputs);
return NewPrimitiveC<Quant>(prim, inputs, quantType);
} else if (op_type == "ReduceMean") {
return NewPrimitiveC<Reduce>(prim, inputs);
return NewPrimitiveC<Reduce>(prim, inputs, quantType);
} else if (op_type == "Reshape") {
return NewPrimitiveC<Reshape>(prim, inputs);
return NewPrimitiveC<Reshape>(prim, inputs, quantType);
} else if (op_type == "TensorAdd") {
return NewPrimitiveC<Add>(prim, inputs);
return NewPrimitiveC<Add>(prim, inputs, quantType);
} else if (op_type == "Transpose") {
return NewPrimitiveC<Transpose>(prim, inputs);
return NewPrimitiveC<Transpose>(prim, inputs, quantType);
} else if (op_type == "tuple_getitem") {
return NewPrimitiveC<TupleGetItem>(prim, inputs);
return NewPrimitiveC<TupleGetItem>(prim, inputs, quantType);
} else if (op_type == "Softmax") {
return NewPrimitiveC<SoftMax>(prim, inputs);
return NewPrimitiveC<SoftMax>(prim, inputs, quantType);
} else {
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type;
return nullptr;

View File

@ -20,6 +20,7 @@
#include <set>
#include <vector>
#include <memory>
#include <map>
#ifdef PRIMITIVE_WRITEABLE
#include "ir/primitive.h"
#include "schema/inner/model_generated.h"
@ -44,6 +45,9 @@ const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNum
constexpr int kAnfPopulaterOne = 1;
constexpr int kAnfPopulaterTwo = 2;
constexpr int kAnfPopulaterThree = 3;
static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU", schema::ActivationType_RELU},
{"ReLU6", schema::ActivationType_RELU6},
{"Sigmoid", schema::ActivationType_SIGMOID}};
class PrimitiveC : public mindspore::Primitive {
public:
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().
@ -94,7 +98,7 @@ class PrimitiveC : public mindspore::Primitive {
std::vector<std::vector<schema::QuantParamT>> GetOutputQuantParams() const;
void SetQuantType(schema::QuantType quant_type);
void SetQuantType(const schema::QuantType &quant_type);
schema::QuantType GetQuantType() const;
@ -110,7 +114,11 @@ class PrimitiveC : public mindspore::Primitive {
static PrimitiveC *UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT *primitive);
static std::shared_ptr<PrimitiveC> UnPackFromPrimitive(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
static std::shared_ptr<PrimitiveC> UnPackFromPrimitive(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
const schema::QuantType &quantType);
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
protected:
virtual int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_ERROR; }

View File

@ -71,8 +71,8 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
auto input_quant_params = primitive->GetInputQuantParams();
auto node_type = (schema::PrimitiveType)primitive->Type();
if (input_quant_params.empty()) {
MS_LOG(ERROR) << "node: " << dst_node->name << " input quant params is empty";
return RET_ERROR;
MS_LOG(WARNING) << "node: " << dst_node->name << " input quant params is empty";
return RET_OK;
}
for (size_t i = 0; i < input_quant_params.size(); i++) {
if (i >= dst_node->inputIndex.size()) {

View File

@ -473,7 +473,7 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
}
inputs.push_back(anfnode_build_map_[input_name]);
}
auto primitivec_ptr = PrimitiveC::UnPackFromPrimitive(*prim, inputs);
auto primitivec_ptr = PrimitiveC::UnPackFromPrimitive(*prim, inputs, quantType);
if (primitivec_ptr == nullptr) {
MS_LOG(ERROR) << "Create PrimitiveC return nullptr, " << prim->name();
return nullptr;

View File

@ -50,7 +50,8 @@ static const std::vector<schema::PrimitiveType> int8OpList = {
schema::PrimitiveType_Mul, schema::PrimitiveType_Slice,
schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split,
schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub,
schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze};
schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze,
schema::PrimitiveType_MatMul};
static const std::vector<schema::PrimitiveType> needInsertOpList = {
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation,

View File

@ -155,11 +155,11 @@ STATUS WeightFormatHardCodePass::HardCodeMS(const std::unique_ptr<CNodeT> &node,
switch (this->quantType) {
case QuantType_AwareTraining: {
if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_HWCK;
weightTensor->format = schema::Format_KCHW;
} else if (opType == PrimitiveType_DepthwiseConv2D) {
weightTensor->format = Format_CKHW;
} else {
weightTensor->format = schema::Format_HWKC;
weightTensor->format = schema::Format_KCHW;
}
} break;
case QuantType_QUANT_NONE: {

View File

@ -36,15 +36,13 @@ using std::vector;
namespace mindspore::lite::quant {
const std::array<schema::PrimitiveType, 7> AwareQuantizer::propagatedOps = {
{schema::PrimitiveType_Concat, schema::PrimitiveType_Resize,
schema::PrimitiveType_Reshape, schema::PrimitiveType_Squeeze,
schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation,
schema::PrimitiveType_DetectionPostProcess}};
{schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape,
schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation,
schema::PrimitiveType_DetectionPostProcess}};
STATUS InputArray::InitQuantParam() {
this->quantParam = std::make_unique<schema::QuantParamT>();
auto status = CalQuantizationParams(quantParam.get(), mMin, mMax,
narrowRange, numBits);
auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, narrowRange, numBits);
if (status != RET_OK) {
return status;
}
@ -58,8 +56,7 @@ STATUS InputArray::SetInputArrayQP(schema::MetaGraphT *graph, size_t inputTensor
if (!tensor->quantParams.empty()) {
auto param = GetTensorQuantParam(tensor);
if (param != nullptr && param->inited) {
MS_LOG(DEBUG) << "tensor " << inputTensorIdx
<< " already has quantParam";
MS_LOG(DEBUG) << "tensor " << inputTensorIdx << " already has quantParam";
return RET_OK;
}
tensor->quantParams.clear();
@ -74,9 +71,7 @@ STATUS InputArray::SetInputArrayQP(schema::MetaGraphT *graph, size_t inputTensor
return RET_OK;
}
AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph,
const string &inputInferType,
const string &stdValues,
AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues,
const string &meanValues)
: FbQuantizer(graph) {
MS_ASSERT(graph != nullptr);
@ -94,12 +89,9 @@ AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph,
mInputArray->InitQuantParam();
}
STATUS AwareQuantizer::RemoveFakeQuant() {
return RET_OK;
}
STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; }
STATUS AwareQuantizer::GenerateDefaultQuantParam(
const schema::MetaGraphT *subGraph) {
STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) {
MS_ASSERT(subGraph != nullptr);
for (const auto &tensor : subGraph->allTensors) {
if (!tensor->quantParams.empty()) {
@ -111,8 +103,7 @@ STATUS AwareQuantizer::GenerateDefaultQuantParam(
return RET_OK;
}
STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph,
schema::CNodeT *node) {
STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
// MS_ASSERT(subGraph != nullptr);
// MS_ASSERT(node != nullptr);
// auto inputIndexes = node->inputIndex;
@ -193,19 +184,15 @@ STATUS AwareQuantizer::GenerateQuantParam() {
GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
MS_ASSERT(false);
}
auto *quantParamCalcer =
quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
auto *quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
if (quantParamCalcer == nullptr) {
MS_LOG(ERROR) << "Can not find QuantParamCalcer for "
<< node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str()
<< " set node to QuantNone and skip";
MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
} else {
status = quantParamCalcer->Calc(graph, *node);
if (status != RET_OK) {
MS_LOG(ERROR) << "quantParamCalcer failed: " << status
<< " node: " << node->name.c_str();
MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
node->quantType = schema::QuantType_QUANT_NONE;
} else {
node->quantType = schema::QuantType_AwareTraining;
@ -227,11 +214,11 @@ STATUS AwareQuantizer::DoQuantize() {
STATUS status;
if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D ||
GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D ||
GetCNodeTType(*node) == schema::PrimitiveType_FullConnection) {
GetCNodeTType(*node) == schema::PrimitiveType_FullConnection ||
GetCNodeTType(*node) == schema::PrimitiveType_MatMul) {
auto inputIndexes = node->inputIndex;
if (inputIndexes.size() < 2) {
MS_LOG(ERROR) << node->name.c_str()
<< " node input has invalid inputs tensor count";
MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count";
return RET_ERROR;
}
// quant weight
@ -248,8 +235,7 @@ STATUS AwareQuantizer::DoQuantize() {
return RET_ERROR;
}
}
} else if (GetCNodeTType(*node) ==
schema::PrimitiveType_DetectionPostProcess) {
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
status = QuantDetectionPostProcessConstTensor(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!";
@ -275,8 +261,7 @@ STATUS AwareQuantizer::DoQuantize() {
return RET_OK;
}
STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph,
schema::CNodeT *node) {
STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(node != nullptr);
for (size_t i = 0; i < node->inputIndex.size(); i++) {
@ -295,8 +280,7 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph,
void *inData = inTensor->data.data();
auto *castedInData = static_cast<float *>(inData);
for (size_t j = 0; j < constTensorShapeSize; j++) {
qDatas[j] =
QuantizeData<uint8_t>(castedInData[j], quantParam.get());
qDatas[j] = QuantizeData<uint8_t>(castedInData[j], quantParam.get());
}
inTensor->data = std::move(qDatas);
inTensor->dataType = kNumberTypeUInt8;
@ -312,17 +296,14 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph,
return RET_OK;
}
STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(
const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
MS_ASSERT(subGraph != nullptr);
MS_ASSERT(node != nullptr);
auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]);
MS_ASSERT(constTensor != nullptr);
const auto *constData =
reinterpret_cast<const float *>(constTensor->data.data());
const auto *constData = reinterpret_cast<const float *>(constTensor->data.data());
if (constTensor->refCount == 999 &&
constTensor->dataType == TypeId::kNumberTypeFloat) {
if (constTensor->nodeType == schema::NodeType_ValueNode && constTensor->dataType == TypeId::kNumberTypeFloat) {
size_t constTensorShapeSize = GetShapeSize(*constTensor);
std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor);
if (quantParam == nullptr) {
@ -340,8 +321,7 @@ STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(
return RET_OK;
}
STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph,
mindspore::schema::CNodeT *node) {
STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(node != nullptr);
auto inputIndexes = node->inputIndex;
@ -351,15 +331,10 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph,
MS_ASSERT(graph->allTensors.size() > inputIndexes.at(2));
auto &biasTensor = graph->allTensors.at(inputIndexes.at(2));
MS_ASSERT(biasTensor != nullptr);
if (biasTensor->dataType != TypeId::kNumberTypeFloat) {
// MS_LOGD("conv %s's bias data is not float", node->name.c_str());
return RET_OK;
}
if (biasTensor->dataType == TypeId::kNumberTypeInt32) {
return RET_OK;
}
if (biasTensor->dataType != TypeId::kNumberTypeFloat) {
if (biasTensor->dataType != TypeId::kNumberTypeFloat && biasTensor->dataType != TypeId::kNumberTypeFloat32) {
// MS_LOGE("conv %s's bias data is not float", node->name.c_str());
return RET_ERROR;
}
@ -400,8 +375,8 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph,
biasTensor->dataType = TypeId::kNumberTypeInt32;
biasTensor->data.clear();
biasTensor->data.resize(bShapeSize * sizeof(int32_t));
auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t),
qDatas.get(), bShapeSize * sizeof(int32_t));
auto ret =
memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed: " << ret;
return RET_ERROR;
@ -409,12 +384,10 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph,
return RET_OK;
}
STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph,
schema::CNodeT *node) {
STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
MS_ASSERT(subGraph != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(node->quantParam.size() ==
node->inputIndex.size() + node->outputIndex.size());
MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size());
auto inputIndexes = node->inputIndex;
MS_ASSERT(inputIndexes.size() >= 2);
MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1));
@ -422,11 +395,9 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph,
if (weightTensor->dataType == TypeId::kNumberTypeInt8) {
return RET_OK;
}
if (weightTensor->dataType != TypeId::kNumberTypeFloat32 &&
weightTensor->dataType != TypeId::kNumberTypeFloat &&
if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat &&
weightTensor->dataType != TypeId::kNumberTypeUInt8) {
MS_LOG(ERROR) << "conv " << node->name.c_str()
<< "'s weight data is not float or uint8";
MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8";
return RET_ERROR;
}
size_t wShapeSize = GetShapeSize(*(weightTensor.get()));
@ -434,8 +405,8 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph,
MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr);
vector<int8_t> qDatas(wShapeSize);
auto weightQauntParam = GetTensorQuantParam(weightTensor);
if (weightTensor->dataType ==
TypeId::kNumberTypeFloat) { // normal awareing quant
if (weightTensor->dataType == TypeId::kNumberTypeFloat ||
weightTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant
auto *weightData = static_cast<float *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
@ -463,8 +434,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
MS_ASSERT(graph->allTensors.size() > inTensorIdx);
auto &inTensor = graph->allTensors.at(inTensorIdx);
MS_ASSERT(inTensor != nullptr);
if (inTensor->quantParams.empty() ||
inTensor->quantParams.front() == nullptr ||
if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr ||
!inTensor->quantParams.front()->inited) {
canQuant = false;
break;
@ -476,8 +446,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
MS_ASSERT(graph->allTensors.size() > outTensorIdx);
auto &outTensor = graph->allTensors.at(outTensorIdx);
MS_ASSERT(outTensor != nullptr);
if (outTensor->quantParams.empty() ||
outTensor->quantParams.front() == nullptr ||
if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr ||
!outTensor->quantParams.front()->inited) {
canQuant = false;
break;

View File

@ -13,13 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/lite/tools/converter/quantizer/quantize_util.h"
#include <cmath>
#include <string>
#include <algorithm>
#include <memory>
#include <vector>
#include "src/ops/primitive_c.h"
#include "mindspore/lite/tools/converter/quantizer/quantize_util.h"
#include "mindspore/lite/tools/converter/quantizer/general_bitpacking.h"
#include "src/common/utils.h"
#include "abstract/abstract_value.h"
@ -32,7 +33,7 @@ namespace mindspore {
namespace lite {
namespace quant {
const std::array<std::string, 4> QuantStrategy::mConvTypes = {
{"Conv2D", "DeConv2D", "DepthwiseConv2D", "DeDepthwiseConv2D"}};
{"Conv2D", "DeConv2D", "DepthwiseConv2D", "DeDepthwiseConv2D"}};
const std::array<std::string, 4> QuantStrategy::mMulTypes = {{"Mul", "MatMul", "BatchMatMul", "FullConnection"}};
QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold)
@ -99,10 +100,9 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/
schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/
schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection,
schema::PrimitiveType_MatMul,
schema::PrimitiveType_Activation};
schema::PrimitiveType_MatMul, schema::PrimitiveType_Activation};
return IsContain(uint8OpList, type);
}
@ -164,8 +164,8 @@ bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
return true;
}
STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange,
int quant_max, int quant_min, int num_bits) {
STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max,
int quant_min, int num_bits) {
MS_ASSERT(quantParam != nullptr);
if (mMin > 0.0f) {
MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
@ -216,8 +216,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
return RET_OK;
}
STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax,
bool narrowRange, int numBits) {
STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int numBits) {
MS_ASSERT(quantParam != nullptr);
if (mMin > 0.0f) {
MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
@ -246,8 +245,8 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
return RET_OK;
}
int quantMin = narrowRange ? 1 : 0 - 128;
int quantMax = (1 << (unsigned int) numBits) - 1 - 128;
const int8_t quantMin = std::numeric_limits<int8_t>::min() + (narrowRange ? 1 : 0);
const int8_t quantMax = std::numeric_limits<int8_t>::max();
auto quantMinFloat = static_cast<double>(quantMin);
auto quantMaxFloat = static_cast<double>(quantMax);
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
@ -264,6 +263,9 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
} else {
zeroPoint = static_cast<int32_t>(std::round(zpDouble));
}
if (std::abs(mMin) == std::abs(mMax)) {
zeroPoint = 0;
}
// The zero point should always be in the range of quantized value,
// [qmin, qmax].
MS_ASSERT(zeroPoint >= quantMin);

View File

@ -23,6 +23,7 @@
#include <array>
#include <vector>
#include <algorithm>
#include <limits>
#include "tools/converter/quantizer/quantizer.h"
#include "src/ops/primitive_c.h"
#include "include/errorcode.h"
@ -75,13 +76,15 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
const auto zeroPoint = quantParam->zeroPoint;
const auto numBit = quantParam->numBits;
const auto narrowRange = quantParam->narrowRange;
const double maxLimit = static_cast<float>((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale;
double maxLimitTemp = static_cast<float>((1 << (unsigned int)numBit) - 1);
const double maxLimit = static_cast<float>(maxLimitTemp - zeroPoint + std::numeric_limits<int8_t>::min()) * scale;
double minLimit;
if (narrowRange) {
minLimit = static_cast<float>(1 - zeroPoint) * scale;
minLimit = static_cast<float>(std::numeric_limits<int8_t>::min() + 1 - zeroPoint) * scale;
} else {
minLimit = static_cast<float>(0 - zeroPoint) * scale;
minLimit = static_cast<float>(std::numeric_limits<int8_t>::min() - zeroPoint) * scale;
}
return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
double tmp = 0.0f;
if (originData > maxLimit) {
@ -91,10 +94,7 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
} else {
tmp = originData;
}
auto quantData = static_cast<T>(std::round(tmp / scale + zeroPoint));
if (quantData == 0 && narrowRange) {
quantData++;
}
auto quantData = static_cast<T>(std::round(zeroPoint + tmp / scale));
return quantData;
}();
}