add ci test for onnx and mindir models.

This commit is contained in:
wang_shaocong 2020-11-26 11:06:19 +08:00
parent eb696440d0
commit 33b1bae821
5 changed files with 14 additions and 7 deletions

View File

@ -111,8 +111,10 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found.";
return RET_ERROR;
}
auto quant_arg = out_tensors_.front()->quant_params().front().inited ? out_tensors_.front()->quant_params().front()
: in_tensors_.front()->quant_params().front();
auto quant_arg =
(!out_tensors_.front()->quant_params().empty() && out_tensors_.front()->quant_params().front().inited)
? out_tensors_.front()->quant_params().front()
: in_tensors_.front()->quant_params().front();
int ret = RET_OK;
if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeFloat32) {
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,

View File

@ -12,3 +12,6 @@ ocr_mobilenetV2.mindir 1.5
mobilenet_quant.mindir 5
mindspore_ghostnet_ssd_13x.mindir 1.5
mindspore_ghost-nose-pets-811.mindir 0.5
mindspore_ghost-pets-8244.mindir 1.5
mindspore_ghostnet600M-pets.mindir 1.5
mindspore_ghostnet_1x_pets_int8.mindir 12

View File

@ -27,3 +27,6 @@ mtk_transformer_encoder.tflite
mtk_transformer_decoder_joint.tflite
ml_ei_facedetection.onnx
mobilebert_1_default_1.tflite
quant_aware_bank_card_detection_inception.onnx
quant_aware_bank_card_recognition_fcny.onnx
quant_aware_identify_card_detect.onnx

View File

@ -158,7 +158,7 @@ lite::STATUS ReplaceCNode(const FuncGraphPtr &func_graph, const CNodePtr &any_no
MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope();
return lite::RET_ERROR;
}
new_parameter->set_name(input_node->fullname_with_scope());
new_parameter->set_name("constfold_" + input_node->fullname_with_scope());
manager->Replace(input_node, new_parameter);
}
return lite::RET_OK;

View File

@ -103,6 +103,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
const ParamValueLitePtr &param_value) const {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(param_value != nullptr);
auto weight_node = conv_node->cast<CNodePtr>()->input(kConvWeightIndex);
auto op_type = GetCNodeType(conv_node);
switch (this->quant_type) {
case QuantType_AwareTraining: {
@ -121,10 +122,8 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
if (op_type == schema::PrimitiveType_Conv2D) {
param_value->set_format(schema::Format::Format_KCHW);
} else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
// the format is initialized to NUM_OF_FORMAT, and set to NHWC in const folding.
if (param_value->format() == schema::Format::Format_NHWC) {
param_value->set_format(schema::Format::Format_KCHW);
} else {
// the format should be set to KCHW while the weight is output of constfolding .
if (weight_node->fullname_with_scope().find("constfold") == weight_node->fullname_with_scope().npos) {
param_value->set_format(schema::Format::Format_CKHW);
}
} else if (op_type == schema::PrimitiveType_DeDepthwiseConv2D) {