add ci test for onnx and mindir models.

This commit is contained in:
wang_shaocong 2020-11-26 11:06:19 +08:00
parent eb696440d0
commit 33b1bae821
5 changed files with 14 additions and 7 deletions

View File

@ -111,8 +111,10 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found."; MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found.";
return RET_ERROR; return RET_ERROR;
} }
auto quant_arg = out_tensors_.front()->quant_params().front().inited ? out_tensors_.front()->quant_params().front() auto quant_arg =
: in_tensors_.front()->quant_params().front(); (!out_tensors_.front()->quant_params().empty() && out_tensors_.front()->quant_params().front().inited)
? out_tensors_.front()->quant_params().front()
: in_tensors_.front()->quant_params().front();
int ret = RET_OK; int ret = RET_OK;
if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeFloat32) {
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,

View File

@ -12,3 +12,6 @@ ocr_mobilenetV2.mindir 1.5
mobilenet_quant.mindir 5 mobilenet_quant.mindir 5
mindspore_ghostnet_ssd_13x.mindir 1.5 mindspore_ghostnet_ssd_13x.mindir 1.5
mindspore_ghost-nose-pets-811.mindir 0.5 mindspore_ghost-nose-pets-811.mindir 0.5
mindspore_ghost-pets-8244.mindir 1.5
mindspore_ghostnet600M-pets.mindir 1.5
mindspore_ghostnet_1x_pets_int8.mindir 12

View File

@ -27,3 +27,6 @@ mtk_transformer_encoder.tflite
mtk_transformer_decoder_joint.tflite mtk_transformer_decoder_joint.tflite
ml_ei_facedetection.onnx ml_ei_facedetection.onnx
mobilebert_1_default_1.tflite mobilebert_1_default_1.tflite
quant_aware_bank_card_detection_inception.onnx
quant_aware_bank_card_recognition_fcny.onnx
quant_aware_identify_card_detect.onnx

View File

@ -158,7 +158,7 @@ lite::STATUS ReplaceCNode(const FuncGraphPtr &func_graph, const CNodePtr &any_no
MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope(); MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope();
return lite::RET_ERROR; return lite::RET_ERROR;
} }
new_parameter->set_name(input_node->fullname_with_scope()); new_parameter->set_name("constfold_" + input_node->fullname_with_scope());
manager->Replace(input_node, new_parameter); manager->Replace(input_node, new_parameter);
} }
return lite::RET_OK; return lite::RET_OK;

View File

@ -103,6 +103,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
const ParamValueLitePtr &param_value) const { const ParamValueLitePtr &param_value) const {
MS_ASSERT(conv_cnode != nullptr); MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(param_value != nullptr); MS_ASSERT(param_value != nullptr);
auto weight_node = conv_node->cast<CNodePtr>()->input(kConvWeightIndex);
auto op_type = GetCNodeType(conv_node); auto op_type = GetCNodeType(conv_node);
switch (this->quant_type) { switch (this->quant_type) {
case QuantType_AwareTraining: { case QuantType_AwareTraining: {
@ -121,10 +122,8 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
if (op_type == schema::PrimitiveType_Conv2D) { if (op_type == schema::PrimitiveType_Conv2D) {
param_value->set_format(schema::Format::Format_KCHW); param_value->set_format(schema::Format::Format_KCHW);
} else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
// the format is initialized to NUM_OF_FORMAT, and set to NHWC in const folding. // the format should be set to KCHW while the weight is output of constfolding .
if (param_value->format() == schema::Format::Format_NHWC) { if (weight_node->fullname_with_scope().find("constfold") == weight_node->fullname_with_scope().npos) {
param_value->set_format(schema::Format::Format_KCHW);
} else {
param_value->set_format(schema::Format::Format_CKHW); param_value->set_format(schema::Format::Format_CKHW);
} }
} else if (op_type == schema::PrimitiveType_DeDepthwiseConv2D) { } else if (op_type == schema::PrimitiveType_DeDepthwiseConv2D) {