forked from mindspore-Ecosystem/mindspore
!6207 [MSLITE] weight quant fix
Merge pull request !6207 from wangchangkai/master
This commit is contained in:
commit
0a9938ddc8
|
@ -58,13 +58,21 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
||||||
|
if (param_value == nullptr) {
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
|
||||||
|
MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<schema::QuantParamT> quant_params;
|
std::vector<schema::QuantParamT> quant_params;
|
||||||
primitive_c->AddInputQuantParam(quant_params);
|
primitive_c->AddInputQuantParam(quant_params);
|
||||||
|
|
||||||
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
||||||
bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false;
|
bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false;
|
||||||
|
|
||||||
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
|
||||||
auto status =
|
auto status =
|
||||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant,
|
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant,
|
||||||
quant_max, quant_min, bitNum, true, depthwise);
|
quant_max, quant_min, bitNum, true, depthwise);
|
||||||
|
|
Loading…
Reference in New Issue