!3815 compute threshold only once in post training quantization
Merge pull request !3815 from xutianchun/quant_0731
This commit is contained in:
commit
6e759cd487
|
@ -98,7 +98,6 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
|
|||
}
|
||||
|
||||
node->primitive = std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT());
|
||||
primitiveT_value->SetPrimitiveT(nullptr);
|
||||
std::vector<schema::TensorT *> outputs;
|
||||
SetOpInputNode(cnode, metaGraphT.get(), node.get());
|
||||
SetOpOutputNode(outputs, metaGraphT.get(), node.get());
|
||||
|
@ -113,24 +112,22 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
|
|||
auto input_quant_params = primitiveT_value->GetInputQuantParams();
|
||||
if (input_quant_params.empty()) {
|
||||
MS_LOG(WARNING) << "node: " << node->name << " input quant params is empty";
|
||||
continue;
|
||||
} else {
|
||||
std::unique_ptr<schema::QuantParamT> input_quant_param =
|
||||
std::make_unique<schema::QuantParamT>(input_quant_params[0]);
|
||||
tensor_input->quantParams.emplace_back(std::move(input_quant_param));
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::QuantParamT> input_quant_param =
|
||||
std::make_unique<schema::QuantParamT>(input_quant_params[0]);
|
||||
tensor_input->quantParams.emplace_back(std::move(input_quant_param));
|
||||
// output
|
||||
auto output_index = node->outputIndex[0];
|
||||
auto tensor_output = metaGraphT->allTensors[output_index].get();
|
||||
auto output_quant_params = primitiveT_value->GetOutputQuantParams();
|
||||
if (output_quant_params.empty()) {
|
||||
MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty";
|
||||
continue;
|
||||
} else {
|
||||
std::unique_ptr<schema::QuantParamT> output_quant_param =
|
||||
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
|
||||
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::QuantParamT> output_quant_param =
|
||||
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
|
||||
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
|
||||
// // TensorType
|
||||
// valuePtr = primitive->GetAttr(kInputTensorDataType);
|
||||
// if (valuePtr != nullptr) {
|
||||
|
|
|
@ -26,8 +26,8 @@ namespace mindspore::lite {
|
|||
class PrimitiveTValue : public Value {
|
||||
public:
|
||||
explicit PrimitiveTValue(schema::PrimitiveT *primt) : primitive(primt) {}
|
||||
|
||||
~PrimitiveTValue() override { delete this->primitive; }
|
||||
// not responsible to free primitive, the one created the dynamic memory is responsible to free it.
|
||||
~PrimitiveTValue() override = default;
|
||||
|
||||
MS_DECLARE_PARENT(PrimitiveTValue, Value)
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) {
|
|||
MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status;
|
||||
return status;
|
||||
}
|
||||
if (this->quantType == QuantType_AwareTrainning) {
|
||||
if (this->quantType == QuantType_AwareTrainning || this->quantType == QuantType_PostTraining) {
|
||||
status = QuantDataFormatTrans(graphNode);
|
||||
if (status != 0) {
|
||||
MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status;
|
||||
|
@ -147,7 +147,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
|
|||
} else if (fmkType == converter::FmkType_TFLITE) {
|
||||
switch (node->quantType) {
|
||||
case QuantType_QUANT_NONE:
|
||||
case QuantType_AwareTrainning: {
|
||||
case QuantType_AwareTrainning:
|
||||
case QuantType_PostTraining: {
|
||||
if (opType == schema::PrimitiveType_Conv2D) {
|
||||
weightTensor->format = schema::Format_KHWC;
|
||||
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
|
|
|
@ -292,14 +292,33 @@ STATUS Calibrator::RecordMaxValue(std::string opName, vector<float> data,
|
|||
}
|
||||
|
||||
STATUS Calibrator::ComputeThreshold() {
|
||||
for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) {
|
||||
DivergInfo *info = iter->second.get();
|
||||
info->ComputeThreshold();
|
||||
}
|
||||
for (auto iter = this->output_diverg_info_.begin(); iter != this->output_diverg_info_.end(); iter++) {
|
||||
DivergInfo *info = iter->second.get();
|
||||
info->ComputeThreshold();
|
||||
}
|
||||
// node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as
|
||||
for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) {
|
||||
DivergInfo *info = iter->second.get();
|
||||
auto cnode = info->cnode;
|
||||
|
||||
bool already_computed = false;
|
||||
auto input = cnode->input(1);
|
||||
if (input->isa<mindspore::CNode>()) {
|
||||
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input);
|
||||
for (const auto &output_diverg_info : output_diverg_info_) {
|
||||
auto output_diverg_cnode = output_diverg_info.second->cnode;
|
||||
if (output_diverg_cnode == input_cnode) {
|
||||
*info = *(output_diverg_info.second);
|
||||
info->cnode = cnode;
|
||||
already_computed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!already_computed) {
|
||||
info->ComputeThreshold();
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue