fix onnx pool and tflite pad

This commit is contained in:
xuanyue 2020-09-16 17:24:49 +08:00
parent e60e4920f9
commit c9db2ed81f
4 changed files with 10 additions and 10 deletions

View File

@ -131,7 +131,7 @@ int CastOpenCLKernel::Run() {
kernel::LiteKernel *OpenCLCastKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::Context *ctx, const kernel::KernelKey &desc,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) CastOpenCLKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {

View File

@ -67,9 +67,8 @@ STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) {
}
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) {
if (node->inputIndex.at(inputIndexIdx) == inputIdx) {
for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) {
if ((*iter)->inputIndex.at(inputIndexIdx) == inputIdx) {
STATUS status = RET_OK;
iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status);
if (status != RET_OK) {
@ -89,7 +88,6 @@ STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) {
graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]};
transed = true;
}
break;
}
}
}

View File

@ -83,9 +83,9 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
if (onnx_node_attr.ints_size() == 4) {
attr->padMode = schema::PadMode_CAFFE;
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->padDown = static_cast<int32_t>(onnx_node_attr.ints(1));
attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(1));
attr->padDown = static_cast<int32_t>(onnx_node_attr.ints(2));
attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1));
attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3));
}
}
if (attribute_name == "ceil_mode") {

View File

@ -74,8 +74,6 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o
MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support";
return RET_INVALID_OP_ATTR;
}
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
} else {
MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported";
return RET_NOT_SUPPORT;
@ -86,6 +84,10 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
if (std::strcmp(node_name, "MirrorPad") == 0) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
return RET_OK;