!6962 support mul int8 op

Merge pull request !6962 from cjh9368/support_int8_op
This commit is contained in:
mindspore-ci-bot 2020-09-28 15:34:43 +08:00 committed by Gitee
commit d71cd1e8e7
4 changed files with 58 additions and 41 deletions

View File

@ -79,7 +79,8 @@ static const std::vector<schema::PrimitiveType> int8OpList = {
schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub,
schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK,
schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul,
schema::PrimitiveType_Pad, schema::PrimitiveType_DeConv2D};
schema::PrimitiveType_Pad, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_Scale};
static const std::vector<schema::PrimitiveType> needInsertOpList = {
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,

View File

@ -106,7 +106,7 @@ STATUS AwareQuantizer::GenerateQuantParam() {
for (auto graphInputIndex : graph->inputIndex) {
auto status = mInputArray->SetInputArrayQP(graph, graphInputIndex);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetInputArrayQP failed";
MS_LOG(WARNING) << "SetInputArrayQP failed";
return status;
}
}
@ -121,8 +121,8 @@ STATUS AwareQuantizer::GenerateQuantParam() {
}
auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
if (quantParamCalcer == nullptr) {
MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
} else {
auto status = quantParamCalcer->Calc(graph, *node);
@ -154,7 +154,7 @@ STATUS AwareQuantizer::DoQuantize() {
GetCNodeTType(*node) == schema::PrimitiveType_MatMul) {
auto inputIndexes = node->inputIndex;
if (inputIndexes.size() < 2) {
MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count";
MS_LOG(WARNING) << node->name.c_str() << " node input has invalid inputs tensor count";
return RET_ERROR;
}
// quant weight
@ -162,7 +162,7 @@ STATUS AwareQuantizer::DoQuantize() {
if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) {
status = QuantConvWeight(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantConvWeight failed!";
MS_LOG(WARNING) << "QuantConvWeight failed!";
return RET_ERROR;
}
}
@ -172,7 +172,7 @@ STATUS AwareQuantizer::DoQuantize() {
if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) {
status = QuantConvBias(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantConvBias failed!";
MS_LOG(WARNING) << "QuantConvBias failed!";
return RET_ERROR;
}
}
@ -180,13 +180,15 @@ STATUS AwareQuantizer::DoQuantize() {
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
status = QuantDetectionPostProcessConstTensor(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!";
MS_LOG(WARNING) << "QuantDetectionPostProcessConstTensor failed!";
return RET_ERROR;
}
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) {
status = QuantAddConstTensor(graph, node.get());
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add ||
GetCNodeTType(*node) == schema::PrimitiveType_Scale ||
GetCNodeTType(*node) == schema::PrimitiveType_Mul) {
status = QuantArithmeticConstTensor(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantAddConstTensor failed!";
MS_LOG(WARNING) << "QuantArithmeticConstTensor failed!";
return RET_ERROR;
}
}
@ -203,7 +205,7 @@ STATUS AwareQuantizer::DoQuantize() {
return RET_OK;
}
STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) {
STATUS AwareQuantizer::QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(node != nullptr);
for (size_t i = 0; i < node->inputIndex.size(); i++) {
@ -211,28 +213,40 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, sche
MS_ASSERT(graph->allTensors.size() > inTensorIdx);
auto &inTensor = graph->allTensors.at(inTensorIdx);
MS_ASSERT(inTensor != nullptr);
if (inTensor->refCount == 999) {
switch (inTensor->dataType) {
case TypeId::kNumberTypeFloat: {
auto quantParam = GetTensorQuantParam(inTensor);
MS_ASSERT(quantParam != nullptr);
MS_ASSERT(quantParam->inited);
auto constTensorShapeSize = GetShapeSize(*(inTensor.get()));
vector<uint8_t> qDatas(constTensorShapeSize);
void *inData = inTensor->data.data();
auto *castedInData = static_cast<float *>(inData);
for (size_t j = 0; j < constTensorShapeSize; j++) {
qDatas[j] = QuantizeData<uint8_t>(castedInData[j], quantParam.get());
}
inTensor->data = std::move(qDatas);
inTensor->dataType = kNumberTypeUInt8;
} break;
case kNumberTypeUInt8:
break;
default:
MS_LOG(ERROR) << "Unsupported dataType: " << inTensor->dataType;
return RET_ERROR;
if (!inTensor->data.empty()) {
if (inTensor->dataType == TypeId::kNumberTypeInt8) {
continue;
}
if (inTensor->dataType != TypeId::kNumberTypeFloat32 && inTensor->dataType != TypeId::kNumberTypeFloat &&
inTensor->dataType != TypeId::kNumberTypeUInt8) {
MS_LOG(WARNING) << node->name.c_str() << "'s weight data is not float or uint8";
return RET_ERROR;
}
auto quantParam = GetTensorQuantParam(inTensor);
MS_ASSERT(quantParam != nullptr);
MS_ASSERT(quantParam->inited);
auto constTensorShapeSize = GetShapeSize(*(inTensor.get()));
vector<int8_t> qDatas(constTensorShapeSize);
void *inData = inTensor->data.data();
if (inTensor->dataType == TypeId::kNumberTypeFloat ||
inTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant
auto *weightData = static_cast<float *>(inData);
for (size_t j = 0; j < constTensorShapeSize; j++) {
qDatas[j] = QuantizeData<int8_t>(weightData[j], quantParam.get());
}
} else { // tflite awareing quant
auto *weightData = static_cast<uint8_t *>(inData);
for (size_t j = 0; j < constTensorShapeSize; j++) {
qDatas[j] = (int32_t)weightData[j] - 128;
}
quantParam->zeroPoint -= 128;
inTensor->quantParams.clear();
inTensor->quantParams.emplace_back(quantParam.release());
}
::memcpy(inTensor->data.data(), qDatas.data(), constTensorShapeSize);
inTensor->dataType = TypeId::kNumberTypeInt8;
}
}
return RET_OK;
@ -245,21 +259,21 @@ STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGr
MS_ASSERT(constTensor != nullptr);
const auto *constData = reinterpret_cast<const float *>(constTensor->data.data());
if (constTensor->nodeType == schema::NodeType::NodeType_ValueNode &&
constTensor->dataType == TypeId::kNumberTypeFloat) {
if (!constTensor->data.empty() &&
(constTensor->dataType == TypeId::kNumberTypeFloat || constTensor->dataType == TypeId::kNumberTypeFloat32)) {
size_t constTensorShapeSize = GetShapeSize(*constTensor);
std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor);
if (quantParam == nullptr) {
MS_LOG(ERROR) << "new QuantParamT failed";
return RET_NULL_PTR;
}
vector<uint8_t> qDatas(constTensorShapeSize);
vector<int8_t> qDatas(constTensorShapeSize);
for (size_t j = 0; j < constTensorShapeSize; j++) {
float rawData = constData[j];
qDatas[j] = QuantizeData<uint8_t>(rawData, quantParam.get());
qDatas[j] = QuantizeData<int8_t>(rawData, quantParam.get());
}
constTensor->data = std::move(qDatas);
constTensor->dataType = TypeId::kNumberTypeUInt8;
::memcpy(constTensor->data.data(), qDatas.data(), constTensorShapeSize);
constTensor->dataType = TypeId::kNumberTypeInt8;
}
return RET_OK;
}
@ -340,13 +354,14 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem
}
if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat &&
weightTensor->dataType != TypeId::kNumberTypeUInt8) {
MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8";
MS_LOG(WARNING) << "conv " << node->name.c_str() << "'s weight data is not float or uint8";
return RET_ERROR;
}
size_t wShapeSize = GetShapeSize(*(weightTensor.get()));
void *oriWeightData = weightTensor->data.data();
MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr);
vector<int8_t> qDatas(wShapeSize);
// todo support perchannel
auto weightQauntParam = GetTensorQuantParam(weightTensor);
if (weightTensor->dataType == TypeId::kNumberTypeFloat ||
weightTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant

View File

@ -67,7 +67,7 @@ class AwareQuantizer : public FbQuantizer {
STATUS GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph);
STATUS QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node);
STATUS QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node);
STATUS QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node);

View File

@ -474,6 +474,7 @@ QuantParamCalcRegister::QuantParamCalcRegister() {
_registerMap[schema::PrimitiveType_Activation] = std::make_shared<CalcActivation>();
_registerMap[schema::PrimitiveType_Add] = std::make_shared<CalcAdd>();
_registerMap[schema::PrimitiveType_Mul] = commonCalcer;
_registerMap[schema::PrimitiveType_Scale] = commonCalcer;
_registerMap[schema::PrimitiveType_Conv2D] = commonCalcer;
_registerMap[schema::PrimitiveType_DeConv2D] = commonCalcer;
_registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer;