forked from mindspore-Ecosystem/mindspore
fix onnx pool and tflite pad
This commit is contained in:
parent
e60e4920f9
commit
c9db2ed81f
|
@ -131,7 +131,7 @@ int CastOpenCLKernel::Run() {
|
|||
|
||||
kernel::LiteKernel *OpenCLCastKernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
||||
const lite::Context *ctx, const kernel::KernelKey &desc,
|
||||
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto *kernel = new (std::nothrow) CastOpenCLKernel(opParameter, inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
|
|
|
@ -67,9 +67,8 @@ STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) {
|
|||
}
|
||||
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto &node = *iter;
|
||||
for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) {
|
||||
if (node->inputIndex.at(inputIndexIdx) == inputIdx) {
|
||||
for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) {
|
||||
if ((*iter)->inputIndex.at(inputIndexIdx) == inputIdx) {
|
||||
STATUS status = RET_OK;
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status);
|
||||
if (status != RET_OK) {
|
||||
|
@ -89,7 +88,6 @@ STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) {
|
|||
graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]};
|
||||
transed = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -83,9 +83,9 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
if (onnx_node_attr.ints_size() == 4) {
|
||||
attr->padMode = schema::PadMode_CAFFE;
|
||||
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->padDown = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
attr->padDown = static_cast<int32_t>(onnx_node_attr.ints(2));
|
||||
attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3));
|
||||
}
|
||||
}
|
||||
if (attribute_name == "ceil_mode") {
|
||||
|
|
|
@ -74,8 +74,6 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o
|
|||
MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support";
|
||||
return RET_INVALID_OP_ATTR;
|
||||
}
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(),
|
||||
tflite_tensors.size(), schema::Format::Format_NHWC);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported";
|
||||
return RET_NOT_SUPPORT;
|
||||
|
@ -86,6 +84,10 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o
|
|||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
|
||||
tflite_tensors.size(), schema::Format::Format_NHWC);
|
||||
if (std::strcmp(node_name, "MirrorPad") == 0) {
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(),
|
||||
tflite_tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
|
||||
tflite_tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
|
|
Loading…
Reference in New Issue