forked from OSSInnovation/mindspore
!6962 support mul int8 op
Merge pull request !6962 from cjh9368/support_int8_op
This commit is contained in:
commit
d71cd1e8e7
|
@ -79,7 +79,8 @@ static const std::vector<schema::PrimitiveType> int8OpList = {
|
|||
schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub,
|
||||
schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK,
|
||||
schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul,
|
||||
schema::PrimitiveType_Pad, schema::PrimitiveType_DeConv2D};
|
||||
schema::PrimitiveType_Pad, schema::PrimitiveType_DeConv2D,
|
||||
schema::PrimitiveType_Scale};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> needInsertOpList = {
|
||||
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
|
||||
|
|
|
@ -106,7 +106,7 @@ STATUS AwareQuantizer::GenerateQuantParam() {
|
|||
for (auto graphInputIndex : graph->inputIndex) {
|
||||
auto status = mInputArray->SetInputArrayQP(graph, graphInputIndex);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetInputArrayQP failed";
|
||||
MS_LOG(WARNING) << "SetInputArrayQP failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
@ -121,8 +121,8 @@ STATUS AwareQuantizer::GenerateQuantParam() {
|
|||
}
|
||||
auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
|
||||
if (quantParamCalcer == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str()
|
||||
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
|
||||
MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str()
|
||||
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
|
||||
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
|
||||
} else {
|
||||
auto status = quantParamCalcer->Calc(graph, *node);
|
||||
|
@ -154,7 +154,7 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
GetCNodeTType(*node) == schema::PrimitiveType_MatMul) {
|
||||
auto inputIndexes = node->inputIndex;
|
||||
if (inputIndexes.size() < 2) {
|
||||
MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count";
|
||||
MS_LOG(WARNING) << node->name.c_str() << " node input has invalid inputs tensor count";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// quant weight
|
||||
|
@ -162,7 +162,7 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) {
|
||||
status = QuantConvWeight(graph, node.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantConvWeight failed!";
|
||||
MS_LOG(WARNING) << "QuantConvWeight failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -172,7 +172,7 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) {
|
||||
status = QuantConvBias(graph, node.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantConvBias failed!";
|
||||
MS_LOG(WARNING) << "QuantConvBias failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -180,13 +180,15 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
|
||||
status = QuantDetectionPostProcessConstTensor(graph, node.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!";
|
||||
MS_LOG(WARNING) << "QuantDetectionPostProcessConstTensor failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) {
|
||||
status = QuantAddConstTensor(graph, node.get());
|
||||
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add ||
|
||||
GetCNodeTType(*node) == schema::PrimitiveType_Scale ||
|
||||
GetCNodeTType(*node) == schema::PrimitiveType_Mul) {
|
||||
status = QuantArithmeticConstTensor(graph, node.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantAddConstTensor failed!";
|
||||
MS_LOG(WARNING) << "QuantArithmeticConstTensor failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -203,7 +205,7 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) {
|
||||
STATUS AwareQuantizer::QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
for (size_t i = 0; i < node->inputIndex.size(); i++) {
|
||||
|
@ -211,28 +213,40 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, sche
|
|||
MS_ASSERT(graph->allTensors.size() > inTensorIdx);
|
||||
auto &inTensor = graph->allTensors.at(inTensorIdx);
|
||||
MS_ASSERT(inTensor != nullptr);
|
||||
if (inTensor->refCount == 999) {
|
||||
switch (inTensor->dataType) {
|
||||
case TypeId::kNumberTypeFloat: {
|
||||
auto quantParam = GetTensorQuantParam(inTensor);
|
||||
MS_ASSERT(quantParam != nullptr);
|
||||
MS_ASSERT(quantParam->inited);
|
||||
auto constTensorShapeSize = GetShapeSize(*(inTensor.get()));
|
||||
vector<uint8_t> qDatas(constTensorShapeSize);
|
||||
void *inData = inTensor->data.data();
|
||||
auto *castedInData = static_cast<float *>(inData);
|
||||
for (size_t j = 0; j < constTensorShapeSize; j++) {
|
||||
qDatas[j] = QuantizeData<uint8_t>(castedInData[j], quantParam.get());
|
||||
}
|
||||
inTensor->data = std::move(qDatas);
|
||||
inTensor->dataType = kNumberTypeUInt8;
|
||||
} break;
|
||||
case kNumberTypeUInt8:
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported dataType: " << inTensor->dataType;
|
||||
return RET_ERROR;
|
||||
if (!inTensor->data.empty()) {
|
||||
if (inTensor->dataType == TypeId::kNumberTypeInt8) {
|
||||
continue;
|
||||
}
|
||||
if (inTensor->dataType != TypeId::kNumberTypeFloat32 && inTensor->dataType != TypeId::kNumberTypeFloat &&
|
||||
inTensor->dataType != TypeId::kNumberTypeUInt8) {
|
||||
MS_LOG(WARNING) << node->name.c_str() << "'s weight data is not float or uint8";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto quantParam = GetTensorQuantParam(inTensor);
|
||||
MS_ASSERT(quantParam != nullptr);
|
||||
MS_ASSERT(quantParam->inited);
|
||||
auto constTensorShapeSize = GetShapeSize(*(inTensor.get()));
|
||||
vector<int8_t> qDatas(constTensorShapeSize);
|
||||
void *inData = inTensor->data.data();
|
||||
if (inTensor->dataType == TypeId::kNumberTypeFloat ||
|
||||
inTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant
|
||||
auto *weightData = static_cast<float *>(inData);
|
||||
for (size_t j = 0; j < constTensorShapeSize; j++) {
|
||||
qDatas[j] = QuantizeData<int8_t>(weightData[j], quantParam.get());
|
||||
}
|
||||
} else { // tflite awareing quant
|
||||
auto *weightData = static_cast<uint8_t *>(inData);
|
||||
for (size_t j = 0; j < constTensorShapeSize; j++) {
|
||||
qDatas[j] = (int32_t)weightData[j] - 128;
|
||||
}
|
||||
quantParam->zeroPoint -= 128;
|
||||
inTensor->quantParams.clear();
|
||||
inTensor->quantParams.emplace_back(quantParam.release());
|
||||
}
|
||||
|
||||
::memcpy(inTensor->data.data(), qDatas.data(), constTensorShapeSize);
|
||||
inTensor->dataType = TypeId::kNumberTypeInt8;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -245,21 +259,21 @@ STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGr
|
|||
MS_ASSERT(constTensor != nullptr);
|
||||
const auto *constData = reinterpret_cast<const float *>(constTensor->data.data());
|
||||
|
||||
if (constTensor->nodeType == schema::NodeType::NodeType_ValueNode &&
|
||||
constTensor->dataType == TypeId::kNumberTypeFloat) {
|
||||
if (!constTensor->data.empty() &&
|
||||
(constTensor->dataType == TypeId::kNumberTypeFloat || constTensor->dataType == TypeId::kNumberTypeFloat32)) {
|
||||
size_t constTensorShapeSize = GetShapeSize(*constTensor);
|
||||
std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor);
|
||||
if (quantParam == nullptr) {
|
||||
MS_LOG(ERROR) << "new QuantParamT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
vector<uint8_t> qDatas(constTensorShapeSize);
|
||||
vector<int8_t> qDatas(constTensorShapeSize);
|
||||
for (size_t j = 0; j < constTensorShapeSize; j++) {
|
||||
float rawData = constData[j];
|
||||
qDatas[j] = QuantizeData<uint8_t>(rawData, quantParam.get());
|
||||
qDatas[j] = QuantizeData<int8_t>(rawData, quantParam.get());
|
||||
}
|
||||
constTensor->data = std::move(qDatas);
|
||||
constTensor->dataType = TypeId::kNumberTypeUInt8;
|
||||
::memcpy(constTensor->data.data(), qDatas.data(), constTensorShapeSize);
|
||||
constTensor->dataType = TypeId::kNumberTypeInt8;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -340,13 +354,14 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem
|
|||
}
|
||||
if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat &&
|
||||
weightTensor->dataType != TypeId::kNumberTypeUInt8) {
|
||||
MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8";
|
||||
MS_LOG(WARNING) << "conv " << node->name.c_str() << "'s weight data is not float or uint8";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t wShapeSize = GetShapeSize(*(weightTensor.get()));
|
||||
void *oriWeightData = weightTensor->data.data();
|
||||
MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr);
|
||||
vector<int8_t> qDatas(wShapeSize);
|
||||
// todo support perchannel
|
||||
auto weightQauntParam = GetTensorQuantParam(weightTensor);
|
||||
if (weightTensor->dataType == TypeId::kNumberTypeFloat ||
|
||||
weightTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant
|
||||
|
|
|
@ -67,7 +67,7 @@ class AwareQuantizer : public FbQuantizer {
|
|||
|
||||
STATUS GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph);
|
||||
|
||||
STATUS QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node);
|
||||
STATUS QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node);
|
||||
|
||||
STATUS QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node);
|
||||
|
||||
|
|
|
@ -474,6 +474,7 @@ QuantParamCalcRegister::QuantParamCalcRegister() {
|
|||
_registerMap[schema::PrimitiveType_Activation] = std::make_shared<CalcActivation>();
|
||||
_registerMap[schema::PrimitiveType_Add] = std::make_shared<CalcAdd>();
|
||||
_registerMap[schema::PrimitiveType_Mul] = commonCalcer;
|
||||
_registerMap[schema::PrimitiveType_Scale] = commonCalcer;
|
||||
_registerMap[schema::PrimitiveType_Conv2D] = commonCalcer;
|
||||
_registerMap[schema::PrimitiveType_DeConv2D] = commonCalcer;
|
||||
_registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer;
|
||||
|
|
Loading…
Reference in New Issue