forked from mindspore-Ecosystem/mindspore
!3713 detect model fix bug
Merge pull request !3713 from ghzl/deconv-adapter
This commit is contained in:
commit
9eb0a7697f
|
@ -80,6 +80,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
|
||||||
return new lite::Activation(const_cast<schema::Primitive *>(srcPrim));
|
return new lite::Activation(const_cast<schema::Primitive *>(srcPrim));
|
||||||
case schema::PrimitiveType_Conv2D:
|
case schema::PrimitiveType_Conv2D:
|
||||||
return new lite::Conv2D(const_cast<schema::Primitive *>(srcPrim));
|
return new lite::Conv2D(const_cast<schema::Primitive *>(srcPrim));
|
||||||
|
case schema::PrimitiveType_DeConv2D:
|
||||||
|
return new lite::DeConv2D(const_cast<schema::Primitive *>(srcPrim));
|
||||||
case schema::PrimitiveType_Reduce:
|
case schema::PrimitiveType_Reduce:
|
||||||
return new lite::Reduce(const_cast<schema::Primitive *>(srcPrim));
|
return new lite::Reduce(const_cast<schema::Primitive *>(srcPrim));
|
||||||
case schema::PrimitiveType_Pooling:
|
case schema::PrimitiveType_Pooling:
|
||||||
|
|
|
@ -81,7 +81,7 @@ STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pat
|
||||||
}
|
}
|
||||||
auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get();
|
auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get();
|
||||||
MS_ASSERT(baNodeBiasTensor != nullptr);
|
MS_ASSERT(baNodeBiasTensor != nullptr);
|
||||||
if (baNodeBiasTensor->refCount != schema::NodeType_ValueNode) {
|
if (baNodeBiasTensor->nodeType != schema::NodeType_ValueNode) {
|
||||||
// dont fusion, return
|
// dont fusion, return
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -215,7 +215,9 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath,
|
||||||
<< ". or bias tensor is a scaler";
|
<< ". or bias tensor is a scaler";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (!biasDims.empty() && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) {
|
|
||||||
|
bool bias_const = !biasDims.empty() && biasDims.size() == 1 && biasDims[0] == 1;
|
||||||
|
if (!biasDims.empty() && !bias_const && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) {
|
||||||
MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)"
|
MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)"
|
||||||
<< biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum;
|
<< biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum;
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -234,6 +236,11 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath,
|
||||||
MS_LOG(ERROR) << "memset_s newBiasData failed";
|
MS_LOG(ERROR) << "memset_s newBiasData failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
} else if (bias_const) {
|
||||||
|
auto *biasData = reinterpret_cast<float *>(biasTensor->data.data());
|
||||||
|
for (size_t i = 0; i < kernelNum; i++) {
|
||||||
|
newBiasData[i] = *biasData;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) {
|
if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) {
|
||||||
MS_LOG(ERROR) << "memcpy_s newBiasData failed";
|
MS_LOG(ERROR) << "memcpy_s newBiasData failed";
|
||||||
|
|
|
@ -153,6 +153,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
|
||||||
weightTensor->format = schema::Format_KHWC;
|
weightTensor->format = schema::Format_KHWC;
|
||||||
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
|
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
|
||||||
weightTensor->format = schema::Format_CHWK;
|
weightTensor->format = schema::Format_CHWK;
|
||||||
|
} else if (opType == schema::PrimitiveType_DeConv2D) {
|
||||||
|
weightTensor->format = schema::Format_KHWC;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "unsupport format";
|
MS_LOG(ERROR) << "unsupport format";
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -356,18 +358,18 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
|
||||||
MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str();
|
MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str();
|
||||||
// todo(00445839): consider varible weight condition
|
// todo(00445839): consider varible weight condition
|
||||||
}
|
}
|
||||||
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KCHW
|
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
|
||||||
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx
|
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
|
||||||
return 0;
|
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
|
||||||
} else if (weightTensor->format == schema::Format_HWKC) { // from tf
|
} else if (weightTensor->format == schema::Format_KHWC) { // from tf
|
||||||
status = TransFilterFormat<float>(weightTensor.get(), kHWKC2KCHW);
|
status = RET_OK;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
|
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
if (status == 0) {
|
if (status == 0) {
|
||||||
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW;
|
node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW;
|
||||||
weightTensor->format = schema::Format_KCHW;
|
weightTensor->format = schema::Format_KHWC;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str();
|
MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str();
|
||||||
// todo(00445839): consider varible weight condition
|
// todo(00445839): consider varible weight condition
|
||||||
|
|
|
@ -27,8 +27,16 @@ STATUS TfliteAddParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
|
||||||
schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache,
|
TensorCache *tensor_cache,
|
||||||
bool quantizedModel) {
|
bool quantizedModel) {
|
||||||
// MS_LOGD("parse TfliteAddParser");
|
MS_LOG(DEBUG) << "parse TfliteAddParser";
|
||||||
std::unique_ptr<schema::AddT> attr(new schema::AddT());
|
std::unique_ptr<schema::AddT> attr(new schema::AddT());
|
||||||
|
auto weight_index = tfliteOp->inputs[1];
|
||||||
|
const auto &weight_tensor = tfliteTensors[weight_index];
|
||||||
|
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
||||||
|
|
||||||
|
if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) {
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
if (op != nullptr) {
|
if (op != nullptr) {
|
||||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
op->primitive->value.type = schema::PrimitiveType_Add;
|
op->primitive->value.type = schema::PrimitiveType_Add;
|
||||||
|
|
Loading…
Reference in New Issue