forked from mindspore-Ecosystem/mindspore
!6298 [MSLITE]deconv weight quant fix
Merge pull request !6298 from wangchangkai/master
This commit is contained in:
commit
8f0c863efe
|
@ -231,9 +231,23 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->MutableData();
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->SetData(dequant_weight);
|
||||
}
|
||||
auto kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
|
@ -241,8 +255,18 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
delete kernel;
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -199,10 +199,24 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
|
|||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->MutableData();
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->SetData(dequant_weight);
|
||||
}
|
||||
auto kernel =
|
||||
new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
|
@ -210,8 +224,16 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
|
|||
delete kernel;
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -53,16 +53,19 @@ STATUS WeightFormatTransformPass::QuantDataFormatTrans(MetaGraphT *graph) {
|
|||
MS_ASSERT(node != nullptr);
|
||||
MS_ASSERT(node->primitive != nullptr);
|
||||
auto opType = node->primitive->value.type;
|
||||
if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D) {
|
||||
if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D &&
|
||||
opType != PrimitiveType_DeConv2D && opType != PrimitiveType_DeDepthwiseConv2D) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(node->inputIndex.size() >= 2);
|
||||
auto weightIndex = node->inputIndex.at(1);
|
||||
MS_ASSERT(subGraph->allTensors.size() > weightIndex);
|
||||
auto &weightTensor = graph->allTensors[weightIndex];
|
||||
MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT);
|
||||
MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT ||
|
||||
weightTensor->dataType == DataType_DT_INT8);
|
||||
STATUS status;
|
||||
if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D) { // weight should be HWCK
|
||||
if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D ||
|
||||
opType == PrimitiveType_DeConv2D || opType == PrimitiveType_DeDepthwiseConv2D) { // weight should be HWCK
|
||||
Format curDstFormat;
|
||||
if (this->dstFormat == Format_NUM_OF_FORMAT) {
|
||||
curDstFormat = Format_KHWC;
|
||||
|
|
|
@ -80,7 +80,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get());
|
||||
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quantType);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "ParseLayer failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
@ -177,7 +177,8 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, T
|
|||
}
|
||||
|
||||
STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight,
|
||||
TensorCache *tensorCache, schema::MetaGraphT *subGraphDef) {
|
||||
TensorCache *tensorCache, schema::MetaGraphT *subGraphDef,
|
||||
const QuantType &quantType) {
|
||||
for (int i = 0; i < proto.layer_size(); i++) {
|
||||
auto layer = proto.layer(i);
|
||||
|
||||
|
@ -214,7 +215,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff
|
|||
|
||||
std::unique_ptr<schema::CNodeT> op = std::make_unique<schema::CNodeT>();
|
||||
op->name = layer.name();
|
||||
|
||||
op->quantType = quantType;
|
||||
if (layer.type() == "Split") {
|
||||
for (int j = 0; j < layer.top_size(); ++j) {
|
||||
splitLayer.emplace(layer.top(j), layer.bottom(0));
|
||||
|
|
|
@ -50,7 +50,7 @@ class CaffeModelParser : public ModelParser {
|
|||
schema::MetaGraphT *subGraphDef);
|
||||
|
||||
STATUS ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, TensorCache *tensorCache,
|
||||
schema::MetaGraphT *subGraphDef);
|
||||
schema::MetaGraphT *subGraphDef, const QuantType &quantType);
|
||||
|
||||
STATUS GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache);
|
||||
|
||||
|
|
|
@ -247,9 +247,10 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
|
|||
|
||||
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
|
||||
TensorCache *tensor_cache) {
|
||||
TensorCache *tensor_cache, const QuantType &quantType) {
|
||||
// change op_type() to name(), that is unique
|
||||
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
|
||||
dst_op->quantType = quantType;
|
||||
// dst_op->fmkType = FmkType_ONNX;
|
||||
MS_LOG(DEBUG) << "onnx op name " << onnx_node.op_type() << ", dst op name: " << dst_op->name << ", input size "
|
||||
<< onnx_node.input_size();
|
||||
|
@ -520,7 +521,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|||
|
||||
std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
|
||||
std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>();
|
||||
status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache);
|
||||
status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
|
|
@ -61,7 +61,8 @@ class OnnxModelParser : public ModelParser {
|
|||
TensorCache *tensor_cache, int *index);
|
||||
|
||||
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache);
|
||||
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache,
|
||||
const QuantType &quantType);
|
||||
|
||||
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::MetaGraphT *graph, TensorCache *tensor_cache);
|
||||
|
|
|
@ -32,22 +32,24 @@ using std::vector;
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace quant {
|
||||
const std::array<std::string, 4> QuantStrategy::mConvTypes = {
|
||||
{"Conv2D", "DeConv2D", "DepthwiseConv2D", "DeDepthwiseConv2D"}};
|
||||
const std::array<std::string, 4> QuantStrategy::mMulTypes = {{"Mul", "MatMul", "BatchMatMul", "FullConnection"}};
|
||||
|
||||
const std::vector<schema::PrimitiveType> QuantStrategy::conv_types = {
|
||||
schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
|
||||
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D};
|
||||
const std::vector<schema::PrimitiveType> QuantStrategy::mul_types = {
|
||||
schema::PrimitiveType_Mul, schema::PrimitiveType_MatMul, schema::PrimitiveType_FullConnection};
|
||||
QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold)
|
||||
: mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {}
|
||||
|
||||
bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
|
||||
size_t i = 0;
|
||||
for (i = 0; i < mConvTypes.size(); i++) {
|
||||
if (node->fullname_with_scope().find(mConvTypes[i]) == 0) {
|
||||
break;
|
||||
}
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((i == mConvTypes.size()) || (node->size() < 3)) {
|
||||
if (!IsContain(conv_types, (schema::PrimitiveType)primitive_c->Type())) {
|
||||
return false;
|
||||
}
|
||||
if (node->size() < 3) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -107,13 +109,13 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
|
|||
}
|
||||
|
||||
bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
|
||||
size_t i = 0;
|
||||
for (i = 0; i < mMulTypes.size(); i++) {
|
||||
if (node->fullname_with_scope().find(mMulTypes[i]) == 0) {
|
||||
break;
|
||||
}
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return false;
|
||||
}
|
||||
if (i == mMulTypes.size()) {
|
||||
|
||||
if (!IsContain(mul_types, (schema::PrimitiveType)primitive_c->Type())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -57,9 +57,8 @@ class QuantStrategy {
|
|||
private:
|
||||
size_t mWeightSize;
|
||||
size_t mConvWeightQuantChannelThreshold;
|
||||
|
||||
static const std::array<std::string, 4> mConvTypes;
|
||||
static const std::array<std::string, 4> mMulTypes;
|
||||
static const std::vector<schema::PrimitiveType> conv_types;
|
||||
static const std::vector<schema::PrimitiveType> mul_types;
|
||||
};
|
||||
|
||||
STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max,
|
||||
|
|
|
@ -69,13 +69,9 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
|||
|
||||
std::vector<schema::QuantParamT> quant_params;
|
||||
primitive_c->AddInputQuantParam(quant_params);
|
||||
|
||||
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
||||
bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false;
|
||||
|
||||
auto status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant,
|
||||
quant_max, quant_min, bitNum, true, depthwise);
|
||||
quant_max, quant_min, bitNum, true, false);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
|
|
Loading…
Reference in New Issue