forked from mindspore-Ecosystem/mindspore
ms model quant param
This commit is contained in:
parent
81833943ba
commit
921e2cdbc2
|
@ -47,7 +47,15 @@ class PrimitiveTValue : public Value {
|
|||
}
|
||||
}
|
||||
|
||||
void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) {}
|
||||
|
||||
void SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) {
|
||||
this->input_quant_param_ = input_quant_param;
|
||||
}
|
||||
|
||||
void SetOutputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) {
|
||||
this->output_quant_param_ = output_quant_param;
|
||||
}
|
||||
|
||||
|
||||
void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) {
|
||||
this->input_quant_param_.emplace_back(quant_param);
|
||||
|
|
|
@ -37,8 +37,13 @@ int Nchw2NhwcCPUKernel::Run() {
|
|||
auto output = out_tensors_[0];
|
||||
|
||||
if (input->shape().size() == 4) {
|
||||
PackNCHWToNHWCFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
|
||||
output->Channel());
|
||||
if (input->data_type() == kNumberTypeFloat32) {
|
||||
PackNCHWToNHWCFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
|
||||
output->Channel());
|
||||
} else if (input->data_type() == kNumberTypeInt8) {
|
||||
PackNCHWToNHWCInt8(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
|
||||
output->Channel());
|
||||
}
|
||||
} else {
|
||||
memcpy(output->Data(), input->Data(), input->ElementsNum() * sizeof(float));
|
||||
}
|
||||
|
@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector<lite::tensor
|
|||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -37,8 +37,13 @@ int Nhwc2NchwCPUKernel::Run() {
|
|||
auto output = out_tensors_[0];
|
||||
|
||||
if (input->shape().size() == 4) {
|
||||
PackNHWCToNCHWFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
|
||||
output->Channel());
|
||||
if (input->data_type() == kNumberTypeFloat32) {
|
||||
PackNHWCToNCHWFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
|
||||
output->Channel());
|
||||
} else if (input->data_type() == kNumberTypeInt8) {
|
||||
PackNHWCToNCHWInt8(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
|
||||
output->Channel());
|
||||
}
|
||||
} else {
|
||||
memcpy(output->Data(), input->Data(), input->ElementsNum() * sizeof(float));
|
||||
}
|
||||
|
@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector<lite::tensor
|
|||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -978,6 +978,19 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
|
|||
return;
|
||||
}
|
||||
|
||||
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
for (int n = 0; n < batch; n++) {
|
||||
for (int c = 0; c < channel; c++) {
|
||||
for (int hw = 0; hw < plane; hw++) {
|
||||
int nhwc_index = n * channel * plane + hw * channel + c;
|
||||
int nchw_index = n * channel * plane + c * plane + hw;
|
||||
((int8_t *)dst)[nchw_index] = ((int8_t *)src)[nhwc_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
return PackNHWCToNCHWFp32(src, dst, batch, channel, plane);
|
||||
}
|
||||
|
|
|
@ -60,6 +60,8 @@ void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int c
|
|||
|
||||
void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
|
|
@ -122,8 +122,10 @@ void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, f
|
|||
*mMax = static_cast<float>((qmax - mean) / stdDev);
|
||||
}
|
||||
|
||||
void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
|
||||
void AnfConvPopulater::PopulaterQuantParam(
|
||||
const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
|
||||
auto narrow_range = prim->GetAttr("narrow_range");
|
||||
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
|
||||
auto num_bits = prim->GetAttr("num_bits");
|
||||
|
@ -154,7 +156,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecQuantParam->emplace_back(quants);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
int biasQuantSize = 0;
|
||||
|
@ -173,7 +175,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecQuantParam->emplace_back(quants);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
}
|
||||
|
||||
quants.clear();
|
||||
|
@ -181,10 +183,12 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|||
quantParam.min = 0.0;
|
||||
quantParam.max = 0.0;
|
||||
quantParam.zeroPoint = 0;
|
||||
quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
|
||||
|
||||
quantParam.scale =
|
||||
vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecQuantParam->emplace_back(quants);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
auto outputMin = prim->GetAttr("output_minq");
|
||||
|
@ -199,7 +203,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecQuantParam->emplace_back(quants);
|
||||
vecOutputQuantParam->emplace_back(quants);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -215,10 +219,13 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit
|
|||
PopulaterConv2DSingleGroup(prim, primitive, group);
|
||||
}
|
||||
primitiveTValuePtr->SetPrimitiveT(primitive.release());
|
||||
|
||||
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
|
||||
PopulaterQuantParam(prim, &vecQuantParam);
|
||||
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
|
||||
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
||||
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
|
||||
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
|
||||
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -20,9 +20,10 @@
|
|||
#ifndef MINDSPORE_ANF_CONV_PARSER_H
|
||||
#define MINDSPORE_ANF_CONV_PARSER_H
|
||||
|
||||
#include "tools/anf_importer/anf_populater/anf_node_populater.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "tools/anf_importer/anf_populater/anf_node_populater.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class AnfConvPopulater : public AnfNodePopulater {
|
||||
public:
|
||||
|
@ -32,12 +33,18 @@ class AnfConvPopulater : public AnfNodePopulater {
|
|||
const std::vector<AnfNodePtr> &inputs) override;
|
||||
|
||||
private:
|
||||
void PopulaterConv2DMultiGroup(const PrimitivePtr &prim, const std::unique_ptr<schema::PrimitiveT> &primitive,
|
||||
const int &group);
|
||||
void PopulaterConv2DSingleGroup(const PrimitivePtr &prim, const std::unique_ptr<schema::PrimitiveT> &primitive,
|
||||
const int &group);
|
||||
void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
|
||||
void PopulaterConv2DMultiGroup(
|
||||
const PrimitivePtr &prim,
|
||||
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
|
||||
void PopulaterConv2DSingleGroup(
|
||||
const PrimitivePtr &prim,
|
||||
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
|
||||
void PopulaterQuantParam(
|
||||
const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
|
||||
float *mMax);
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
|
|
|
@ -31,8 +31,10 @@ void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, const double &
|
|||
*mMax = static_cast<float>((qmax - mean) / stdDev);
|
||||
}
|
||||
|
||||
void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
|
||||
void AnfDepwiseconv2DPopulater::PopulaterQuantParam(
|
||||
const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
|
||||
auto narrow_range = prim->GetAttr("narrow_range");
|
||||
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
|
||||
auto num_bits = prim->GetAttr("num_bits");
|
||||
|
@ -63,7 +65,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecQuantParam->emplace_back(quants);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
int biasQuantSize = 0;
|
||||
|
@ -82,7 +84,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecQuantParam->emplace_back(quants);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
}
|
||||
|
||||
quants.clear();
|
||||
|
@ -90,10 +92,12 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|||
quantParam.min = 0.0;
|
||||
quantParam.max = 0.0;
|
||||
quantParam.zeroPoint = 0;
|
||||
quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
|
||||
|
||||
quantParam.scale =
|
||||
vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecQuantParam->emplace_back(quants);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
auto outputMin = prim->GetAttr("output_minq");
|
||||
|
@ -108,7 +112,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecQuantParam->emplace_back(quants);
|
||||
vecOutputQuantParam->emplace_back(quants);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -177,10 +181,12 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu
|
|||
MS_ASSERT(primitiveTValuePtr != nullptr);
|
||||
primitiveTValuePtr->SetPrimitiveT(primitive.release());
|
||||
|
||||
if (primitiveTValuePtr->GetQuantType()) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
|
||||
PopulaterQuantParam(prim, &vecQuantParam);
|
||||
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
|
||||
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
||||
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
|
||||
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
|
||||
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -28,8 +28,12 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
|
|||
const std::vector<AnfNodePtr> &inputs) override;
|
||||
|
||||
private:
|
||||
void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
|
||||
void PopulaterQuantParam(
|
||||
const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
|
||||
float *mMax);
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
|
|
|
@ -30,8 +30,10 @@ void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev,
|
|||
*mMax = static_cast<float>((qmax - mean) / stdDev);
|
||||
}
|
||||
|
||||
void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
|
||||
void AnfMatmulPopulater::PopulaterQuantParam(
|
||||
const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
|
||||
auto narrow_range = prim->GetAttr("narrow_range");
|
||||
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
|
||||
auto num_bits = prim->GetAttr("num_bits");
|
||||
|
@ -62,7 +64,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecQuantParam->emplace_back(quants);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
auto filterMin = prim->GetAttr("filter_minq");
|
||||
|
@ -79,7 +81,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecQuantParam->emplace_back(quants);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
}
|
||||
|
||||
quants.clear();
|
||||
|
@ -95,7 +97,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecQuantParam->emplace_back(quants);
|
||||
vecOutputQuantParam->emplace_back(quants);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,12 +112,13 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim
|
|||
primitive->value.value = attr.release();
|
||||
MS_ASSERT(primitiveTValuePtr != nullptr);
|
||||
primitiveTValuePtr->SetPrimitiveT(primitive.release());
|
||||
if (primitiveTValuePtr->GetQuantType()) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
|
||||
PopulaterQuantParam(prim, &vecQuantParam);
|
||||
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
|
||||
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
||||
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
|
||||
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
|
||||
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater());
|
||||
|
|
|
@ -26,8 +26,12 @@ class AnfMatmulPopulater : public AnfNodePopulater {
|
|||
const std::vector<AnfNodePtr> &inputs) override;
|
||||
|
||||
private:
|
||||
void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
|
||||
void PopulaterQuantParam(
|
||||
const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
|
||||
float *mMax);
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
|
|
|
@ -15,20 +15,22 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/quantizer/aware_quantizer.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/converter/quantizer/calc_quant_param.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/common/converter_op_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/quantizer/calc_quant_param.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
@ -42,7 +44,8 @@ struct InputArray {
|
|||
int numBits = 8;
|
||||
TypeId dataType = TypeId::kTypeUnknown;
|
||||
|
||||
InputArray(float mean, float stdDev, TypeId dataType = TypeId::kNumberTypeFloat) {
|
||||
InputArray(float mean, float stdDev,
|
||||
TypeId dataType = TypeId::kNumberTypeFloat) {
|
||||
this->dataType = dataType;
|
||||
constexpr float qmin = 0;
|
||||
constexpr float qmax = 255;
|
||||
|
@ -52,7 +55,8 @@ struct InputArray {
|
|||
|
||||
STATUS InitQuantParam() {
|
||||
this->quantParam = std::make_unique<schema::QuantParamT>();
|
||||
auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, narrowRange, numBits);
|
||||
auto status = CalQuantizationParams(quantParam.get(), mMin, mMax,
|
||||
narrowRange, numBits);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
|
@ -66,7 +70,8 @@ struct InputArray {
|
|||
if (!tensor->quantParams.empty()) {
|
||||
auto param = GetTensorQuantParam(tensor);
|
||||
if (param != nullptr && param->inited) {
|
||||
MS_LOG(DEBUG) << "tensor " << inputTensorIdx << " already has quantParam";
|
||||
MS_LOG(DEBUG) << "tensor " << inputTensorIdx
|
||||
<< " already has quantParam";
|
||||
return RET_OK;
|
||||
}
|
||||
tensor->quantParams.clear();
|
||||
|
@ -83,11 +88,14 @@ struct InputArray {
|
|||
};
|
||||
|
||||
const std::array<schema::PrimitiveType, 7> AwareQuantizer::propagatedOps = {
|
||||
{schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape,
|
||||
schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation,
|
||||
schema::PrimitiveType_DetectionPostProcess}};
|
||||
{schema::PrimitiveType_Concat, schema::PrimitiveType_Resize,
|
||||
schema::PrimitiveType_Reshape, schema::PrimitiveType_Squeeze,
|
||||
schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation,
|
||||
schema::PrimitiveType_DetectionPostProcess}};
|
||||
|
||||
AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues,
|
||||
AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph,
|
||||
const string &inputInferType,
|
||||
const string &stdValues,
|
||||
const string &meanValues)
|
||||
: FbQuantizer(graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
|
@ -110,9 +118,11 @@ STATUS AwareQuantizer::RemoveFakeQuant() {
|
|||
// MS_LOGE("GenerateDefaultQuantParam failed: %d", status);
|
||||
// return RET_ERROR;
|
||||
// }
|
||||
// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end(); iter++) {
|
||||
// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end();
|
||||
// iter++) {
|
||||
// auto *node = (*iter).get();
|
||||
// if (GetCNodeTType(*node) != OpT_FakeQuantWithMinMaxVars && GetCNodeTType(*node) != OpT_FakeQuantWithMinMax) {
|
||||
// if (GetCNodeTType(*node) != OpT_FakeQuantWithMinMaxVars &&
|
||||
// GetCNodeTType(*node) != OpT_FakeQuantWithMinMax) {
|
||||
// continue;
|
||||
// }
|
||||
// auto inputIndexes = node->inputIndex;
|
||||
|
@ -144,41 +154,43 @@ STATUS AwareQuantizer::RemoveFakeQuant() {
|
|||
// auto *maxData = reinterpret_cast<const float *>(tensor2->data.data());
|
||||
// MS_ASSERT(minData != nullptr);
|
||||
// MS_ASSERT(maxData != nullptr);
|
||||
// std::unique_ptr<QuantParamT> quantParam(new (std::nothrow) QuantParamT());
|
||||
// if (quantParam == nullptr) {
|
||||
// std::unique_ptr<QuantParamT> quantParam(new (std::nothrow)
|
||||
// QuantParamT()); if (quantParam == nullptr) {
|
||||
// MS_LOGE("new quantParam failed");
|
||||
// return RET_ERROR;
|
||||
// }
|
||||
// auto realMin = (double)minData[0];
|
||||
// auto realMax = (double)maxData[0];
|
||||
// status = CalQuantizationParams(quantParam.get(), realMin, realMax, narrorRange, numBits);
|
||||
// if (status != RET_OK) {
|
||||
// MS_LOGE("in aware quantization run CalQuantizationParams failed, node: %s", node->name.c_str());
|
||||
// return RET_ERROR;
|
||||
// status = CalQuantizationParams(quantParam.get(), realMin, realMax,
|
||||
// narrorRange, numBits); if (status != RET_OK) {
|
||||
// MS_LOGE("in aware quantization run CalQuantizationParams failed,
|
||||
// node: %s", node->name.c_str()); return RET_ERROR;
|
||||
// }
|
||||
// if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT) {
|
||||
// CalFakeNode(tensor0, quantParam.get());
|
||||
// }
|
||||
// std::unique_ptr<QuantParamArrayT> quantParamArray(new (std::nothrow) QuantParamArrayT());
|
||||
// if (quantParamArray == nullptr) {
|
||||
// std::unique_ptr<QuantParamArrayT> quantParamArray(new (std::nothrow)
|
||||
// QuantParamArrayT()); if (quantParamArray == nullptr) {
|
||||
// MS_LOGE("new quantParamArray failed");
|
||||
// return RET_ERROR;
|
||||
// }
|
||||
// quantParamArray->param.push_back(std::move(quantParam));
|
||||
// auto quantParamArrayCopy = CopyQuantParamArrayT(quantParamArray);
|
||||
// if (quantParamArrayCopy == nullptr) {
|
||||
// MS_LOGE("CopyQuantParamArray %s return nullptr", iter->get()->name.c_str());
|
||||
// return RET_ERROR;
|
||||
// MS_LOGE("CopyQuantParamArray %s return nullptr",
|
||||
// iter->get()->name.c_str()); return RET_ERROR;
|
||||
// }
|
||||
// node->quantParam.emplace_back(std::move(quantParamArrayCopy));
|
||||
// node->quantParam.emplace_back(nullptr); // secondInTensor and thirdInTensor are weightTensors who have no
|
||||
// preNode node->quantParam.emplace_back(nullptr); node->quantParam.emplace_back(std::move(quantParamArray));
|
||||
// node->quantParam.emplace_back(nullptr); // secondInTensor and
|
||||
// thirdInTensor are weightTensors who have no preNode
|
||||
// node->quantParam.emplace_back(nullptr);
|
||||
// node->quantParam.emplace_back(std::move(quantParamArray));
|
||||
//
|
||||
// // BroadCast fakeQuantNode QuantParam
|
||||
// status = BroadCastQuantParam(subGraph, *iter);
|
||||
// if (status != RET_OK) {
|
||||
// MS_LOGE("BroadCastQuantParam %s failed: %d", iter->get()->name.c_str(), status);
|
||||
// return status;
|
||||
// MS_LOGE("BroadCastQuantParam %s failed: %d",
|
||||
// iter->get()->name.c_str(), status); return status;
|
||||
// }
|
||||
// // save post node index for SetAttrToConvolution
|
||||
// auto postNodeIdxes = GetOutputNodeIdx(*subGraph, *node);
|
||||
|
@ -189,10 +201,13 @@ STATUS AwareQuantizer::RemoveFakeQuant() {
|
|||
// return RET_ERROR;
|
||||
// }
|
||||
// // set filter param to node
|
||||
// if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT && !postNodeIdxes.empty()) {
|
||||
// if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT &&
|
||||
// !postNodeIdxes.empty()) {
|
||||
// auto postNode = subGraph->nodes.at(postNodeIdxes.front()).get();
|
||||
// if (GetCNodeTType(*postNode) == OpT_Conv2D || GetCNodeTType(*postNode) == OpT_DepthwiseConv2D ||
|
||||
// GetCNodeTType(*postNode) == OpT_DeConv2D || GetCNodeTType(*postNode) == OpT_DeDepthwiseConv2D) {
|
||||
// if (GetCNodeTType(*postNode) == OpT_Conv2D ||
|
||||
// GetCNodeTType(*postNode) == OpT_DepthwiseConv2D ||
|
||||
// GetCNodeTType(*postNode) == OpT_DeConv2D ||
|
||||
// GetCNodeTType(*postNode) == OpT_DeDepthwiseConv2D) {
|
||||
// auto status = SetAttrToConvolution(subGraph.get(), postNode);
|
||||
// if (status != RET_OK) {
|
||||
// MS_LOGE("in aware quant SetAttrToConvolution failed!");
|
||||
|
@ -203,7 +218,8 @@ STATUS AwareQuantizer::RemoveFakeQuant() {
|
|||
// }
|
||||
//
|
||||
// // remove IsolatedNode
|
||||
// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end();) {
|
||||
// for (auto iter = subGraph->nodes.begin(); iter !=
|
||||
// subGraph->nodes.end();) {
|
||||
// if ((*iter)->inputIndex.empty() && (*iter)->outputIndex.empty()) {
|
||||
// iter = subGraph->nodes.erase(iter);
|
||||
// } else {
|
||||
|
@ -213,8 +229,8 @@ STATUS AwareQuantizer::RemoveFakeQuant() {
|
|||
// // set graphInputNode inputTensor quantParams
|
||||
// MS_ASSERT(subGraph->inputIndex.size() == 1);
|
||||
// for (auto graphInputIndex : subGraph->inputIndex) {
|
||||
// auto linkedPostIdx = GetLinkedPostIdx(*(subGraph.get()), graphInputIndex);
|
||||
// for (auto nodeIdx : linkedPostIdx) {
|
||||
// auto linkedPostIdx = GetLinkedPostIdx(*(subGraph.get()),
|
||||
// graphInputIndex); for (auto nodeIdx : linkedPostIdx) {
|
||||
// MS_ASSERT(subGraph->nodes.size() > nodeIdx);
|
||||
// mInputArray->SetInputArrayQP(subGraph->nodes.at(nodeIdx).get());
|
||||
// }
|
||||
|
@ -223,7 +239,8 @@ STATUS AwareQuantizer::RemoveFakeQuant() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) {
|
||||
STATUS AwareQuantizer::GenerateDefaultQuantParam(
|
||||
const schema::MetaGraphT *subGraph) {
|
||||
MS_ASSERT(subGraph != nullptr);
|
||||
for (const auto &tensor : subGraph->allTensors) {
|
||||
if (!tensor->quantParams.empty()) {
|
||||
|
@ -235,15 +252,18 @@ STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGr
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
|
||||
STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph,
|
||||
schema::CNodeT *node) {
|
||||
// MS_ASSERT(subGraph != nullptr);
|
||||
// MS_ASSERT(node != nullptr);
|
||||
// auto inputIndexes = node->inputIndex;
|
||||
// MS_ASSERT(GetCNodeTType(*node) == OpT_Conv2D || GetCNodeTType(*node) == OpT_DepthwiseConv2D ||
|
||||
// GetCNodeTType(*node) == OpT_DeConv2D || GetCNodeTType(*node) == OpT_DeDepthwiseConv2D);
|
||||
// MS_ASSERT(GetCNodeTType(*node) == OpT_Conv2D || GetCNodeTType(*node) ==
|
||||
// OpT_DepthwiseConv2D ||
|
||||
// GetCNodeTType(*node) == OpT_DeConv2D || GetCNodeTType(*node) ==
|
||||
// OpT_DeDepthwiseConv2D);
|
||||
// if (inputIndexes.size() < 2) {
|
||||
// MS_LOGE("in aware quant %s node's input tensors is invalid(%zu)!", node->name.c_str(), inputIndexes.size());
|
||||
// return RET_ERROR;
|
||||
// MS_LOGE("in aware quant %s node's input tensors is invalid(%zu)!",
|
||||
// node->name.c_str(), inputIndexes.size()); return RET_ERROR;
|
||||
// }
|
||||
// TensorDefT *filterTensor = subGraph->allTensors.at(inputIndexes[1]).get();
|
||||
// MS_ASSERT(filterTensor != nullptr);
|
||||
|
@ -267,14 +287,16 @@ STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph,
|
|||
// if (GetCNodeTType(*node) == OpT_DepthwiseConv2D) {
|
||||
// if (node->fmkType == FmkType_MS) {
|
||||
// node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[0];
|
||||
// node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[1];
|
||||
// node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[2];
|
||||
// node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[3];
|
||||
// node->attr.AsDepthwiseConv2D()->channelMultiplier =
|
||||
// (int32_t)filterDims[1]; node->attr.AsDepthwiseConv2D()->kernelH =
|
||||
// (int32_t)filterDims[2]; node->attr.AsDepthwiseConv2D()->kernelW =
|
||||
// (int32_t)filterDims[3];
|
||||
// } else if (node->fmkType == FmkType_TF) {
|
||||
// node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[0];
|
||||
// node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[1];
|
||||
// node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[2];
|
||||
// node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[3];
|
||||
// node->attr.AsDepthwiseConv2D()->channelMultiplier =
|
||||
// (int32_t)filterDims[3];
|
||||
// } else {
|
||||
// MS_LOGE("Unsupport");
|
||||
// }
|
||||
|
@ -313,15 +335,19 @@ STATUS AwareQuantizer::GenerateQuantParam() {
|
|||
GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
|
||||
MS_ASSERT(false);
|
||||
}
|
||||
auto *quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
|
||||
auto *quantParamCalcer =
|
||||
quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
|
||||
if (quantParamCalcer == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str()
|
||||
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
|
||||
MS_LOG(ERROR) << "Can not find QuantParamCalcer for "
|
||||
<< node->name.c_str()
|
||||
<< ", type: " << GetCNodeTTypeName(*node).c_str()
|
||||
<< " set node to QuantNone and skip";
|
||||
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
|
||||
} else {
|
||||
status = quantParamCalcer->Calc(graph, *node);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
|
||||
MS_LOG(ERROR) << "quantParamCalcer failed: " << status
|
||||
<< " node: " << node->name.c_str();
|
||||
node->quantType = schema::QuantType_QUANT_NONE;
|
||||
} else {
|
||||
node->quantType = schema::QuantType_AwareTraining;
|
||||
|
@ -345,7 +371,8 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
auto inputIndexes = node->inputIndex;
|
||||
if (inputIndexes.size() < 2) {
|
||||
MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count";
|
||||
MS_LOG(ERROR) << node->name.c_str()
|
||||
<< " node input has invalid inputs tensor count";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// quant weight
|
||||
|
@ -362,7 +389,8 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
|
||||
} else if (GetCNodeTType(*node) ==
|
||||
schema::PrimitiveType_DetectionPostProcess) {
|
||||
status = QuantDetectionPostProcessConstTensor(graph, node.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!";
|
||||
|
@ -388,7 +416,8 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) {
|
||||
STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph,
|
||||
schema::CNodeT *node) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
for (size_t i = 0; i < node->inputIndex.size(); i++) {
|
||||
|
@ -407,7 +436,8 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, sche
|
|||
void *inData = inTensor->data.data();
|
||||
auto *castedInData = static_cast<float *>(inData);
|
||||
for (size_t j = 0; j < constTensorShapeSize; j++) {
|
||||
qDatas[j] = QuantizeData<uint8_t>(castedInData[j], quantParam.get());
|
||||
qDatas[j] =
|
||||
QuantizeData<uint8_t>(castedInData[j], quantParam.get());
|
||||
}
|
||||
inTensor->data = std::move(qDatas);
|
||||
inTensor->dataType = kNumberTypeUInt8;
|
||||
|
@ -423,14 +453,17 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, sche
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
|
||||
STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(
|
||||
const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
|
||||
MS_ASSERT(subGraph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]);
|
||||
MS_ASSERT(constTensor != nullptr);
|
||||
const auto *constData = reinterpret_cast<const float *>(constTensor->data.data());
|
||||
const auto *constData =
|
||||
reinterpret_cast<const float *>(constTensor->data.data());
|
||||
|
||||
if (constTensor->refCount == 999 && constTensor->dataType == TypeId::kNumberTypeFloat) {
|
||||
if (constTensor->refCount == 999 &&
|
||||
constTensor->dataType == TypeId::kNumberTypeFloat) {
|
||||
size_t constTensorShapeSize = GetShapeSize(*constTensor);
|
||||
std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor);
|
||||
if (quantParam == nullptr) {
|
||||
|
@ -448,7 +481,8 @@ STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGr
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) {
|
||||
STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph,
|
||||
mindspore::schema::CNodeT *node) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
auto inputIndexes = node->inputIndex;
|
||||
|
@ -507,7 +541,8 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph,
|
|||
biasTensor->dataType = TypeId::kNumberTypeInt32;
|
||||
biasTensor->data.clear();
|
||||
biasTensor->data.resize(bShapeSize * sizeof(int32_t));
|
||||
auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas, bShapeSize * sizeof(int32_t));
|
||||
auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t),
|
||||
qDatas, bShapeSize * sizeof(int32_t));
|
||||
if (ret != EOK) {
|
||||
// MS_LOGE("memcpy_s failed: %d", ret);
|
||||
return RET_ERROR;
|
||||
|
@ -516,10 +551,12 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
|
||||
STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph,
|
||||
schema::CNodeT *node) {
|
||||
MS_ASSERT(subGraph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size());
|
||||
MS_ASSERT(node->quantParam.size() ==
|
||||
node->inputIndex.size() + node->outputIndex.size());
|
||||
auto inputIndexes = node->inputIndex;
|
||||
MS_ASSERT(inputIndexes.size() >= 2);
|
||||
MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1));
|
||||
|
@ -527,8 +564,11 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem
|
|||
if (weightTensor->dataType == TypeId::kNumberTypeInt8) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (weightTensor->dataType != TypeId::kNumberTypeFloat && weightTensor->dataType != TypeId::kNumberTypeUInt8) {
|
||||
MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8";
|
||||
if (weightTensor->dataType != TypeId::kNumberTypeFloat32 &&
|
||||
weightTensor->dataType != TypeId::kNumberTypeFloat &&
|
||||
weightTensor->dataType != TypeId::kNumberTypeUInt8) {
|
||||
MS_LOG(ERROR) << "conv " << node->name.c_str()
|
||||
<< "'s weight data is not float or uint8";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t wShapeSize = GetShapeSize(*(weightTensor.get()));
|
||||
|
@ -536,7 +576,8 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem
|
|||
MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr);
|
||||
vector<int8_t> qDatas(wShapeSize);
|
||||
auto weightQauntParam = GetTensorQuantParam(weightTensor);
|
||||
if (weightTensor->dataType == TypeId::kNumberTypeFloat) { // normal awareing quant
|
||||
if (weightTensor->dataType ==
|
||||
TypeId::kNumberTypeFloat) { // normal awareing quant
|
||||
auto *weightData = static_cast<float *>(oriWeightData);
|
||||
for (size_t j = 0; j < wShapeSize; j++) {
|
||||
qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
|
||||
|
@ -564,7 +605,8 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
|
|||
MS_ASSERT(graph->allTensors.size() > inTensorIdx);
|
||||
auto &inTensor = graph->allTensors.at(inTensorIdx);
|
||||
MS_ASSERT(inTensor != nullptr);
|
||||
if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr ||
|
||||
if (inTensor->quantParams.empty() ||
|
||||
inTensor->quantParams.front() == nullptr ||
|
||||
!inTensor->quantParams.front()->inited) {
|
||||
canQuant = false;
|
||||
break;
|
||||
|
@ -576,7 +618,8 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
|
|||
MS_ASSERT(graph->allTensors.size() > outTensorIdx);
|
||||
auto &outTensor = graph->allTensors.at(outTensorIdx);
|
||||
MS_ASSERT(outTensor != nullptr);
|
||||
if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr ||
|
||||
if (outTensor->quantParams.empty() ||
|
||||
outTensor->quantParams.front() == nullptr ||
|
||||
!outTensor->quantParams.front()->inited) {
|
||||
canQuant = false;
|
||||
break;
|
||||
|
|
Loading…
Reference in New Issue