forked from mindspore-Ecosystem/mindspore
!5513 add aware quant converter
Merge pull request !5513 from cjh9368/aware_quant
This commit is contained in:
commit
25d5423640
|
@ -18,10 +18,8 @@
|
|||
#define __STDC_FORMAT_MACROS
|
||||
#include <cinttypes>
|
||||
#undef __STDC_FORMAT_MACROS
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <cfloat>
|
||||
#include "src/common/common.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/context.h"
|
||||
|
@ -167,71 +165,6 @@ int Benchmark::ReadCalibData() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
// tensorData need to be converter first
|
||||
float Benchmark::CompareData(const std::string &nodeName, std::vector<int> msShape, float *msTensorData) {
|
||||
auto iter = this->calibData.find(nodeName);
|
||||
if (iter != this->calibData.end()) {
|
||||
std::vector<size_t> castedMSShape;
|
||||
size_t shapeSize = 1;
|
||||
for (int64_t dim : msShape) {
|
||||
castedMSShape.push_back(size_t(dim));
|
||||
shapeSize *= dim;
|
||||
}
|
||||
|
||||
CheckTensor *calibTensor = iter->second;
|
||||
if (calibTensor->shape != castedMSShape) {
|
||||
std::ostringstream oss;
|
||||
oss << "Shape of mslite output(";
|
||||
for (auto dim : castedMSShape) {
|
||||
oss << dim << ",";
|
||||
}
|
||||
oss << ") and shape source model output(";
|
||||
for (auto dim : calibTensor->shape) {
|
||||
oss << dim << ",";
|
||||
}
|
||||
oss << ") are different";
|
||||
std::cerr << oss.str() << std::endl;
|
||||
MS_LOG(ERROR) << oss.str().c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t errorCount = 0;
|
||||
float meanError = 0;
|
||||
std::cout << "Data of node " << nodeName << " : ";
|
||||
for (size_t j = 0; j < shapeSize; j++) {
|
||||
if (j < 50) {
|
||||
std::cout << msTensorData[j] << " ";
|
||||
}
|
||||
|
||||
if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) {
|
||||
std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl;
|
||||
MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto tolerance = absoluteTolerance + relativeTolerance * fabs(calibTensor->data.at(j));
|
||||
auto absoluteError = std::fabs(msTensorData[j] - calibTensor->data.at(j));
|
||||
if (absoluteError > tolerance) {
|
||||
// just assume that atol = rtol
|
||||
meanError += absoluteError / (fabs(calibTensor->data.at(j)) + FLT_MIN);
|
||||
errorCount++;
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
if (meanError > 0.0f) {
|
||||
meanError /= errorCount;
|
||||
}
|
||||
|
||||
if (meanError <= 0.0000001) {
|
||||
std::cout << "Mean bias of node " << nodeName << " : 0%" << std::endl;
|
||||
} else {
|
||||
std::cout << "Mean bias of node " << nodeName << " : " << meanError * 100 << "%" << std::endl;
|
||||
}
|
||||
return meanError;
|
||||
} else {
|
||||
MS_LOG(INFO) << "%s is not in Source Model output", nodeName.c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
int Benchmark::CompareOutput() {
|
||||
std::cout << "================ Comparing Output data ================" << std::endl;
|
||||
|
@ -255,7 +188,24 @@ int Benchmark::CompareOutput() {
|
|||
auto &tensor = tensors.front();
|
||||
MS_ASSERT(tensor->GetDataType() == DataType_DT_FLOAT);
|
||||
MS_ASSERT(tensor->GetData() != nullptr);
|
||||
float bias = CompareData(nodeName, tensor->shape(), static_cast<float *>(tensor->MutableData()));
|
||||
float bias = 0;
|
||||
switch (msCalibDataType) {
|
||||
case TypeId::kNumberTypeFloat: {
|
||||
bias = CompareData<float>(nodeName, tensor->shape(), static_cast<float *>(tensor->MutableData()));
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt8: {
|
||||
bias = CompareData<int8_t>(nodeName, tensor->shape(), static_cast<int8_t *>(tensor->MutableData()));
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt32: {
|
||||
bias = CompareData<int32_t>(nodeName, tensor->shape(), static_cast<int32_t *>(tensor->MutableData()));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Datatype " << msCalibDataType << " is not supported.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (bias >= 0) {
|
||||
totalBias += bias;
|
||||
totalSize++;
|
||||
|
@ -343,14 +293,26 @@ int Benchmark::MarkAccuracy() {
|
|||
MS_LOG(INFO) << "MarkAccuracy";
|
||||
std::cout << "MarkAccuracy" << std::endl;
|
||||
for (size_t i = 0; i < msInputs.size(); i++) {
|
||||
MS_ASSERT(msInputs.at(i) != nullptr);
|
||||
MS_ASSERT(msInputs.at(i)->data_type() == TypeId::kNumberTypeFloat32);
|
||||
auto inData = reinterpret_cast<float *>(msInputs.at(i)->MutableData());
|
||||
std::cout << "InData" << i << ": ";
|
||||
for (size_t j = 0; j < 20; j++) {
|
||||
std::cout << inData[j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
switch (msInputs.at(i)->data_type()) {
|
||||
case TypeId::kNumberTypeFloat:
|
||||
PrintInputData<float>(msInputs.at(i));
|
||||
break;
|
||||
case TypeId::kNumberTypeFloat32:
|
||||
PrintInputData<float>(msInputs.at(i));
|
||||
break;
|
||||
case TypeId::kNumberTypeInt8:
|
||||
PrintInputData<int8_t>(msInputs.at(i));
|
||||
break;
|
||||
case TypeId::kNumberTypeUInt8:
|
||||
PrintInputData<uint8_t>(msInputs.at(i));
|
||||
break;
|
||||
case TypeId::kNumberTypeInt32:
|
||||
PrintInputData<int>(msInputs.at(i));
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Datatype " << msInputs.at(i)->data_type() << " is not supported.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto status = session->RunGraph();
|
||||
if (status != RET_OK) {
|
||||
|
@ -555,6 +517,16 @@ int Benchmark::Init() {
|
|||
|
||||
this->_flags->inDataType = this->_flags->inDataTypeIn == "img" ? kImage : kBinary;
|
||||
|
||||
if (!_flags->calibDataType.empty()) {
|
||||
if (dataTypeMap.find(_flags->calibDataType) == dataTypeMap.end()) {
|
||||
MS_LOG(ERROR) << "CalibDataType not supported: " << _flags->calibDataType.c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
msCalibDataType = dataTypeMap.at(_flags->calibDataType);
|
||||
MS_LOG(INFO) << "CalibDataType = " << _flags->calibDataType.c_str();
|
||||
std::cout << "CalibDataType = " << _flags->calibDataType.c_str() << std::endl;
|
||||
}
|
||||
|
||||
if (_flags->modelPath.empty()) {
|
||||
MS_LOG(ERROR) << "modelPath is required";
|
||||
std::cerr << "modelPath is required" << std::endl;
|
||||
|
|
|
@ -23,9 +23,11 @@
|
|||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <cfloat>
|
||||
#include "include/model.h"
|
||||
#include "tools/common/flag_parser.h"
|
||||
#include "src/common/file_utils.h"
|
||||
|
@ -64,6 +66,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
|||
AddFlag(&BenchmarkFlags::warmUpLoopCount, "warmUpLoopCount", "Run warm up loop", 3);
|
||||
// MarkAccuracy
|
||||
AddFlag(&BenchmarkFlags::calibDataPath, "calibDataPath", "Calibration data file path", "");
|
||||
AddFlag(&BenchmarkFlags::calibDataType, "calibDataType", "Calibration data type. FLOAT | INT32 | INT8", "FLOAT");
|
||||
AddFlag(&BenchmarkFlags::accuracyThreshold, "accuracyThreshold", "Threshold of accuracy", 0.5);
|
||||
}
|
||||
|
||||
|
@ -88,6 +91,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
|||
int warmUpLoopCount;
|
||||
// MarkAccuracy
|
||||
std::string calibDataPath;
|
||||
std::string calibDataType;
|
||||
float accuracyThreshold;
|
||||
// Resize
|
||||
std::string resizeDimsIn = "";
|
||||
|
@ -121,7 +125,85 @@ class MS_API Benchmark {
|
|||
|
||||
int CompareOutput();
|
||||
|
||||
float CompareData(const std::string &nodeName, std::vector<int> msShape, float *msTensorData);
|
||||
template <typename T>
|
||||
void PrintInputData(tensor::MSTensor *input) {
|
||||
MS_ASSERT(input != nullptr);
|
||||
static int i = 0;
|
||||
auto inData = reinterpret_cast<T *>(input->MutableData());
|
||||
std::cout << "InData" << i++ << ": ";
|
||||
// int printSize = std::min(20, input->ElementsNum());
|
||||
for (size_t j = 0; j < 20; j++) {
|
||||
std::cout << static_cast<float >(inData[j]) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
// tensorData need to be converter first
|
||||
template <typename T>
|
||||
float CompareData(const std::string &nodeName, std::vector<int> msShape, T *msTensorData) {
|
||||
auto iter = this->calibData.find(nodeName);
|
||||
if (iter != this->calibData.end()) {
|
||||
std::vector<size_t> castedMSShape;
|
||||
size_t shapeSize = 1;
|
||||
for (int64_t dim : msShape) {
|
||||
castedMSShape.push_back(size_t(dim));
|
||||
shapeSize *= dim;
|
||||
}
|
||||
|
||||
CheckTensor *calibTensor = iter->second;
|
||||
if (calibTensor->shape != castedMSShape) {
|
||||
std::ostringstream oss;
|
||||
oss << "Shape of mslite output(";
|
||||
for (auto dim : castedMSShape) {
|
||||
oss << dim << ",";
|
||||
}
|
||||
oss << ") and shape source model output(";
|
||||
for (auto dim : calibTensor->shape) {
|
||||
oss << dim << ",";
|
||||
}
|
||||
oss << ") are different";
|
||||
std::cerr << oss.str() << std::endl;
|
||||
MS_LOG(ERROR) << oss.str().c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t errorCount = 0;
|
||||
float meanError = 0;
|
||||
std::cout << "Data of node " << nodeName << " : ";
|
||||
for (size_t j = 0; j < shapeSize; j++) {
|
||||
if (j < 50) {
|
||||
std::cout << static_cast<float>(msTensorData[j]) << " ";
|
||||
}
|
||||
|
||||
if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) {
|
||||
std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl;
|
||||
MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto tolerance = absoluteTolerance + relativeTolerance * fabs(calibTensor->data.at(j));
|
||||
auto absoluteError = std::fabs(msTensorData[j] - calibTensor->data.at(j));
|
||||
if (absoluteError > tolerance) {
|
||||
// just assume that atol = rtol
|
||||
meanError += absoluteError / (fabs(calibTensor->data.at(j)) + FLT_MIN);
|
||||
errorCount++;
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
if (meanError > 0.0f) {
|
||||
meanError /= errorCount;
|
||||
}
|
||||
|
||||
if (meanError <= 0.0000001) {
|
||||
std::cout << "Mean bias of node " << nodeName << " : 0%" << std::endl;
|
||||
} else {
|
||||
std::cout << "Mean bias of node " << nodeName << " : " << meanError * 100 << "%" << std::endl;
|
||||
}
|
||||
return meanError;
|
||||
} else {
|
||||
MS_LOG(INFO) << "%s is not in Source Model output", nodeName.c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
int MarkPerformance();
|
||||
|
||||
|
@ -133,6 +215,10 @@ class MS_API Benchmark {
|
|||
std::vector<mindspore::tensor::MSTensor *> msInputs;
|
||||
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> msOutputs;
|
||||
std::unordered_map<std::string, CheckTensor *> calibData;
|
||||
std::unordered_map<std::string, TypeId> dataTypeMap{
|
||||
{"FLOAT", TypeId::kNumberTypeFloat}, {"INT8", TypeId::kNumberTypeInt8}, {"INT32", TypeId::kNumberTypeInt32}};
|
||||
// TypeId msInputBinDataType = TypeId::kNumberTypeFloat;
|
||||
TypeId msCalibDataType = TypeId::kNumberTypeFloat;
|
||||
};
|
||||
|
||||
int MS_API RunBenchmark(int argc, const char **argv);
|
||||
|
|
|
@ -104,7 +104,7 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
|
|||
if (outputDataDType == TypeId::kNumberTypeInt8) {
|
||||
return RET_OK;
|
||||
}
|
||||
MS_ASSERT(inputDataDType == TypeId::kNumberTypeFloat);
|
||||
MS_ASSERT(outputDataDType == TypeId::kNumberTypeFloat);
|
||||
auto &graphOutIdxes = graph->outputIndex;
|
||||
for (auto graphOutIdx : graphOutIdxes) {
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
|
@ -115,11 +115,7 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
|
|||
if (node->outputIndex.at(outputIndexIdx) == graphOutIdx) {
|
||||
// insert transNode
|
||||
STATUS status = RET_OK;
|
||||
if (inputDataDType == TypeId::kNumberTypeFloat) {
|
||||
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status);
|
||||
} else {
|
||||
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status);
|
||||
}
|
||||
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed";
|
||||
return status;
|
||||
|
|
Loading…
Reference in New Issue