forked from mindspore-Ecosystem/mindspore
int32 tensors don't insert dtype_cast op
This commit is contained in:
parent
a161a3375a
commit
c5db8e0a32
|
@ -63,7 +63,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
|
|||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
// data of second tensor of fc may be nullptr
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
if (!weight_tensor->GetQuantParams().empty()) {
|
||||
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
|
||||
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
|
@ -91,7 +91,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
if (!weight_tensor->GetQuantParams().empty()) {
|
||||
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
|
|
|
@ -93,7 +93,7 @@ std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInp
|
|||
|
||||
std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; }
|
||||
|
||||
std::vector<schema::PrimitiveType> GetUint8OpList() { return int8OpList; }
|
||||
std::vector<schema::PrimitiveType> GetInt8OpList() { return int8OpList; }
|
||||
|
||||
STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims,
|
||||
mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) {
|
||||
|
|
|
@ -42,7 +42,7 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList();
|
|||
|
||||
std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
|
||||
|
||||
std::vector<schema::PrimitiveType> GetUint8OpList();
|
||||
std::vector<schema::PrimitiveType> GetInt8OpList();
|
||||
|
||||
class NodeUtils {
|
||||
public:
|
||||
|
|
|
@ -51,13 +51,7 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
|
|||
|
||||
STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
// modify inputTensor first
|
||||
auto &graphInIdxes = graph->inputIndex;
|
||||
for (auto graphInIdx : graphInIdxes) {
|
||||
MS_ASSERT(graph->allTensors.size() > graphInIdx);
|
||||
auto &graphInTensor = graph->allTensors.at(graphInIdx);
|
||||
graphInTensor->dataType = TypeId::kNumberTypeInt8;
|
||||
}
|
||||
|
||||
if (this->inputDataDType == TypeId::kNumberTypeInt8) {
|
||||
return RET_OK;
|
||||
|
@ -70,7 +64,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
|
|||
for (auto graphInIdx : graphInIdxes) {
|
||||
MS_ASSERT(graphInIdx < graph->allTensors.size());
|
||||
auto &tensor = graph->allTensors.at(graphInIdx);
|
||||
if (tensor->dims.size() != kNHWCDimNumber) {
|
||||
if (tensor->dims.size() != kNHWCDimNumber || tensor->dataType != kNumberTypeInt8) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -137,7 +131,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
|
|||
MS_ASSERT(graph != nullptr);
|
||||
// insert transNode before and after existNode
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
|
||||
if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
|
||||
continue;
|
||||
}
|
||||
if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) {
|
||||
|
@ -157,10 +151,16 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
|
|||
for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) {
|
||||
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i));
|
||||
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i));
|
||||
if (preTensor->dataType == TypeId::kNumberTypeInt || preTensor->dataType == TypeId::kNumberTypeInt32) {
|
||||
continue;
|
||||
}
|
||||
auto &graphInIdxes = graph->inputIndex;
|
||||
if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) {
|
||||
continue;
|
||||
}
|
||||
if (IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) {
|
||||
continue;
|
||||
}
|
||||
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed";
|
||||
|
@ -170,6 +170,10 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
|
|||
|
||||
if (needInsertPost) {
|
||||
for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) {
|
||||
auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i));
|
||||
if (postTensor->dataType == TypeId::kNumberTypeInt || postTensor->dataType == TypeId::kNumberTypeInt32) {
|
||||
continue;
|
||||
}
|
||||
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";
|
||||
|
|
|
@ -79,6 +79,7 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor
|
|||
// change quant param min to 0 to fit ms-lite ops
|
||||
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
|
||||
quant_param->zeroPoint = quant_param->zeroPoint - 128;
|
||||
tensor->dataType = TypeId::kNumberTypeInt8;
|
||||
}
|
||||
|
||||
if (!tflite_tensor->quantization->min.empty()) {
|
||||
|
@ -164,11 +165,7 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT>
|
|||
MS_LOG(ERROR) << "obtain const tensor failed";
|
||||
return status;
|
||||
}
|
||||
} else if (quantType == QuantType_AwareTraining && tensor->dataType == TypeId::kNumberTypeUInt8) {
|
||||
// set in/out tensor to int8 to fit ms-lite op
|
||||
tensor->dataType = TypeId::kNumberTypeInt8;
|
||||
}
|
||||
|
||||
// set tensor attr
|
||||
if (isInput || isConst) {
|
||||
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
|
|
|
@ -145,7 +145,7 @@ STATUS AwareQuantizer::GenerateQuantParam() {
|
|||
STATUS AwareQuantizer::DoQuantize() {
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto &node = *iter;
|
||||
if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
|
||||
if (!IsContain(GetInt8OpList(), GetCNodeTType(*node))) {
|
||||
continue;
|
||||
}
|
||||
if (node->quantType != schema::QuantType_AwareTraining) {
|
||||
|
@ -388,7 +388,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
|
|||
}
|
||||
}
|
||||
|
||||
if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
|
||||
if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*node))) {
|
||||
node->quantType = schema::QuantType_AwareTraining;
|
||||
} else {
|
||||
node->quantType = schema::QuantType_QUANT_NONE;
|
||||
|
|
Loading…
Reference in New Issue