deconv adapter
This commit is contained in:
parent
d66e6b33bf
commit
668db1dd7d
|
@ -76,6 +76,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
|
|||
return new lite::Activation(const_cast<schema::Primitive *>(srcPrim));
|
||||
case schema::PrimitiveType_Conv2D:
|
||||
return new lite::Conv2D(const_cast<schema::Primitive *>(srcPrim));
|
||||
case schema::PrimitiveType_DeConv2D:
|
||||
return new lite::DeConv2D(const_cast<schema::Primitive *>(srcPrim));
|
||||
case schema::PrimitiveType_Reduce:
|
||||
return new lite::Reduce(const_cast<schema::Primitive *>(srcPrim));
|
||||
case schema::PrimitiveType_Pooling:
|
||||
|
|
|
@ -81,7 +81,7 @@ STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pat
|
|||
}
|
||||
auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get();
|
||||
MS_ASSERT(baNodeBiasTensor != nullptr);
|
||||
if (baNodeBiasTensor->refCount != schema::NodeType_ValueNode) {
|
||||
if (baNodeBiasTensor->nodeType != schema::NodeType_ValueNode) {
|
||||
// dont fusion, return
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -215,7 +215,9 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath,
|
|||
<< ". or bias tensor is a scaler";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!biasDims.empty() && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) {
|
||||
|
||||
bool bias_const = !biasDims.empty() && biasDims.size() == 1 && biasDims[0] == 1;
|
||||
if (!biasDims.empty() && !bias_const && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) {
|
||||
MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)"
|
||||
<< biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum;
|
||||
return RET_ERROR;
|
||||
|
@ -234,6 +236,11 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath,
|
|||
MS_LOG(ERROR) << "memset_s newBiasData failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (bias_const) {
|
||||
auto *biasData = reinterpret_cast<float *>(biasTensor->data.data());
|
||||
for (size_t i = 0; i < kernelNum; i++) {
|
||||
newBiasData[i] = *biasData;
|
||||
}
|
||||
} else {
|
||||
if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) {
|
||||
MS_LOG(ERROR) << "memcpy_s newBiasData failed";
|
||||
|
|
|
@ -152,6 +152,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
|
|||
weightTensor->format = schema::Format_KHWC;
|
||||
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
weightTensor->format = schema::Format_CHWK;
|
||||
} else if (opType == schema::PrimitiveType_DeConv2D) {
|
||||
weightTensor->format = schema::Format_KHWC;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupport format";
|
||||
return -1;
|
||||
|
@ -355,18 +357,18 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
|
|||
MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str();
|
||||
// todo(00445839): consider varible weight condition
|
||||
}
|
||||
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KCHW
|
||||
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx
|
||||
return 0;
|
||||
} else if (weightTensor->format == schema::Format_HWKC) { // from tf
|
||||
status = TransFilterFormat<float>(weightTensor.get(), kHWKC2KCHW);
|
||||
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
|
||||
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
|
||||
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
|
||||
} else if (weightTensor->format == schema::Format_KHWC) { // from tf
|
||||
status = RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
|
||||
return -1;
|
||||
}
|
||||
if (status == 0) {
|
||||
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW;
|
||||
weightTensor->format = schema::Format_KCHW;
|
||||
node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW;
|
||||
weightTensor->format = schema::Format_KHWC;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str();
|
||||
// todo(00445839): consider varible weight condition
|
||||
|
|
|
@ -27,8 +27,16 @@ STATUS TfliteAddParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
|
|||
schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantizedModel) {
|
||||
// MS_LOGD("parse TfliteAddParser");
|
||||
MS_LOG(DEBUG) << "parse TfliteAddParser";
|
||||
std::unique_ptr<schema::AddT> attr(new schema::AddT());
|
||||
auto weight_index = tfliteOp->inputs[1];
|
||||
const auto &weight_tensor = tfliteTensors[weight_index];
|
||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
||||
|
||||
if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Add;
|
||||
|
|
Loading…
Reference in New Issue