deconv adapter

This commit is contained in:
guohongzilong 2020-07-30 20:21:22 +08:00
parent d66e6b33bf
commit 668db1dd7d
4 changed files with 29 additions and 10 deletions

View File

@ -76,6 +76,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::Activation(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Conv2D:
return new lite::Conv2D(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_DeConv2D:
return new lite::DeConv2D(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Reduce:
return new lite::Reduce(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Pooling:

View File

@ -81,7 +81,7 @@ STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pat
}
auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get();
MS_ASSERT(baNodeBiasTensor != nullptr);
if (baNodeBiasTensor->refCount != schema::NodeType_ValueNode) {
if (baNodeBiasTensor->nodeType != schema::NodeType_ValueNode) {
// dont fusion, return
return RET_OK;
}
@ -215,7 +215,9 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath,
<< ". or bias tensor is a scaler";
return RET_ERROR;
}
if (!biasDims.empty() && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) {
bool bias_const = !biasDims.empty() && biasDims.size() == 1 && biasDims[0] == 1;
if (!biasDims.empty() && !bias_const && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) {
MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)"
<< biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum;
return RET_ERROR;
@ -234,6 +236,11 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath,
MS_LOG(ERROR) << "memset_s newBiasData failed";
return RET_ERROR;
}
} else if (bias_const) {
auto *biasData = reinterpret_cast<float *>(biasTensor->data.data());
for (size_t i = 0; i < kernelNum; i++) {
newBiasData[i] = *biasData;
}
} else {
if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s newBiasData failed";

View File

@ -152,6 +152,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
weightTensor->format = schema::Format_KHWC;
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_CHWK;
} else if (opType == schema::PrimitiveType_DeConv2D) {
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(ERROR) << "unsupport format";
return -1;
@ -355,18 +357,18 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KCHW
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx
return 0;
} else if (weightTensor->format == schema::Format_HWKC) { // from tf
status = TransFilterFormat<float>(weightTensor.get(), kHWKC2KCHW);
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_KHWC) { // from tf
status = RET_OK;
} else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
}
if (status == 0) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW;
weightTensor->format = schema::Format_KCHW;
node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition

View File

@ -27,8 +27,16 @@ STATUS TfliteAddParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
// MS_LOGD("parse TfliteAddParser");
MS_LOG(DEBUG) << "parse TfliteAddParser";
std::unique_ptr<schema::AddT> attr(new schema::AddT());
auto weight_index = tfliteOp->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index];
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) {
return RET_ERROR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Add;