fix DeConv DeDepthwiseConve for post training quantization

This commit is contained in:
xutianchun 2020-10-24 12:01:01 +08:00
parent c962ccbe07
commit b68835f4a5
4 changed files with 8 additions and 3 deletions

View File

@ -92,6 +92,9 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
return has_trans_count >= half_count;
}
}
if (GetCNodeTType(*node) == schema::PrimitiveType_Split) {
return has_trans_count >= half_count;
}
can_fusion = has_trans_count > half_count;
return can_fusion;
}

View File

@ -798,6 +798,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
primitive_c->SetQuantType(schema::QuantType_PostTraining);
continue;
} else if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
op_type != PrimitiveType_DeConv2D && op_type != PrimitiveType_DeDepthwiseConv2D &&
op_type != PrimitiveType_FullConnection) {
for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i);
@ -847,7 +848,8 @@ STATUS PostTrainingQuantizer::QuantNode() {
// do weight quant
auto weight = cnode->input(2);
bool perchannel = per_channel_;
if (op_type == PrimitiveType_FullConnection) {
if (op_type == PrimitiveType_FullConnection || op_type == PrimitiveType_DeConv2D ||
op_type == PrimitiveType_DeDepthwiseConv2D) {
perchannel = false;
}
DoWeightQuant(weight, primitive_c, perchannel);

View File

@ -65,7 +65,7 @@ class PostTrainingQuantizer : public Quantizer {
int quant_min{INT8_MIN};
private:
bool per_channel_;
bool per_channel_{true};
TypeId target_type_{kNumberTypeInt8};

View File

@ -108,7 +108,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_Activation,
schema::PrimitiveType_TupleGetItem,
schema::PrimitiveType_Eltwise,
};
bool contain = IsContain(int8OpList, type);
if (!contain) {