forked from mindspore-Ecosystem/mindspore
fix DeConv DeDepthwiseConve for post training quantization
This commit is contained in:
parent
c962ccbe07
commit
b68835f4a5
|
@ -92,6 +92,9 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
|
|||
return has_trans_count >= half_count;
|
||||
}
|
||||
}
|
||||
if (GetCNodeTType(*node) == schema::PrimitiveType_Split) {
|
||||
return has_trans_count >= half_count;
|
||||
}
|
||||
can_fusion = has_trans_count > half_count;
|
||||
return can_fusion;
|
||||
}
|
||||
|
|
|
@ -798,6 +798,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
primitive_c->SetQuantType(schema::QuantType_PostTraining);
|
||||
continue;
|
||||
} else if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
|
||||
op_type != PrimitiveType_DeConv2D && op_type != PrimitiveType_DeDepthwiseConv2D &&
|
||||
op_type != PrimitiveType_FullConnection) {
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
auto input_node = cnode->input(i);
|
||||
|
@ -847,7 +848,8 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
// do weight quant
|
||||
auto weight = cnode->input(2);
|
||||
bool perchannel = per_channel_;
|
||||
if (op_type == PrimitiveType_FullConnection) {
|
||||
if (op_type == PrimitiveType_FullConnection || op_type == PrimitiveType_DeConv2D ||
|
||||
op_type == PrimitiveType_DeDepthwiseConv2D) {
|
||||
perchannel = false;
|
||||
}
|
||||
DoWeightQuant(weight, primitive_c, perchannel);
|
||||
|
|
|
@ -65,7 +65,7 @@ class PostTrainingQuantizer : public Quantizer {
|
|||
int quant_min{INT8_MIN};
|
||||
|
||||
private:
|
||||
bool per_channel_;
|
||||
bool per_channel_{true};
|
||||
|
||||
TypeId target_type_{kNumberTypeInt8};
|
||||
|
||||
|
|
|
@ -108,7 +108,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
|
|||
schema::PrimitiveType_DeDepthwiseConv2D,
|
||||
schema::PrimitiveType_DeConv2D,
|
||||
schema::PrimitiveType_Activation,
|
||||
schema::PrimitiveType_TupleGetItem,
|
||||
schema::PrimitiveType_Eltwise,
|
||||
};
|
||||
bool contain = IsContain(int8OpList, type);
|
||||
if (!contain) {
|
||||
|
|
Loading…
Reference in New Issue