!3723 Insert quant cast node after post training quantization
Merge pull request !3723 from xutianchun/quant_0730
This commit is contained in:
commit
6f6c8ccb1a
|
@ -32,6 +32,7 @@
|
||||||
#include "tools/converter/parser/onnx/onnx.pb.h"
|
#include "tools/converter/parser/onnx/onnx.pb.h"
|
||||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||||
#include "tools/converter/quantizer/post_training.h"
|
#include "tools/converter/quantizer/post_training.h"
|
||||||
|
#include "tools/converter/quantizer/quant_cast.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -90,7 +91,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// auto newGraph = anfTransform->Transform(graph);
|
// auto newGraph = anfTransform->Transform(graph);
|
||||||
/*
|
|
||||||
CreateQuantizer(graph, flag);
|
CreateQuantizer(graph, flag);
|
||||||
if (mQuantizer != nullptr) {
|
if (mQuantizer != nullptr) {
|
||||||
auto status = mQuantizer->DoQuantize(graph);
|
auto status = mQuantizer->DoQuantize(graph);
|
||||||
|
@ -98,8 +99,15 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
||||||
MS_LOG(ERROR) << "Quant failed " << status;
|
MS_LOG(ERROR) << "Quant failed " << status;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
quant::QuantCast quant_cast;
|
||||||
|
quant_cast.SetInputDataDType(kNumberTypeFloat32);
|
||||||
|
status = quant_cast.Run(graph);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "add QuantCast error";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
// anf -- fb
|
// anf -- fb
|
||||||
auto meta_graph = Export(graph);
|
auto meta_graph = Export(graph);
|
||||||
if (meta_graph == nullptr) {
|
if (meta_graph == nullptr) {
|
||||||
|
|
|
@ -11,6 +11,7 @@ add_library(quantizer_mid OBJECT
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/post_training.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/post_training.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc
|
||||||
#${CMAKE_CURRENT_SOURCE_DIR}/../proto/post_training/post_training.pb.cc
|
#${CMAKE_CURRENT_SOURCE_DIR}/../proto/post_training/post_training.pb.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -732,9 +732,10 @@ STATUS PostTrainingQuantizer::CheckTensorVec(const std::string &nodeName,
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
tensor::Tensor *tensor = tensorVec[0];
|
tensor::Tensor *tensor = tensorVec[0];
|
||||||
if (tensor->data_type() != kNumberTypeFloat) {
|
if (tensor->data_type() != kNumberTypeFloat32) {
|
||||||
//&& tensor->RefCount() != MSCONST_WEIGHT_REFCOUNT
|
//&& tensor->RefCount() != MSCONST_WEIGHT_REFCOUNT
|
||||||
MS_LOG(DEBUG) << "node: " << nodeName << " will not quantize";
|
MS_LOG(DEBUG) << "node: " << nodeName << " will not quantize" << " tensor data_type: " << tensor->data_type();
|
||||||
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
#include "mindspore/lite/tools/converter/quantizer/quant_cast.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "mindspore/lite/src/ir/primitive_t_value.h"
|
||||||
|
|
||||||
|
namespace mindspore::lite::quant {
|
||||||
|
|
||||||
|
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type) {
|
||||||
|
std::unique_ptr<schema::PrimitiveT> primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
schema::QuantDTypeCastT quant_dtype_cast;
|
||||||
|
quant_dtype_cast.srcT = src_type; // kNumberTypeUInt8;
|
||||||
|
quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32;
|
||||||
|
primitive->value.Set(quant_dtype_cast);
|
||||||
|
auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release());
|
||||||
|
return NewValueNode(primTValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS QuantCast::Run(FuncGraphPtr graph) {
|
||||||
|
MS_ASSERT(graph != nullptr);
|
||||||
|
|
||||||
|
auto cnodes = graph->GetOrderedCnodes();
|
||||||
|
bool first = true;
|
||||||
|
|
||||||
|
for (auto &cnode : cnodes) {
|
||||||
|
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
|
||||||
|
auto curnode_quant_type = schema::QuantType_QUANT_NONE;
|
||||||
|
if (primitiveT_value == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
|
||||||
|
} else {
|
||||||
|
curnode_quant_type = primitiveT_value->GetQuantType();
|
||||||
|
}
|
||||||
|
if (first) {
|
||||||
|
if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) {
|
||||||
|
auto value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8);
|
||||||
|
std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)};
|
||||||
|
auto quant_cast_cnode = graph->NewCNode(op_inputs);
|
||||||
|
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast");
|
||||||
|
cnode->set_input(1, quant_cast_cnode);
|
||||||
|
MS_LOG(DEBUG) << "Add quant cast at front. "
|
||||||
|
<< "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type;
|
||||||
|
}
|
||||||
|
first = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 1; i < cnode->inputs().size(); i++) {
|
||||||
|
auto input_node = cnode->input(i);
|
||||||
|
if (!input_node->isa<CNode>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto input_cnode = std::dynamic_pointer_cast<CNode>(input_node);
|
||||||
|
auto input_cnode_primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(input_cnode->input(0));
|
||||||
|
if (input_cnode_primitiveT_value == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
|
||||||
|
<< " PrimitiveTValue is null";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto input_cnode_quant_type = input_cnode_primitiveT_value->GetQuantType();
|
||||||
|
|
||||||
|
if (curnode_quant_type != input_cnode_quant_type) {
|
||||||
|
ValueNodePtr value_node = nullptr;
|
||||||
|
if (curnode_quant_type == schema::QuantType_PostTraining &&
|
||||||
|
input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
|
||||||
|
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8);
|
||||||
|
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
|
||||||
|
input_cnode_quant_type == schema::QuantType_PostTraining) {
|
||||||
|
value_node = NewQuantCastValueNode(kNumberTypeUInt8, kNumberTypeFloat32);
|
||||||
|
}
|
||||||
|
if (value_node == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "value_node is null! "
|
||||||
|
<< "cur_node: " << cnode->fullname_with_scope() << " quant_type: "
|
||||||
|
<< " input_" << i << ": " << input_cnode->fullname_with_scope()
|
||||||
|
<< " quant_type:" << input_cnode_quant_type;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::vector<AnfNodePtr> op_inputs = {value_node, input_cnode};
|
||||||
|
auto quant_cast_cnode = graph->NewCNode(op_inputs);
|
||||||
|
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast");
|
||||||
|
cnode->set_input(i, quant_cast_cnode);
|
||||||
|
MS_LOG(DEBUG) << "Add quant cast. "
|
||||||
|
<< "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type
|
||||||
|
<< " input_" << i << ": " << input_cnode->fullname_with_scope()
|
||||||
|
<< " quant_type:" << input_cnode_quant_type;
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "No need to add quant cast. "
|
||||||
|
<< "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type
|
||||||
|
<< " input_" << i << ": " << input_cnode->fullname_with_scope()
|
||||||
|
<< " quant_type:" << input_cnode_quant_type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore::lite::quant
|
|
@ -0,0 +1,39 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef LITE_QUANT_CAST_H
|
||||||
|
#define LITE_QUANT_CAST_H
|
||||||
|
|
||||||
|
#include "mindspore/core/ir/anf.h"
|
||||||
|
#include "mindspore/lite/include/errorcode.h"
|
||||||
|
#include "mindspore/core/ir/dtype/type_id.h"
|
||||||
|
#include "mindspore/core/ir/func_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore::lite::quant {
|
||||||
|
|
||||||
|
class QuantCast {
|
||||||
|
public:
|
||||||
|
QuantCast() = default;
|
||||||
|
STATUS Run(FuncGraphPtr graph);
|
||||||
|
void SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
TypeId inputDataDType = kNumberTypeFloat32;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mindspore::lite::quant
|
||||||
|
|
||||||
|
#endif // LITE_QUANT_CAST_H
|
|
@ -88,7 +88,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
|
||||||
|
|
||||||
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
|
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
|
||||||
if (primitiveT_value == nullptr) {
|
if (primitiveT_value == nullptr) {
|
||||||
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
|
MS_LOG(ERROR) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue