int32 tensors don't insert dtype_cast op

This commit is contained in:
cjh9368 2020-09-18 14:19:27 +08:00
parent a161a3375a
commit c5db8e0a32
6 changed files with 19 additions and 18 deletions

View File

@ -63,7 +63,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
auto *weight_tensor = inputs.at(kWeightIndex);
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();
if (!weight_tensor->GetQuantParams().empty()) {
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
@ -91,7 +91,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
}
return nullptr;
}
if (!weight_tensor->GetQuantParams().empty()) {
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}

View File

@ -93,7 +93,7 @@ std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInp
std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; }
std::vector<schema::PrimitiveType> GetUint8OpList() { return int8OpList; }
std::vector<schema::PrimitiveType> GetInt8OpList() { return int8OpList; }
STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims,
mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) {

View File

@ -42,7 +42,7 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList();
std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
std::vector<schema::PrimitiveType> GetUint8OpList();
std::vector<schema::PrimitiveType> GetInt8OpList();
class NodeUtils {
public:

View File

@ -51,13 +51,7 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// modify inputTensor first
auto &graphInIdxes = graph->inputIndex;
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graph->allTensors.size() > graphInIdx);
auto &graphInTensor = graph->allTensors.at(graphInIdx);
graphInTensor->dataType = TypeId::kNumberTypeInt8;
}
if (this->inputDataDType == TypeId::kNumberTypeInt8) {
return RET_OK;
@ -70,7 +64,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graphInIdx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphInIdx);
if (tensor->dims.size() != kNHWCDimNumber) {
if (tensor->dims.size() != kNHWCDimNumber || tensor->dataType != kNumberTypeInt8) {
continue;
}
@ -137,7 +131,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// insert transNode before and after existNode
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
continue;
}
if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) {
@ -157,10 +151,16 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i));
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i));
if (preTensor->dataType == TypeId::kNumberTypeInt || preTensor->dataType == TypeId::kNumberTypeInt32) {
continue;
}
auto &graphInIdxes = graph->inputIndex;
if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) {
continue;
}
if (IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed";
@ -170,6 +170,10 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
if (needInsertPost) {
for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) {
auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i));
if (postTensor->dataType == TypeId::kNumberTypeInt || postTensor->dataType == TypeId::kNumberTypeInt32) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";

View File

@ -79,6 +79,7 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor
// change quant param min to 0 to fit ms-lite ops
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
quant_param->zeroPoint = quant_param->zeroPoint - 128;
tensor->dataType = TypeId::kNumberTypeInt8;
}
if (!tflite_tensor->quantization->min.empty()) {
@ -164,11 +165,7 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT>
MS_LOG(ERROR) << "obtain const tensor failed";
return status;
}
} else if (quantType == QuantType_AwareTraining && tensor->dataType == TypeId::kNumberTypeUInt8) {
// set in/out tensor to int8 to fit ms-lite op
tensor->dataType = TypeId::kNumberTypeInt8;
}
// set tensor attr
if (isInput || isConst) {
tensor->nodeType = schema::NodeType::NodeType_ValueNode;

View File

@ -145,7 +145,7 @@ STATUS AwareQuantizer::GenerateQuantParam() {
STATUS AwareQuantizer::DoQuantize() {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
if (!IsContain(GetInt8OpList(), GetCNodeTType(*node))) {
continue;
}
if (node->quantType != schema::QuantType_AwareTraining) {
@ -388,7 +388,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
}
}
if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*node))) {
node->quantType = schema::QuantType_AwareTraining;
} else {
node->quantType = schema::QuantType_QUANT_NONE;