From cf7d399c6886df16c0daff7cf7fde550eb852c6d Mon Sep 17 00:00:00 2001 From: kai00 Date: Wed, 5 Aug 2020 17:26:05 +0800 Subject: [PATCH] anf exporter fixed --- .../anf_depthwiseconv2d_populater.cc | 37 +++++++++++++++++++ .../anf_populater/anf_matmul_populater.cc | 2 +- .../anf_populater/anf_mul_populater.cc | 1 - 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc index 6970df4ada4..5aa9ab1b900 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc @@ -24,7 +24,44 @@ namespace mindspore::lite { int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); auto attr = std::make_unique(); + + auto format = GetValue(p->GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(p->GetAttr("pads")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(p->GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(p->GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(p->GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + auto pad_mode = GetValue(p->GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + node->nodeType = schema::NodeType_CNode; node->primitive = std::make_unique(); node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc index 3bde1465959..909ceec01a9 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc @@ -34,5 +34,5 @@ int mindspore::lite::AnfMatmulPopulater::Parse(mindspore::CNodePtr cnodePtr, sch node->primitive->value.value = attr.release(); return 0; } -AnfNodePopulaterRegistrar anfMatmulParser("Matmul", new AnfMatmulPopulater()); +AnfNodePopulaterRegistrar anfMatmulParser("MatMul", new AnfMatmulPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc index d86d95b05ef..4f5c3beec82 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc @@ -32,5 +32,4 @@ int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema return 0; } AnfNodePopulaterRegistrar anfMulParser("Mul", new AnfMulPopulater()); -AnfNodePopulaterRegistrar anfMatMulParser("MatMul", new AnfMulPopulater()); } // namespace mindspore::lite