From fb129986e512dab6bb63942d7592b041bfb6c5fd Mon Sep 17 00:00:00 2001 From: kai00 Date: Mon, 14 Sep 2020 20:24:28 +0800 Subject: [PATCH] fix weight quant --- .../lite/tools/converter/quantizer/weight_quantizer.cc | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index e6c1110d7e4..de07c4ef2b0 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -58,13 +58,21 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { return RET_ERROR; } + ParamValueLitePtr param_value = std::static_pointer_cast(param_node->default_param()); + if (param_value == nullptr) { + return RET_ERROR; + } + if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { + MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type(); + return RET_ERROR; + } + std::vector quant_params; primitive_c->AddInputQuantParam(quant_params); auto op_type = (schema::PrimitiveType)primitive_c->Type(); bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false; - ParamValueLitePtr param_value = std::static_pointer_cast(param_node->default_param()); auto status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true, depthwise);