diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc index 44a5589b854..57d36cc54da 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc @@ -125,9 +125,18 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni return RET_NULL_PTR; } auto type = node->primitive->value.type; - if (graph->allTensors.at(node->inputIndex[0])->dims.size() != 4) { - MS_LOG(ERROR) << "change op axis only support 4 dims"; - return RET_NOT_SUPPORT; + auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size(); + if (input1_ndim != 4) { + if (node->inputIndex.size() > 1) { + auto input2_ndim = graph->allTensors.at(node->inputIndex[1])->dims.size(); + if (input2_ndim != 4 && input2_ndim != 0) { + MS_LOG(ERROR) << "change op axis only support 4 dims"; + return RET_NOT_SUPPORT; + } + } else { + MS_LOG(ERROR) << "change op axis only support 4 dims"; + return RET_NOT_SUPPORT; + } } if (type == PrimitiveType_Concat) { auto origin_axis = node->primitive->value.AsConcat()->axis; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc new file mode 100644 index 00000000000..6eb1887e2a7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_constant_of_shape_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx ConstantOfShapeParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "value") { + attr->value = static_cast(onnx_node_attr.i()); + } + } + + op->primitive->value.type = schema::PrimitiveType_ConstantOfShape; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +OnnxNodeRegistrar g_onnxConstantOfShapeParser("ConstantOfShape", new OnnxConstantOfShapeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.h new file mode 100644 index 00000000000..b76421c30a2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_OF_SHAPE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_OF_SHAPE_PARSER_H + +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxConstantOfShapeParser : public OnnxNodeParser { + public: + OnnxConstantOfShapeParser() : OnnxNodeParser("ConstantOfShape") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_OF_SHAPE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.cc new file mode 100644 index 00000000000..6d4c205fc0a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_lstm_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx LstmParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + for (const auto &onnx_node_attr : onnx_node.attribute()) { + if (onnx_node_attr.name() == "direction") { + auto direction = onnx_node_attr.s(); + attr->bidirection = direction == "bidirectional"; + } + } + + op->primitive->value.type = schema::PrimitiveType_Lstm; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +OnnxNodeRegistrar g_onnxLstmParser("LSTM", new OnnxLstmParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.h new file mode 100644 index 00000000000..ecf0261f0cb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LSTM_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LSTM_PARSER_H + +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxLstmParser : public OnnxNodeParser { + public: + OnnxLstmParser() : OnnxNodeParser("LSTM") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LSTM_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index ccd6170669d..59f0a6b5c60 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -365,13 +365,15 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { for (const auto &onnx_node_input : node_inputs) { - auto index = tensor_cache->FindTensor(onnx_node_input); - if (index < 0) { - MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found"; - return RET_ERROR; + if (onnx_node_input != "") { + auto index = tensor_cache->FindTensor(onnx_node_input); + if (index < 0) { + MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found"; + return RET_ERROR; + } + MS_LOG(DEBUG) << "node: " << onnx_node_input << ", input index: " << index; + dst_op->inputIndex.emplace_back(index); } - MS_LOG(DEBUG) << "node: " << onnx_node_input << ", input index: " << index; - dst_op->inputIndex.emplace_back(index); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc new file mode 100644 index 00000000000..c9655773a25 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_split_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx SplitParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axis") { + attr->splitDim = static_cast(onnx_node_attr.i()); + } else if (attribute_name == "split") { + for (auto sizeSplit : onnx_node_attr.ints()) { + attr->sizeSplits.emplace_back(sizeSplit); + } + attr->numberSplit = onnx_node_attr.ints_size(); + } + } + + op->primitive->value.type = schema::PrimitiveType_Split; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +OnnxNodeRegistrar g_onnxSplitParser("Split", new OnnxSplitParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.h new file mode 100644 index 00000000000..bb017966aa3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLIT_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLIT_PARSER_H + +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxSplitParser : public OnnxNodeParser { + public: + OnnxSplitParser() : OnnxNodeParser("Split") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLIT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.cc new file mode 100644 index 00000000000..ef19e9623c0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_topk_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx TopKParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "k") { + attr->k = static_cast(onnx_node_attr.i()); + } + } + // attr->sorted; + + op->primitive->value.type = schema::PrimitiveType_TopK; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +OnnxNodeRegistrar g_onnxTopkParser("TopK", new OnnxTopkParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.h new file mode 100644 index 00000000000..0b23471a771 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TOPK_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TOPK_PARSER_H + +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxTopkParser : public OnnxNodeParser { + public: + OnnxTopkParser() : OnnxNodeParser("TopK") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TOPK_PARSER_H