anf exporter fixed

This commit is contained in:
kai00 2020-08-05 17:26:05 +08:00
parent c552e8d9f4
commit cf7d399c68
3 changed files with 38 additions and 2 deletions

View File

@ -24,7 +24,44 @@
namespace mindspore::lite {
int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(p->GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pads"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(p->GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;

View File

@ -34,5 +34,5 @@ int mindspore::lite::AnfMatmulPopulater::Parse(mindspore::CNodePtr cnodePtr, sch
node->primitive->value.value = attr.release();
return 0;
}
AnfNodePopulaterRegistrar anfMatmulParser("Matmul", new AnfMatmulPopulater());
AnfNodePopulaterRegistrar anfMatmulParser("MatMul", new AnfMatmulPopulater());
} // namespace mindspore::lite

View File

@ -32,5 +32,4 @@ int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema
return 0;
}
AnfNodePopulaterRegistrar anfMulParser("Mul", new AnfMulPopulater());
AnfNodePopulaterRegistrar anfMatMulParser("MatMul", new AnfMulPopulater());
} // namespace mindspore::lite