ms model quant param

This commit is contained in:
yankai 2020-08-15 17:54:31 +08:00
parent 81833943ba
commit 921e2cdbc2
12 changed files with 220 additions and 111 deletions

View File

@ -47,7 +47,15 @@ class PrimitiveTValue : public Value {
}
}
void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) {}
void SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) {
this->input_quant_param_ = input_quant_param;
}
void SetOutputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) {
this->output_quant_param_ = output_quant_param;
}
void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) {
this->input_quant_param_.emplace_back(quant_param);

View File

@ -37,8 +37,13 @@ int Nchw2NhwcCPUKernel::Run() {
auto output = out_tensors_[0];
if (input->shape().size() == 4) {
PackNCHWToNHWCFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
if (input->data_type() == kNumberTypeFloat32) {
PackNCHWToNHWCFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
} else if (input->data_type() == kNumberTypeInt8) {
PackNCHWToNHWCInt8(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
}
} else {
memcpy(output->Data(), input->Data(), input->ElementsNum() * sizeof(float));
}
@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector<lite::tensor
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -37,8 +37,13 @@ int Nhwc2NchwCPUKernel::Run() {
auto output = out_tensors_[0];
if (input->shape().size() == 4) {
PackNHWCToNCHWFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
if (input->data_type() == kNumberTypeFloat32) {
PackNHWCToNCHWFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
} else if (input->data_type() == kNumberTypeInt8) {
PackNHWCToNCHWInt8(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
}
} else {
memcpy(output->Data(), input->Data(), input->ElementsNum() * sizeof(float));
}
@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector<lite::tensor
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -978,6 +978,19 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
return;
}
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int c = 0; c < channel; c++) {
for (int hw = 0; hw < plane; hw++) {
int nhwc_index = n * channel * plane + hw * channel + c;
int nchw_index = n * channel * plane + c * plane + hw;
((int8_t *)dst)[nchw_index] = ((int8_t *)src)[nhwc_index];
}
}
}
return;
}
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
return PackNHWCToNCHWFp32(src, dst, batch, channel, plane);
}

View File

@ -60,6 +60,8 @@ void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int c
void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);

View File

@ -122,8 +122,10 @@ void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, f
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
void AnfConvPopulater::PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim->GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim->GetAttr("num_bits");
@ -154,7 +156,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
quants.clear();
int biasQuantSize = 0;
@ -173,7 +175,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
}
quants.clear();
@ -181,10 +183,12 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quantParam.min = 0.0;
quantParam.max = 0.0;
quantParam.zeroPoint = 0;
quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
quantParam.scale =
vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
quants.clear();
auto outputMin = prim->GetAttr("output_minq");
@ -199,7 +203,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
vecOutputQuantParam->emplace_back(quants);
}
}
@ -215,10 +219,13 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit
PopulaterConv2DSingleGroup(prim, primitive, group);
}
primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
PopulaterQuantParam(prim, &vecQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
}
return 0;
}

View File

@ -20,9 +20,10 @@
#ifndef MINDSPORE_ANF_CONV_PARSER_H
#define MINDSPORE_ANF_CONV_PARSER_H
#include "tools/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
#include <memory>
#include <vector>
#include "tools/anf_importer/anf_populater/anf_node_populater.h"
namespace mindspore::lite {
class AnfConvPopulater : public AnfNodePopulater {
public:
@ -32,12 +33,18 @@ class AnfConvPopulater : public AnfNodePopulater {
const std::vector<AnfNodePtr> &inputs) override;
private:
void PopulaterConv2DMultiGroup(const PrimitivePtr &prim, const std::unique_ptr<schema::PrimitiveT> &primitive,
const int &group);
void PopulaterConv2DSingleGroup(const PrimitivePtr &prim, const std::unique_ptr<schema::PrimitiveT> &primitive,
const int &group);
void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
void PopulaterConv2DMultiGroup(
const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
void PopulaterConv2DSingleGroup(
const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
void PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
float *mMax);
};
} // namespace mindspore::lite

View File

@ -31,8 +31,10 @@ void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, const double &
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
void AnfDepwiseconv2DPopulater::PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim->GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim->GetAttr("num_bits");
@ -63,7 +65,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
quants.clear();
int biasQuantSize = 0;
@ -82,7 +84,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
}
quants.clear();
@ -90,10 +92,12 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quantParam.min = 0.0;
quantParam.max = 0.0;
quantParam.zeroPoint = 0;
quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
quantParam.scale =
vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
quants.clear();
auto outputMin = prim->GetAttr("output_minq");
@ -108,7 +112,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
vecOutputQuantParam->emplace_back(quants);
}
}
@ -177,10 +181,12 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType()) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
PopulaterQuantParam(prim, &vecQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
}
return 0;
}

View File

@ -28,8 +28,12 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
const std::vector<AnfNodePtr> &inputs) override;
private:
void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
void PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
float *mMax);
};
} // namespace mindspore::lite

View File

@ -30,8 +30,10 @@ void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev,
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
void AnfMatmulPopulater::PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim->GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim->GetAttr("num_bits");
@ -62,7 +64,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
quants.clear();
auto filterMin = prim->GetAttr("filter_minq");
@ -79,7 +81,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
}
quants.clear();
@ -95,7 +97,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
vecOutputQuantParam->emplace_back(quants);
}
}
@ -110,12 +112,13 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType()) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
PopulaterQuantParam(prim, &vecQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
}
return 0;
}
AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater());

View File

@ -26,8 +26,12 @@ class AnfMatmulPopulater : public AnfNodePopulater {
const std::vector<AnfNodePtr> &inputs) override;
private:
void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
void PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
float *mMax);
};
} // namespace mindspore::lite

View File

@ -15,20 +15,22 @@
*/
#include "tools/converter/quantizer/aware_quantizer.h"
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "schema/inner/model_generated.h"
#include "utils/log_adapter.h"
#include "securec/include/securec.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "src/common/utils.h"
#include "tools/converter/quantizer/calc_quant_param.h"
#include "tools/common/tensor_util.h"
#include "tools/common/converter_op_utils.h"
#include "tools/common/node_util.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/quantizer/calc_quant_param.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "utils/log_adapter.h"
using std::string;
using std::vector;
@ -42,7 +44,8 @@ struct InputArray {
int numBits = 8;
TypeId dataType = TypeId::kTypeUnknown;
InputArray(float mean, float stdDev, TypeId dataType = TypeId::kNumberTypeFloat) {
InputArray(float mean, float stdDev,
TypeId dataType = TypeId::kNumberTypeFloat) {
this->dataType = dataType;
constexpr float qmin = 0;
constexpr float qmax = 255;
@ -52,7 +55,8 @@ struct InputArray {
STATUS InitQuantParam() {
this->quantParam = std::make_unique<schema::QuantParamT>();
auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, narrowRange, numBits);
auto status = CalQuantizationParams(quantParam.get(), mMin, mMax,
narrowRange, numBits);
if (status != RET_OK) {
return status;
}
@ -66,7 +70,8 @@ struct InputArray {
if (!tensor->quantParams.empty()) {
auto param = GetTensorQuantParam(tensor);
if (param != nullptr && param->inited) {
MS_LOG(DEBUG) << "tensor " << inputTensorIdx << " already has quantParam";
MS_LOG(DEBUG) << "tensor " << inputTensorIdx
<< " already has quantParam";
return RET_OK;
}
tensor->quantParams.clear();
@ -83,11 +88,14 @@ struct InputArray {
};
const std::array<schema::PrimitiveType, 7> AwareQuantizer::propagatedOps = {
{schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape,
schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation,
schema::PrimitiveType_DetectionPostProcess}};
{schema::PrimitiveType_Concat, schema::PrimitiveType_Resize,
schema::PrimitiveType_Reshape, schema::PrimitiveType_Squeeze,
schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation,
schema::PrimitiveType_DetectionPostProcess}};
AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues,
AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph,
const string &inputInferType,
const string &stdValues,
const string &meanValues)
: FbQuantizer(graph) {
MS_ASSERT(graph != nullptr);
@ -110,9 +118,11 @@ STATUS AwareQuantizer::RemoveFakeQuant() {
// MS_LOGE("GenerateDefaultQuantParam failed: %d", status);
// return RET_ERROR;
// }
// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end(); iter++) {
// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end();
// iter++) {
// auto *node = (*iter).get();
// if (GetCNodeTType(*node) != OpT_FakeQuantWithMinMaxVars && GetCNodeTType(*node) != OpT_FakeQuantWithMinMax) {
// if (GetCNodeTType(*node) != OpT_FakeQuantWithMinMaxVars &&
// GetCNodeTType(*node) != OpT_FakeQuantWithMinMax) {
// continue;
// }
// auto inputIndexes = node->inputIndex;
@ -144,41 +154,43 @@ STATUS AwareQuantizer::RemoveFakeQuant() {
// auto *maxData = reinterpret_cast<const float *>(tensor2->data.data());
// MS_ASSERT(minData != nullptr);
// MS_ASSERT(maxData != nullptr);
// std::unique_ptr<QuantParamT> quantParam(new (std::nothrow) QuantParamT());
// if (quantParam == nullptr) {
// std::unique_ptr<QuantParamT> quantParam(new (std::nothrow)
// QuantParamT()); if (quantParam == nullptr) {
// MS_LOGE("new quantParam failed");
// return RET_ERROR;
// }
// auto realMin = (double)minData[0];
// auto realMax = (double)maxData[0];
// status = CalQuantizationParams(quantParam.get(), realMin, realMax, narrorRange, numBits);
// if (status != RET_OK) {
// MS_LOGE("in aware quantization run CalQuantizationParams failed, node: %s", node->name.c_str());
// return RET_ERROR;
// status = CalQuantizationParams(quantParam.get(), realMin, realMax,
// narrorRange, numBits); if (status != RET_OK) {
// MS_LOGE("in aware quantization run CalQuantizationParams failed,
// node: %s", node->name.c_str()); return RET_ERROR;
// }
// if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT) {
// CalFakeNode(tensor0, quantParam.get());
// }
// std::unique_ptr<QuantParamArrayT> quantParamArray(new (std::nothrow) QuantParamArrayT());
// if (quantParamArray == nullptr) {
// std::unique_ptr<QuantParamArrayT> quantParamArray(new (std::nothrow)
// QuantParamArrayT()); if (quantParamArray == nullptr) {
// MS_LOGE("new quantParamArray failed");
// return RET_ERROR;
// }
// quantParamArray->param.push_back(std::move(quantParam));
// auto quantParamArrayCopy = CopyQuantParamArrayT(quantParamArray);
// if (quantParamArrayCopy == nullptr) {
// MS_LOGE("CopyQuantParamArray %s return nullptr", iter->get()->name.c_str());
// return RET_ERROR;
// MS_LOGE("CopyQuantParamArray %s return nullptr",
// iter->get()->name.c_str()); return RET_ERROR;
// }
// node->quantParam.emplace_back(std::move(quantParamArrayCopy));
// node->quantParam.emplace_back(nullptr); // secondInTensor and thirdInTensor are weightTensors who have no
// preNode node->quantParam.emplace_back(nullptr); node->quantParam.emplace_back(std::move(quantParamArray));
// node->quantParam.emplace_back(nullptr); // secondInTensor and
// thirdInTensor are weightTensors who have no preNode
// node->quantParam.emplace_back(nullptr);
// node->quantParam.emplace_back(std::move(quantParamArray));
//
// // BroadCast fakeQuantNode QuantParam
// status = BroadCastQuantParam(subGraph, *iter);
// if (status != RET_OK) {
// MS_LOGE("BroadCastQuantParam %s failed: %d", iter->get()->name.c_str(), status);
// return status;
// MS_LOGE("BroadCastQuantParam %s failed: %d",
// iter->get()->name.c_str(), status); return status;
// }
// // save post node index for SetAttrToConvolution
// auto postNodeIdxes = GetOutputNodeIdx(*subGraph, *node);
@ -189,10 +201,13 @@ STATUS AwareQuantizer::RemoveFakeQuant() {
// return RET_ERROR;
// }
// // set filter param to node
// if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT && !postNodeIdxes.empty()) {
// if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT &&
// !postNodeIdxes.empty()) {
// auto postNode = subGraph->nodes.at(postNodeIdxes.front()).get();
// if (GetCNodeTType(*postNode) == OpT_Conv2D || GetCNodeTType(*postNode) == OpT_DepthwiseConv2D ||
// GetCNodeTType(*postNode) == OpT_DeConv2D || GetCNodeTType(*postNode) == OpT_DeDepthwiseConv2D) {
// if (GetCNodeTType(*postNode) == OpT_Conv2D ||
// GetCNodeTType(*postNode) == OpT_DepthwiseConv2D ||
// GetCNodeTType(*postNode) == OpT_DeConv2D ||
// GetCNodeTType(*postNode) == OpT_DeDepthwiseConv2D) {
// auto status = SetAttrToConvolution(subGraph.get(), postNode);
// if (status != RET_OK) {
// MS_LOGE("in aware quant SetAttrToConvolution failed!");
@ -203,7 +218,8 @@ STATUS AwareQuantizer::RemoveFakeQuant() {
// }
//
// // remove IsolatedNode
// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end();) {
// for (auto iter = subGraph->nodes.begin(); iter !=
// subGraph->nodes.end();) {
// if ((*iter)->inputIndex.empty() && (*iter)->outputIndex.empty()) {
// iter = subGraph->nodes.erase(iter);
// } else {
@ -213,8 +229,8 @@ STATUS AwareQuantizer::RemoveFakeQuant() {
// // set graphInputNode inputTensor quantParams
// MS_ASSERT(subGraph->inputIndex.size() == 1);
// for (auto graphInputIndex : subGraph->inputIndex) {
// auto linkedPostIdx = GetLinkedPostIdx(*(subGraph.get()), graphInputIndex);
// for (auto nodeIdx : linkedPostIdx) {
// auto linkedPostIdx = GetLinkedPostIdx(*(subGraph.get()),
// graphInputIndex); for (auto nodeIdx : linkedPostIdx) {
// MS_ASSERT(subGraph->nodes.size() > nodeIdx);
// mInputArray->SetInputArrayQP(subGraph->nodes.at(nodeIdx).get());
// }
@ -223,7 +239,8 @@ STATUS AwareQuantizer::RemoveFakeQuant() {
return RET_OK;
}
STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) {
STATUS AwareQuantizer::GenerateDefaultQuantParam(
const schema::MetaGraphT *subGraph) {
MS_ASSERT(subGraph != nullptr);
for (const auto &tensor : subGraph->allTensors) {
if (!tensor->quantParams.empty()) {
@ -235,15 +252,18 @@ STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGr
return RET_OK;
}
STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph,
schema::CNodeT *node) {
// MS_ASSERT(subGraph != nullptr);
// MS_ASSERT(node != nullptr);
// auto inputIndexes = node->inputIndex;
// MS_ASSERT(GetCNodeTType(*node) == OpT_Conv2D || GetCNodeTType(*node) == OpT_DepthwiseConv2D ||
// GetCNodeTType(*node) == OpT_DeConv2D || GetCNodeTType(*node) == OpT_DeDepthwiseConv2D);
// MS_ASSERT(GetCNodeTType(*node) == OpT_Conv2D || GetCNodeTType(*node) ==
// OpT_DepthwiseConv2D ||
// GetCNodeTType(*node) == OpT_DeConv2D || GetCNodeTType(*node) ==
// OpT_DeDepthwiseConv2D);
// if (inputIndexes.size() < 2) {
// MS_LOGE("in aware quant %s node's input tensors is invalid(%zu)!", node->name.c_str(), inputIndexes.size());
// return RET_ERROR;
// MS_LOGE("in aware quant %s node's input tensors is invalid(%zu)!",
// node->name.c_str(), inputIndexes.size()); return RET_ERROR;
// }
// TensorDefT *filterTensor = subGraph->allTensors.at(inputIndexes[1]).get();
// MS_ASSERT(filterTensor != nullptr);
@ -267,14 +287,16 @@ STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph,
// if (GetCNodeTType(*node) == OpT_DepthwiseConv2D) {
// if (node->fmkType == FmkType_MS) {
// node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[0];
// node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[1];
// node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[2];
// node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[3];
// node->attr.AsDepthwiseConv2D()->channelMultiplier =
// (int32_t)filterDims[1]; node->attr.AsDepthwiseConv2D()->kernelH =
// (int32_t)filterDims[2]; node->attr.AsDepthwiseConv2D()->kernelW =
// (int32_t)filterDims[3];
// } else if (node->fmkType == FmkType_TF) {
// node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[0];
// node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[1];
// node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[2];
// node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[3];
// node->attr.AsDepthwiseConv2D()->channelMultiplier =
// (int32_t)filterDims[3];
// } else {
// MS_LOGE("Unsupport");
// }
@ -313,15 +335,19 @@ STATUS AwareQuantizer::GenerateQuantParam() {
GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
MS_ASSERT(false);
}
auto *quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
auto *quantParamCalcer =
quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
if (quantParamCalcer == nullptr) {
MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
MS_LOG(ERROR) << "Can not find QuantParamCalcer for "
<< node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str()
<< " set node to QuantNone and skip";
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
} else {
status = quantParamCalcer->Calc(graph, *node);
if (status != RET_OK) {
MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
MS_LOG(ERROR) << "quantParamCalcer failed: " << status
<< " node: " << node->name.c_str();
node->quantType = schema::QuantType_QUANT_NONE;
} else {
node->quantType = schema::QuantType_AwareTraining;
@ -345,7 +371,8 @@ STATUS AwareQuantizer::DoQuantize() {
GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D) {
auto inputIndexes = node->inputIndex;
if (inputIndexes.size() < 2) {
MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count";
MS_LOG(ERROR) << node->name.c_str()
<< " node input has invalid inputs tensor count";
return RET_ERROR;
}
// quant weight
@ -362,7 +389,8 @@ STATUS AwareQuantizer::DoQuantize() {
return RET_ERROR;
}
}
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
} else if (GetCNodeTType(*node) ==
schema::PrimitiveType_DetectionPostProcess) {
status = QuantDetectionPostProcessConstTensor(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!";
@ -388,7 +416,8 @@ STATUS AwareQuantizer::DoQuantize() {
return RET_OK;
}
STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) {
STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph,
schema::CNodeT *node) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(node != nullptr);
for (size_t i = 0; i < node->inputIndex.size(); i++) {
@ -407,7 +436,8 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, sche
void *inData = inTensor->data.data();
auto *castedInData = static_cast<float *>(inData);
for (size_t j = 0; j < constTensorShapeSize; j++) {
qDatas[j] = QuantizeData<uint8_t>(castedInData[j], quantParam.get());
qDatas[j] =
QuantizeData<uint8_t>(castedInData[j], quantParam.get());
}
inTensor->data = std::move(qDatas);
inTensor->dataType = kNumberTypeUInt8;
@ -423,14 +453,17 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, sche
return RET_OK;
}
STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(
const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
MS_ASSERT(subGraph != nullptr);
MS_ASSERT(node != nullptr);
auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]);
MS_ASSERT(constTensor != nullptr);
const auto *constData = reinterpret_cast<const float *>(constTensor->data.data());
const auto *constData =
reinterpret_cast<const float *>(constTensor->data.data());
if (constTensor->refCount == 999 && constTensor->dataType == TypeId::kNumberTypeFloat) {
if (constTensor->refCount == 999 &&
constTensor->dataType == TypeId::kNumberTypeFloat) {
size_t constTensorShapeSize = GetShapeSize(*constTensor);
std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor);
if (quantParam == nullptr) {
@ -448,7 +481,8 @@ STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGr
return RET_OK;
}
STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) {
STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph,
mindspore::schema::CNodeT *node) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(node != nullptr);
auto inputIndexes = node->inputIndex;
@ -507,7 +541,8 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph,
biasTensor->dataType = TypeId::kNumberTypeInt32;
biasTensor->data.clear();
biasTensor->data.resize(bShapeSize * sizeof(int32_t));
auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas, bShapeSize * sizeof(int32_t));
auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t),
qDatas, bShapeSize * sizeof(int32_t));
if (ret != EOK) {
// MS_LOGE("memcpy_s failed: %d", ret);
return RET_ERROR;
@ -516,10 +551,12 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph,
return RET_OK;
}
STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph,
schema::CNodeT *node) {
MS_ASSERT(subGraph != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size());
MS_ASSERT(node->quantParam.size() ==
node->inputIndex.size() + node->outputIndex.size());
auto inputIndexes = node->inputIndex;
MS_ASSERT(inputIndexes.size() >= 2);
MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1));
@ -527,8 +564,11 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem
if (weightTensor->dataType == TypeId::kNumberTypeInt8) {
return RET_OK;
}
if (weightTensor->dataType != TypeId::kNumberTypeFloat && weightTensor->dataType != TypeId::kNumberTypeUInt8) {
MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8";
if (weightTensor->dataType != TypeId::kNumberTypeFloat32 &&
weightTensor->dataType != TypeId::kNumberTypeFloat &&
weightTensor->dataType != TypeId::kNumberTypeUInt8) {
MS_LOG(ERROR) << "conv " << node->name.c_str()
<< "'s weight data is not float or uint8";
return RET_ERROR;
}
size_t wShapeSize = GetShapeSize(*(weightTensor.get()));
@ -536,7 +576,8 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem
MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr);
vector<int8_t> qDatas(wShapeSize);
auto weightQauntParam = GetTensorQuantParam(weightTensor);
if (weightTensor->dataType == TypeId::kNumberTypeFloat) { // normal awareing quant
if (weightTensor->dataType ==
TypeId::kNumberTypeFloat) { // normal awareing quant
auto *weightData = static_cast<float *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
@ -564,7 +605,8 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
MS_ASSERT(graph->allTensors.size() > inTensorIdx);
auto &inTensor = graph->allTensors.at(inTensorIdx);
MS_ASSERT(inTensor != nullptr);
if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr ||
if (inTensor->quantParams.empty() ||
inTensor->quantParams.front() == nullptr ||
!inTensor->quantParams.front()->inited) {
canQuant = false;
break;
@ -576,7 +618,8 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
MS_ASSERT(graph->allTensors.size() > outTensorIdx);
auto &outTensor = graph->allTensors.at(outTensorIdx);
MS_ASSERT(outTensor != nullptr);
if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr ||
if (outTensor->quantParams.empty() ||
outTensor->quantParams.front() == nullptr ||
!outTensor->quantParams.front()->inited) {
canQuant = false;
break;