diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt index d075e73d37b..ee3f5d6355b 100644 --- a/.jenkins/check/config/filter_cppcheck.txt +++ b/.jenkins/check/config/filter_cppcheck.txt @@ -46,6 +46,7 @@ "mindspore/mindspore/lite/tools/common/flag_parser.cc" "useStlAlgorithm" "mindspore/mindspore/lite/tools/common/tensor_util.cc" "useStlAlgorithm" "mindspore/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc" "useStlAlgorithm" +"mindspore/mindspore/lite/tools/converter/parser/pytorch/pytorch_model_parser.cc" "variableScope" "mindspore/mindspore/lite/tools/converter/quantizer/quantize_util.cc" "useStlAlgorithm" "mindspore/mindspore/lite/src/runtime/kernel/opencl/kernel/" "unreadVariable" "mindspore/mindspore/lite/src/runtime/kernel/opencl/cl/" "unreadVariable" diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index 81cdf4bb30f..73a7782be2b 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -525,7 +525,16 @@ if(PLATFORM_ARM64) install(FILES ${opencv_LIBPATH}/libopencv_imgproc.so.4.5.2 DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libopencv_imgproc.so.4.5 COMPONENT ${RUNTIME_COMPONENT_NAME}) - + if(ENABLE_CONVERT_PYTORCH_MODEL AND NOT WIN32) + install(FILES ${LIB_TORCH_PATH}/lib/libtorch.so DESTINATION ${CONVERTER_ROOT_DIR}/lib + COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(FILES ${LIB_TORCH_PATH}/lib/libtorch_cpu.so DESTINATION ${CONVERTER_ROOT_DIR}/lib + COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(FILES ${LIB_TORCH_PATH}/lib/libc10.so DESTINATION ${CONVERTER_ROOT_DIR}/lib + COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(FILES ${LIB_TORCH_PATH}/lib/libgomp.so DESTINATION ${CONVERTER_ROOT_DIR}/lib + COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() if(MSLITE_ENABLE_ACL) set(LITE_ACL_DIR ${TOP_DIR}/mindspore/lite/build/tools/converter/adapter/acl) install(FILES ${LITE_ACL_DIR}/mindspore_shared_lib/libmindspore_shared_lib.so diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index 2a50376ef57..c52e1346c8f 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -55,6 +55,7 @@ option(MSLITE_ENABLE_PARALLEL_INFERENCE "enable parallel inference interface" of option(MSLITE_ENABLE_SHARING_MODEL_WEIGHT "enable sharing model weight" off) option(MSLITE_ENABLE_EXPERIMENTAL_KERNEL "enable experimental kernel" on) option(MSLITE_ENABLE_GRAPH_KERNEL "enable graph kernel" off) +option(MSLITE_ENABLE_CONVERT_PYTORCH_MODEL "enable to convert pytorch model" off) #Option that can be configured through manually option(ENABLE_VERBOSE "" off) @@ -175,6 +176,11 @@ if(DEFINED ENV{MSLITE_ENABLE_SERVING}) set(MSLITE_ENABLE_SERVING $ENV{MSLITE_ENABLE_SERVING}) endif() +if(DEFINED ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL} AND DEFINED ENV{LIB_TORCH_PATH}) + set(ENABLE_CONVERT_PYTORCH_MODEL $ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL}) + set(LIB_TORCH_PATH $ENV{LIB_TORCH_PATH}) +endif() + if(MACHINE_LINUX_ARM64) add_compile_definitions(MACHINE_LINUX_ARM64) add_compile_definitions(LINUX_RUNTIME) diff --git a/mindspore/lite/include/registry/converter_context.h b/mindspore/lite/include/registry/converter_context.h index db5ac063f5e..91d32968d10 100644 --- a/mindspore/lite/include/registry/converter_context.h +++ b/mindspore/lite/include/registry/converter_context.h @@ -32,6 +32,7 @@ enum MS_API FmkType : int { kFmkTypeOnnx = 2, kFmkTypeMs = 3, kFmkTypeTflite = 4, + kFmkTypePytorch = 5, }; /// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser. diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 274bb29658c..f0895dbcd9d 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -79,6 +79,10 @@ add_subdirectory(parser/caffe) add_subdirectory(parser/tflite) add_subdirectory(parser/onnx) add_subdirectory(parser/tf) +if(ENABLE_CONVERT_PYTORCH_MODEL AND NOT WIN32) + add_subdirectory(parser/pytorch) +endif() + add_subdirectory(legacy_optimizer) add_subdirectory(quantizer) add_subdirectory(registry) @@ -310,6 +314,10 @@ if(SUPPORT_TRAIN) target_link_libraries(converter_lite PRIVATE train_cpu_kernel_mid) endif() +if(ENABLE_CONVERT_PYTORCH_MODEL) + target_link_libraries(converter_lite PRIVATE pytorch_parser_mid) +endif() + if(NOT ENABLE_CLOUD_AND_LITE) target_link_libraries(converter_lite PRIVATE ccsrc_debug_common_mid_ diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index b8673eac2ff..8c651ec4da3 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -69,6 +69,7 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) { } else { model_parser_ = registry::ModelParserRegistry::GetModelParser(flag.fmk); if (model_parser_ == nullptr) { + MS_LOG(ERROR) << "Unsupported to converter models with fmk: " << flag.fmkIn; return nullptr; } converter::ConverterParameters converter_parameters; diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 6caa37919a9..f182e4ed966 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -147,6 +147,8 @@ int Flags::InitFmk() { this->fmk = kFmkTypeOnnx; } else if (this->fmkIn == "TF") { this->fmk = kFmkTypeTf; + } else if (this->fmkIn == "PYTORCH") { + this->fmk = kFmkTypePytorch; } else { std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl; return RET_INPUT_PARAM_INVALID; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc index 1162c19dc60..6a72f224d69 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc @@ -105,7 +105,7 @@ STATUS ReplaceTypeParameterNode(const FuncGraphPtr &func_graph, const ParameterP MS_LOG(ERROR) << "default data is not tensor::Tensor."; return lite::RET_NULL_PTR; } - auto param_node_new = opt::BuildParameterNode(func_graph, param_node, tensor_info); + auto param_node_new = opt::BuildParameterNode(func_graph, tensor_info, param_node->fullname_with_scope()); if (param_node_new == nullptr) { MS_LOG(ERROR) << "BuildParameterNode failed."; return lite::RET_NULL_PTR; @@ -166,7 +166,7 @@ STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { MS_LOG(ERROR) << "valueptr is not tensor::Tensorptr."; return lite::RET_ERROR; } - auto param_node = opt::BuildParameterNode(func_graph, cnode, tensor_info_ptr); + auto param_node = opt::BuildParameterNode(func_graph, tensor_info_ptr, cnode->fullname_with_scope()); if (param_node == nullptr) { MS_LOG(ERROR) << "convert constant to param node failed."; return lite::RET_ERROR; @@ -413,7 +413,7 @@ STATUS AdjustRandomNormal(const FuncGraphPtr &func_graph, const CNodePtr &cnode) auto tensor_info = CreateTensorInfo(data, data_size, shape, data_type); free(data); MS_CHECK_TRUE_RET(tensor_info != nullptr, RET_ERROR); - auto parameter = opt::BuildParameterNode(func_graph, cnode, tensor_info); + auto parameter = opt::BuildParameterNode(func_graph, tensor_info, cnode->fullname_with_scope()); if (parameter == nullptr) { MS_LOG(ERROR) << "BuildParameterNode failed."; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/pytorch/CMakeLists.txt b/mindspore/lite/tools/converter/parser/pytorch/CMakeLists.txt new file mode 100644 index 00000000000..1f66e6e605e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/pytorch/CMakeLists.txt @@ -0,0 +1,22 @@ +include(${TOP_DIR}/mindspore/lite/cmake/merge.cmake) +merge_parser(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_BINARY_DIR}/tools/converter/parser/pytorch/pytorch_op_parser.cc) +file(GLOB_RECURSE PYTORCH_SRC_LIST ${CMAKE_BINARY_DIR}/tools/converter/parser/pytorch/pytorch_op_parser.cc) +set_property(SOURCE ${PYTORCH_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) +add_library(pytorch_parser_mid OBJECT + ${PYTORCH_SRC_LIST} + ) + +add_compile_definitions(C10_USE_GLOG) +if(NOT EXISTS ${LIB_TORCH_PATH}) + message(FATAL_ERROR "Path of libtorch is invalid.") +endif() +find_package(Torch REQUIRED PATHS ${LIB_TORCH_PATH}) +if(TORCH_FOUND) + target_link_libraries(pytorch_parser_mid PRIVATE ${TORCH_LIBRARIES}) + target_include_directories(pytorch_parser_mid PRIVATE ${TORCH_INCLUDE_DIRS}) +else() + message(FATAL_ERROR "Torch is not found") +endif() + +add_dependencies(pytorch_parser_mid fbs_src) +add_dependencies(pytorch_parser_mid fbs_inner_src) diff --git a/mindspore/lite/tools/converter/parser/pytorch/pytorch_conv_parser.cc b/mindspore/lite/tools/converter/parser/pytorch/pytorch_conv_parser.cc new file mode 100644 index 00000000000..0e82ef30cd0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/pytorch/pytorch_conv_parser.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/pytorch/pytorch_conv_parser.h" +#include +#include +#include +#include +#include "ops/fusion/conv2d_fusion.h" +#include "nnacl/op_base.h" + +namespace mindspore::lite { +PrimitiveCPtr PytorchConvParser::Parse(const torch::jit::Node *torch_node, std::vector *input_indices) { + MS_ASSERT(torch_node != nullptr && input_indices != nullptr); + auto prim = std::make_unique(); + MS_CHECK_TRUE_RET(prim != nullptr, nullptr); + auto bias = torch_node->input(kBiasIndex); + MS_CHECK_TRUE_RET(bias != nullptr, nullptr); + // the bias is noneType + size_t input_size = bias->isCompleteTensor() ? kInputSize2 : kInputSize1; + input_indices->resize(input_size); + std::iota(input_indices->begin(), input_indices->end(), 0); + + auto stride = PytorchNodeParser::GetValueFromConstNode>(torch_node->input(FOURTH_INPUT)); + prim->set_stride(stride); + auto dilation = PytorchNodeParser::GetValueFromConstNode>(torch_node->input(SIXTH_INPUT)); + prim->set_dilation(dilation); + + auto padding = PytorchNodeParser::GetValueFromConstNode>(torch_node->input(FIFTH_INPUT)); + if (padding.size() == DIMENSION_2D) { + padding.push_back(padding.at(1)); + padding.insert(padding.begin(), padding.at(0)); + } + prim->set_pad_list(padding); + + auto group = PytorchNodeParser::GetValueFromConstNode(torch_node->input(8)); + prim->set_group(group); + + prim->set_pad({0, 0, 0, 0}); + mindspore::PadMode pad_mode = mindspore::PadMode::PAD; + prim->set_pad_mode(pad_mode); + auto prim_c = prim->GetPrim(); + MS_CHECK_TRUE_RET(prim_c != nullptr, nullptr); + mindspore::Format format = mindspore::Format::NCHW; + prim_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue(format)); + bool conv1d = stride.size() == 1; + if (conv1d) { + prim_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue(NCW)); + } + + // parse activationType + prim->set_activation_type(mindspore::ActivationType::NO_ACTIVATION); + + return prim->GetPrim(); +} + +PytorchNodeRegistrar g_pytorchConvParser("conv2d", new PytorchConvParser()); +PytorchNodeRegistrar g_pytorchConvolutionParser("convolution", new PytorchConvParser()); +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/pytorch/pytorch_conv_parser.h b/mindspore/lite/tools/converter/parser/pytorch/pytorch_conv_parser.h new file mode 100644 index 00000000000..719fa20f863 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/pytorch/pytorch_conv_parser.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_CONV_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_CONV_PARSER_H + +#include +#include "tools/converter/parser/pytorch/pytorch_node_parser.h" +#include "tools/converter/parser/pytorch/pytorch_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class PytorchConvParser : public PytorchNodeParser { + public: + PytorchConvParser() : PytorchNodeParser("Conv") {} + ~PytorchConvParser() override = default; + + PrimitiveCPtr Parse(const torch::jit::Node *torch_node, std::vector *input_indices) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_CONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/pytorch/pytorch_model_parser.cc b/mindspore/lite/tools/converter/parser/pytorch/pytorch_model_parser.cc new file mode 100644 index 00000000000..e4c3c239ade --- /dev/null +++ b/mindspore/lite/tools/converter/parser/pytorch/pytorch_model_parser.cc @@ -0,0 +1,508 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/pytorch/pytorch_model_parser.h" +#include +#include +#include +#include "torch/csrc/jit/passes/freeze_module.h" +#include "torch/csrc/jit/passes/inliner.h" +#include "torch/csrc/jit/passes/normalize_ops.h" +#include "include/registry/node_parser_registry.h" +#include "tools/common/graph_util.h" +#include "tools/common/tensor_util.h" +#include "tools/converter/converter_context.h" +#include "tools/converter/parser/parser_utils.h" +#include "tools/converter/parser/unify_format.h" +#include "tools/converter/parser/lite_model_parser_creator.h" +#include "src/common/file_utils.h" +#include "src/common/log_util.h" +#include "nnacl/op_base.h" +#include "ops/op_utils.h" +#include "ops/make_tuple.h" +#include "ops/return.h" +#include "ops/tuple_get_item.h" + +using mindspore::converter::kFmkTypePytorch; +namespace mindspore { +namespace lite { +api::FuncGraphPtr PytorchModelParser::Parse(const converter::ConverterParameters &flag) { + auto model_file = flag.model_file; + NotSupportOp::GetInstance()->set_fmk_type("PYTORCH"); + auto anf_graph = std::make_shared(); + MS_CHECK_TRUE_MSG(anf_graph != nullptr, nullptr, "create FuncGraph failed"); + res_graph_ = api::MakeShared(anf_graph); + MS_CHECK_TRUE_MSG(res_graph_ != nullptr, nullptr, "create FuncGraph failed"); + auto status = InitOriginModel(model_file); + if (RET_OK != status) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + MS_LOG(ERROR) << "init origin model failed."; + return nullptr; + } + + status = ConvertTorchGraph(anf_graph); + if (RET_OK != status) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + MS_LOG(ERROR) << "convert pytorch graph failed."; + return nullptr; + } + static auto root_func_manager = Manage(anf_graph); + MS_ASSERT(root_func_manager != nullptr); + for (auto &subgraph : all_subgraphs_) { + MS_ASSERT(subgraph != nullptr); + subgraph->set_manager(root_func_manager); + subgraph->set_attr("fmk", MakeValue(static_cast(converter::kFmkTypePytorch))); + } + anf_graph->set_attr("graph_name", MakeValue("main_graph")); + anf_graph->set_attr("fmk", MakeValue(static_cast(converter::kFmkTypePytorch))); + if ((status = CommonAnfAdjust(anf_graph)) != RET_OK) { + MS_LOG(ERROR) << "AdjustForAnf failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } + + auto unify_format = std::make_shared(kFmkTypePytorch, false); + MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr"); + if (!unify_format->Run(anf_graph)) { + MS_LOG(ERROR) << "Run insert transpose failed."; + return nullptr; + } + return res_graph_; +} +STATUS PytorchModelParser::InitOriginModel(const std::string &model_file) { + if (ValidateFileStr(model_file, ".pt") != RET_OK && ValidateFileStr(model_file, ".pth") != RET_OK) { + MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pt or *.pth"; + return RET_ERROR; + } + std::string model_path = RealPath(model_file.c_str()); + if (model_path.empty()) { + MS_LOG(ERROR) << "Binary proto file path " << model_file << " is not valid"; + return RET_ERROR; + } + // only linux supports to convert pytorch model. + if (access(model_path.c_str(), F_OK) != 0 || access(model_path.c_str(), R_OK) != 0) { + MS_LOG(ERROR) << "The pytorch model file is not exist or can't be read."; + return RET_ERROR; + } + + auto module = torch::jit::load(model_path); + module.eval(); // eval to expand function call + module = torch::jit::freeze_module(module); // freeze module + torch_model_ = module.get_method("forward").graph(); + CHECK_NULL_RETURN(torch_model_); + // parse submodules in graph + torch::jit::Inline(*torch_model_); + torch::jit::NormalizeOps(torch_model_); + return RET_OK; +} + +STATUS PytorchModelParser::ConvertTorchGraph(const FuncGraphPtr &anf_graph) { + MS_ASSERT(torch_graph != nullptr && anf_graph != nullptr && anf_nodes_map != nullptr && + extra_subgraph_inputs != nullptr); + STATUS status = ConvertGraphInputs(anf_graph); + if (RET_OK != status) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + MS_LOG(ERROR) << "convert graph inputs failed."; + return RET_OK; + } + + status = ConvertNodes(anf_graph); + if (RET_OK != status) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + MS_LOG(ERROR) << "convert nodes failed."; + return RET_ERROR; + } + + status = ConvertGraphOutputs(anf_graph); + if (RET_OK != status) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + MS_LOG(ERROR) << "convert graph outputs failed."; + return RET_ERROR; + } + return status; +} + +STATUS PytorchModelParser::ConvertGraphInputs(const FuncGraphPtr &anf_graph) { + MS_ASSERT(anf_graph != nullptr && anf_nodes_map != nullptr); + for (auto &input : torch_model_->inputs()) { + auto input_name = input->debugName(); + if (anf_nodes_map_.find(input_name) != anf_nodes_map_.end()) { + continue; + } + auto type = input->type(); + MS_CHECK_TRUE_RET(type != nullptr, RET_ERROR); + auto tensor_type = type->cast(); + if (tensor_type == nullptr) { + MS_LOG(DEBUG) << "The input is not a tensor, but a: " << c10::typeKindToString(type->kind()); + continue; + } + auto scalar_type = tensor_type->scalarType().value_or(at::ScalarType::Float); + auto data_type = PytorchNodeParser::GetDataTypeFromTorch(scalar_type); + if (data_type == kTypeUnknown) { + MS_LOG(ERROR) << "not support pytorch data type " << scalar_type; + return RET_ERROR; + } + std::vector input_shape = ConverterInnerContext::GetInstance()->GetGraphInputTensorShape(input_name); + if (input_shape.empty()) { + if (tensor_type->sizes().isComplete()) { + input_shape = tensor_type->sizes().concrete_sizes().value(); + } else { + MS_LOG(WARNING) << "The input shape is empty."; + } + } + auto abstract_tensor = CreateTensorAbstract(input_shape, data_type); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + auto parameter = anf_graph->add_parameter(); + MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr"); + parameter->set_abstract(abstract_tensor); + parameter->set_name(input_name); + anf_nodes_map_.emplace(input_name, parameter); + } + return RET_OK; +} + +STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector &return_inputs) { + MS_ASSERT(anf_graph != nullptr); + auto return_prim_ptr = std::make_shared(); + if (return_prim_ptr == nullptr) { + MS_LOG(ERROR) << "new Return failed"; + return RET_NULL_PTR; + } + auto return_prim = return_prim_ptr->GetPrim(); + MS_CHECK_TRUE_RET(return_prim != nullptr, RET_ERROR); + auto return_cnode = anf_graph->NewCNode(return_prim, return_inputs); + if (return_cnode == nullptr) { + MS_LOG(ERROR) << "new cnode error"; + return RET_ERROR; + } + return_cnode->set_fullname_with_scope("Return"); + anf_graph->set_return(return_cnode); + return RET_OK; +} + +STATUS PytorchModelParser::ConvertGraphOutputs(const FuncGraphPtr &anf_graph) { + MS_ASSERT(anf_graph != nullptr); + std::vector return_inputs; + if (torch_model_->outputs().size() == 0) { + MS_LOG(ERROR) << "pytorch graph has no output"; + return RET_ERROR; + } + if (torch_model_->outputs().size() > 1) { + std::vector make_tuple_inputs; + auto make_tuple_prim_ptr = std::make_shared(); + if (make_tuple_prim_ptr == nullptr) { + MS_LOG(ERROR) << "new MakeTuple failed"; + return RET_NULL_PTR; + } + for (const auto &output : torch_model_->outputs()) { + auto output_name = output->debugName(); + if (anf_nodes_map_.find(output_name) == anf_nodes_map_.end()) { + MS_LOG(ERROR) << "graph output get failed."; + return RET_ERROR; + } + auto cnode = anf_nodes_map_.at(output_name); + if (cnode == nullptr) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_NOT_FIND_OP; + } + make_tuple_inputs.emplace_back(cnode); + } + auto make_tuple_prim = make_tuple_prim_ptr->GetPrim(); + MS_CHECK_TRUE_RET(make_tuple_prim != nullptr, RET_ERROR); + auto make_tuple_cnode = anf_graph->NewCNode(make_tuple_prim, make_tuple_inputs); + if (make_tuple_cnode == nullptr) { + MS_LOG(ERROR) << "new cnode error"; + return RET_ERROR; + } + make_tuple_cnode->set_fullname_with_scope("return tuple"); + return_inputs.emplace_back(make_tuple_cnode); + } else { + const auto &output = torch_model_->outputs().front(); + if (anf_nodes_map_.find(output->debugName()) == anf_nodes_map_.end()) { + MS_LOG(ERROR) << "graph output get failed."; + return RET_ERROR; + } + auto cnode = anf_nodes_map_.at(output->debugName()); + if (cnode == nullptr) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_NOT_FIND_OP; + } + return_inputs.emplace_back(cnode); + } + if (BuildReturnNode(anf_graph, return_inputs) != RET_OK) { + MS_LOG(ERROR) << "build return node failed."; + return RET_ERROR; + } + return RET_OK; +} + +STATUS CopyDataFromTorchTensor(char *dst_data, const at::Tensor &torch_tensor, TypeId data_type) { + auto ele_size = abstract::TypeIdSize(data_type); + MS_CHECK_TRUE_RET(ele_size > 0, RET_ERROR); + auto data_shape = torch_tensor.sizes().vec(); + auto stride = torch_tensor.strides().vec(); + if (data_shape.empty()) { + auto data_size = torch_tensor.numel() * ele_size; + data_shape.push_back(data_size); + stride.push_back(1); + } + char *data_ptr = reinterpret_cast(torch_tensor.data_ptr()); + if (data_ptr == nullptr) { + MS_LOG(ERROR) << "The tensor data is nullptr."; + return RET_ERROR; + } + size_t idx = 0; + std::function copy_data = [&](size_t dim, size_t offset) { + if (dim == data_shape.size() - 1) { + for (int i = 0; i < data_shape[dim]; i++) { + auto src_ptr = data_ptr + offset + i * stride[dim] * ele_size; + auto dst_ptr = dst_data + (idx++) * ele_size; + memcpy(dst_ptr, src_ptr, ele_size); + } + } else { + for (int i = 0; i < data_shape[dim]; i++) { + copy_data(dim + 1, offset + i * stride[dim] * ele_size); + } + } + }; + copy_data(0, 0); + return RET_OK; +} + +STATUS ConvertConstNode(const torch::jit::Node *torch_node, const FuncGraphPtr &anf_graph, + std::unordered_map *anf_nodes_map) { + ParameterPtr parameter = nullptr; + auto output = torch_node->output(); + auto type_kind = output->type()->kind(); + auto value = torch::jit::toIValue(output); + MS_CHECK_TRUE_RET(value.has_value(), RET_ERROR); + switch (type_kind) { + case c10::TypeKind::BoolType: { + auto data = static_cast(value.value().toBool()); + parameter = opt::BuildIntValueParameterNode(anf_graph, data, output->debugName()); + } break; + case c10::TypeKind::IntType: { + auto data = static_cast(value.value().toInt()); + parameter = opt::BuildIntValueParameterNode(anf_graph, data, output->debugName()); + } break; + case c10::TypeKind::FloatType: { + auto data = static_cast(value.value().toDouble()); + parameter = opt::BuildFloatValueParameterNode(anf_graph, data, output->debugName()); + } break; + case c10::TypeKind::ListType: { + auto element_type = value->toList().elementType()->kind(); + switch (element_type) { + case c10::TypeKind::IntType: { + auto ori_data = value.value().toIntVector(); + std::vector data; + std::transform(ori_data.begin(), ori_data.end(), std::back_inserter(data), + [](int64_t ele) { return static_cast(ele); }); + parameter = opt::BuildIntVecParameterNode(anf_graph, data, output->debugName()); + } break; + case c10::TypeKind::FloatType: { + auto ori_data = value.value().toDoubleVector(); + std::vector data; + std::transform(ori_data.begin(), ori_data.end(), std::back_inserter(data), + [](double ele) { return static_cast(ele); }); + parameter = opt::BuildFloatVecParameterNode(anf_graph, data, output->debugName()); + } break; + default: + MS_LOG(ERROR) << "Unsupported data type: " << c10::typeKindToString(element_type); + return RET_ERROR; + } + } break; + case c10::TypeKind::TensorType: { + auto torch_tensor = value.value().toTensor(); + auto data_type = PytorchNodeParser::GetDataTypeFromTorch(torch_tensor.scalar_type()); + auto data_size = torch_tensor.numel() * abstract::TypeIdSize(data_type); + char *data_ptr = reinterpret_cast(malloc(data_size)); + if (data_ptr == nullptr) { + MS_LOG(ERROR) << "malloc data failed."; + return RET_ERROR; + } + if (CopyDataFromTorchTensor(data_ptr, torch_tensor, data_type) != RET_OK) { + MS_LOG(ERROR) << "Copy data from torch tensor failed."; + free(data_ptr); + return RET_ERROR; + } + auto data_shape = torch_tensor.sizes().vec(); + auto tensor_info = CreateTensorInfo(data_ptr, data_size, data_shape, data_type); + free(data_ptr); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensorInfo failed."; + return RET_ERROR; + } + parameter = opt::BuildParameterNode(anf_graph, tensor_info, output->debugName()); + } break; + case c10::TypeKind::NoneType: + MS_LOG(DEBUG) << "The const node is none."; + return RET_OK; + default: + MS_LOG(ERROR) << "Unsupported data type: " << c10::typeKindToString(type_kind); + return RET_ERROR; + } + if (parameter == nullptr) { + MS_LOG(ERROR) << "The parameter is nullptr."; + return RET_ERROR; + } + anf_nodes_map->emplace(output->debugName(), parameter); + return RET_OK; +} + +STATUS BuildOpInputs(const torch::jit::Node *torch_node, std::vector *op_inputs, + const std::vector &input_indices, + const std::unordered_map &anf_nodes_map) { + MS_ASSERT(torch_node != nullptr && op_inputs != nullptr); + for (size_t idx : input_indices) { + auto input = torch_node->input(idx); + MS_CHECK_TRUE_RET(input != nullptr, RET_ERROR); + auto input_name = input->debugName(); + if (input_name.empty()) { + continue; + } + + if (anf_nodes_map.find(input_name) != anf_nodes_map.end()) { + op_inputs->push_back(anf_nodes_map.at(input_name)); + } else { + MS_LOG(ERROR) << "could not find input node: " << input_name; + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS BuildOpOutputs(const torch::jit::Node *torch_node, const FuncGraphPtr &anf_graph, + std::unordered_map *anf_nodes_map, const CNodePtr &cnode) { + MS_ASSERT(torch_node != nullptr && anf_graph != nullptr && cnode != nullptr && anf_nodes_map != nullptr); + if (torch_node->outputs().size() == 1) { + auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + cnode->set_abstract(abstract_tensor); + anf_nodes_map->emplace(torch_node->output()->debugName(), cnode); + } else { + AbstractBasePtrList abstract_list; + int op_idx = 0; + for (const auto &output : torch_node->outputs()) { + auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + abstract_list.emplace_back(abstract_tensor); + auto tuple_get_item_prim_ptr = std::make_shared(); + if (tuple_get_item_prim_ptr == nullptr) { + MS_LOG(ERROR) << "new TupleGetItem failed"; + return RET_NULL_PTR; + } + auto tuple_get_item_prim = tuple_get_item_prim_ptr->GetPrim(); + MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, RET_NULL_PTR, "get prim return nullptr"); + auto tuple_get_item = NewValueNode(tuple_get_item_prim); + MS_CHECK_TRUE_MSG(tuple_get_item != nullptr, RET_NULL_PTR, "create ValueNode return nullptr"); + auto get_item_value = NewValueNode(MakeValue(op_idx)); + MS_CHECK_TRUE_MSG(get_item_value != nullptr, RET_NULL_PTR, "create ValueNode return nullptr"); + std::vector inputs{tuple_get_item, cnode, get_item_value}; + CNodePtr get_item_cnode = anf_graph->NewCNode(inputs); + if (get_item_cnode == nullptr) { + MS_LOG(ERROR) << "new cnode error"; + return RET_ERROR; + } + get_item_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx)); + auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32); + if (get_item_abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + get_item_cnode->set_abstract(get_item_abstract); + anf_nodes_map->emplace(output->debugName(), get_item_cnode); + op_idx++; + } + auto new_abstract_list = std::make_shared(abstract_list); + CHECK_NULL_RETURN(new_abstract_list); + cnode->set_abstract(new_abstract_list); + } + anf_nodes_map->emplace(torch_node->kind().toUnqualString(), cnode); + return RET_OK; +} + +STATUS PytorchModelParser::ConvertNodes(const FuncGraphPtr &anf_graph) { + MS_ASSERT(anf_graph != nullptr); + STATUS status = RET_OK; + for (const auto &torch_node : torch_model_->nodes()) { + ops::PrimitiveCPtr primitive_c = nullptr; + auto node_type = PytorchNodeParser::GetTorchNodeType(torch_node); + MS_CHECK_TRUE_RET(!node_type.empty(), RET_ERROR); + // convert constant node. + if (node_type == "Constant") { + if (ConvertConstNode(torch_node, anf_graph, &anf_nodes_map_) != RET_OK) { + MS_LOG(ERROR) << "Convert constant node failed."; + return RET_ERROR; + } + continue; + } + + std::vector op_inputs; + std::vector input_indices; + auto node_parser_builtin = PytorchNodeParserRegistry::GetInstance().GetNodeParser(node_type); + if (node_parser_builtin == nullptr) { + NotSupportOp::GetInstance()->InsertOp(node_type); + status = status == RET_OK ? RET_NOT_FIND_OP : status; + MS_LOG(ERROR) << "not support pytorch op type " << node_type; + continue; + } + MS_LOG(INFO) << "parse op:" << node_type; + primitive_c = node_parser_builtin->Parse(torch_node, &input_indices); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "parse node " << node_type << " failed."; + status = RET_ERROR; + continue; + } + // set default format and input indices. + if (primitive_c->GetAttr(ops::kOriginalFormat) == nullptr) { + primitive_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue(NCHW)); + } + if (input_indices.empty()) { + input_indices.resize(torch_node->inputs().size()); + std::iota(input_indices.begin(), input_indices.end(), 0); + } + + if (BuildOpInputs(torch_node, &op_inputs, input_indices, anf_nodes_map_) != RET_OK) { + MS_LOG(ERROR) << "BuildOpInputs failed."; + return RET_ERROR; + } + auto new_cnode = anf_graph->NewCNode(primitive_c, op_inputs); + if (new_cnode == nullptr) { + MS_LOG(ERROR) << "new cnode error"; + return RET_ERROR; + } + new_cnode->set_fullname_with_scope(std::string(torch_node->kind().toUnqualString()) + "_" + + torch_node->output(0)->debugName()); + if (BuildOpOutputs(torch_node, anf_graph, &anf_nodes_map_, new_cnode) != RET_OK) { + MS_LOG(ERROR) << "BuildOpOutputs failed."; + return RET_ERROR; + } + } + return status; +} + +REG_MODEL_PARSER(kFmkTypePytorch, LiteModelParserCreator) +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/pytorch/pytorch_model_parser.h b/mindspore/lite/tools/converter/parser/pytorch/pytorch_model_parser.h new file mode 100644 index 00000000000..c5352a79ddb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/pytorch/pytorch_model_parser.h @@ -0,0 +1,55 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_MODEL_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_MODEL_PARSER_H + +#include +#include +#include +#include +#include "torch/script.h" +#include "securec/include/securec.h" +#include "include/registry/model_parser.h" +#include "include/registry/model_parser_registry.h" +#include "tools/converter/parser/pytorch/pytorch_node_parser_registry.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +class PytorchModelParser : public converter::ModelParser { + public: + PytorchModelParser() = default; + + ~PytorchModelParser() override = default; + + api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override; + + private: + STATUS InitOriginModel(const std::string &model_file); + STATUS ConvertTorchGraph(const FuncGraphPtr &anf_graph); + STATUS ConvertGraphInputs(const FuncGraphPtr &anf_graph); + STATUS ConvertGraphOutputs(const FuncGraphPtr &anf_graph); + STATUS ConvertNodes(const FuncGraphPtr &anf_graph); + + std::shared_ptr torch_model_; + std::vector all_subgraphs_{}; + std::unordered_map anf_nodes_map_{}; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_MODEL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser.cc b/mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser.cc new file mode 100644 index 00000000000..ff4b0a2650a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/pytorch/pytorch_node_parser.h" +#include + +namespace mindspore { +namespace lite { +namespace { +static std::unordered_map kTorchDataTypeTransferMap = { + {at::ScalarType::Bool, kNumberTypeBool}, {at::ScalarType::Byte, kNumberTypeUInt8}, + {at::ScalarType::Char, kNumberTypeInt8}, {at::ScalarType::Int, kNumberTypeInt}, + {at::ScalarType::Long, kNumberTypeInt}, {at::ScalarType::Half, kNumberTypeFloat16}, + {at::ScalarType::Float, kNumberTypeFloat32}, {at::ScalarType::Double, kNumberTypeFloat32}}; +} // namespace + +std::string PytorchNodeParser::GetTorchNodeType(const torch::jit::Node *torch_node) { + const auto &kind = torch_node->kind(); + std::string node_type = kind.toUnqualString(); + if (node_type.empty()) { + return node_type; + } + node_type = node_type.at(0) == '_' ? node_type.substr(1) : node_type; + node_type = node_type.at(node_type.size() - 1) == '_' ? node_type.substr(0, node_type.size() - 1) : node_type; + return node_type; +} + +TypeId PytorchNodeParser::GetDataTypeFromTorch(const at::ScalarType torch_data_type) { + auto iter = kTorchDataTypeTransferMap.find(torch_data_type); + if (iter == kTorchDataTypeTransferMap.end()) { + MS_LOG(ERROR) << "Unsupported torch data type: " << torch_data_type; + return kTypeUnknown; + } + return iter->second; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser.h b/mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser.h new file mode 100644 index 00000000000..672e5ffa254 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser.h @@ -0,0 +1,66 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_PARSER_H + +#include +#include +#include +#include +#include "torch/script.h" +#include "include/errorcode.h" +#include "ir/dtype/type_id.h" +#include "ops/primitive_c.h" +#include "ops/op_name.h" +#include "src/common/log_adapter.h" +#include "tools/common/tensor_util.h" +#include "tools/converter/parser/parser_utils.h" +#include "nnacl/op_base.h" + +namespace mindspore { +namespace lite { +class PytorchNodeParser { + public: + explicit PytorchNodeParser(std::string node_name) : name_(std::move(node_name)) {} + + virtual ~PytorchNodeParser() = default; + + virtual PrimitiveCPtr Parse(const torch::jit::Node *torch_node, std::vector *input_indices) { + return nullptr; + } + + static std::string GetTorchNodeType(const torch::jit::Node *torch_node); + + static TypeId GetDataTypeFromTorch(const at::ScalarType torch_data_type); + + template + static T GetValueFromConstNode(const torch::jit::Value *value_node) { + T data{}; + auto ivalue = torch::jit::toIValue(value_node); + MS_CHECK_TRUE_RET(ivalue.has_value(), data); + auto value = ivalue.value(); + auto optional_value = value.toOptional(); + MS_CHECK_TRUE_RET(optional_value, data); + return optional_value.value(); + } + + protected: + const std::string name_{}; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser_registry.cc b/mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser_registry.cc new file mode 100644 index 00000000000..127d81bfd0c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser_registry.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/pytorch/pytorch_node_parser_registry.h" +#include + +namespace mindspore { +namespace lite { +PytorchNodeParserRegistry::PytorchNodeParserRegistry() = default; + +PytorchNodeParserRegistry::~PytorchNodeParserRegistry() { + for (auto ite : parsers) { + if (ite.second != nullptr) { + delete ite.second; + ite.second = nullptr; + } + } +} + +PytorchNodeParserRegistry &PytorchNodeParserRegistry::GetInstance() { + static PytorchNodeParserRegistry instance; + return instance; +} + +PytorchNodeParser *PytorchNodeParserRegistry::GetNodeParser(const std::string &name) const { + auto it = parsers.find(name); + if (it != parsers.end()) { + return it->second; + } + return nullptr; +} + +void PytorchNodeParserRegistry::RegNodeParser(const std::string &name, PytorchNodeParser *parser) { + if (parser == nullptr) { + MS_LOG(WARNING) << "Input PytorchNodeParser is nullptr"; + return; + } + if (this->parsers.find(name) != this->parsers.end()) { + MS_LOG(WARNING) << "PytorchNodeParser " << name << " is already exist"; + return; + } + this->parsers[name] = parser; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser_registry.h b/mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser_registry.h new file mode 100644 index 00000000000..98658215147 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser_registry.h @@ -0,0 +1,53 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_REGISTRY_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_REGISTRY_H + +#include +#include +#include "tools/converter/parser/pytorch/pytorch_node_parser.h" + +namespace mindspore { +namespace lite { +class PytorchNodeParserRegistry { + public: + virtual ~PytorchNodeParserRegistry(); + + static PytorchNodeParserRegistry &GetInstance(); + + PytorchNodeParser *GetNodeParser(const std::string &name) const; + + void RegNodeParser(const std::string &name, PytorchNodeParser *parser); + + private: + PytorchNodeParserRegistry(); + + private: + std::unordered_map parsers{}; +}; + +class PytorchNodeRegistrar { + public: + PytorchNodeRegistrar(const std::string &name, PytorchNodeParser *parser) { + PytorchNodeParserRegistry::GetInstance().RegNodeParser(name, parser); + } + ~PytorchNodeRegistrar() = default; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_REGISTRY_H diff --git a/mindspore/lite/tools/converter/parser/pytorch/pytorch_permute_parser.cc b/mindspore/lite/tools/converter/parser/pytorch/pytorch_permute_parser.cc new file mode 100644 index 00000000000..f6cb96e8fcf --- /dev/null +++ b/mindspore/lite/tools/converter/parser/pytorch/pytorch_permute_parser.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/pytorch/pytorch_permute_parser.h" +#include +#include +#include "ops/transpose.h" +#include "nnacl/op_base.h" + +namespace mindspore { +namespace lite { +PrimitiveCPtr PytorchPermuteParser::Parse(const torch::jit::Node *torch_node, std::vector *input_indices) { + MS_ASSERT(torch_node != nullptr && input_indices != nullptr); + auto prim = std::make_unique(); + MS_CHECK_TRUE_RET(prim != nullptr, nullptr); + return prim->GetPrim(); +} + +PytorchNodeRegistrar g_pytorchPermuteParser("permute", new PytorchPermuteParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/pytorch/pytorch_permute_parser.h b/mindspore/lite/tools/converter/parser/pytorch/pytorch_permute_parser.h new file mode 100644 index 00000000000..c59c764016d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/pytorch/pytorch_permute_parser.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_PERMUTE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_PERMUTE_PARSER_H + +#include +#include "tools/converter/parser/pytorch/pytorch_node_parser.h" +#include "tools/converter/parser/pytorch/pytorch_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class PytorchPermuteParser : public PytorchNodeParser { + public: + PytorchPermuteParser() : PytorchNodeParser("Permute") {} + ~PytorchPermuteParser() override = default; + + PrimitiveCPtr Parse(const torch::jit::Node *torch_node, std::vector *input_indices) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_PERMUTE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_inputs_adjust.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_inputs_adjust.cc index a2c697acf4e..de5d207fd97 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_inputs_adjust.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_inputs_adjust.cc @@ -121,7 +121,7 @@ STATUS TfliteInputsAdjust::ReplaceInt64ParameterNode(const FuncGraphPtr &func_gr MS_LOG(ERROR) << "default data is not tensor::Tensor."; return lite::RET_NULL_PTR; } - auto param_node_new = opt::BuildParameterNode(func_graph, param_node, tensor_info); + auto param_node_new = opt::BuildParameterNode(func_graph, tensor_info, param_node->fullname_with_scope()); if (!manager->Replace(param_node, param_node_new)) { MS_LOG(ERROR) << "Replace param node failed."; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/unify_format.cc b/mindspore/lite/tools/converter/parser/unify_format.cc index abac618213f..0bc2a90363a 100644 --- a/mindspore/lite/tools/converter/parser/unify_format.cc +++ b/mindspore/lite/tools/converter/parser/unify_format.cc @@ -208,7 +208,8 @@ STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, {converter::kFmkTypeTf, DecideTFConvWeightSrcFormat}, {converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat}, {converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat}, - {converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}}; + {converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}, + {converter::kFmkTypePytorch, DecideONNXConvWeightSrcFormat}}; auto iter = decide_functions.find(fmk_type_); if (iter == decide_functions.end()) { MS_LOG(ERROR) << "current fmk don't support, please check."; diff --git a/mindspore/lite/tools/converter/registry/model_parser_registry.cc b/mindspore/lite/tools/converter/registry/model_parser_registry.cc index cff1580bf51..bbdafd9689e 100644 --- a/mindspore/lite/tools/converter/registry/model_parser_registry.cc +++ b/mindspore/lite/tools/converter/registry/model_parser_registry.cc @@ -26,7 +26,7 @@ std::map model_parser_room; } // namespace ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) { - if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) { + if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) { MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType."; return; } @@ -34,7 +34,7 @@ ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator } converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) { - if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) { + if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) { MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType."; return nullptr; } diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index fc46af6ce12..89c00aee015 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -728,9 +728,9 @@ STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_for return iter->second(tensor, src_format, dst_format); } -ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const tensor::TensorPtr &tensor_info) { - if (func_graph == nullptr || node == nullptr || tensor_info == nullptr) { +ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const tensor::TensorPtr &tensor_info, + const std::string &node_name) { + if (func_graph == nullptr || tensor_info == nullptr) { MS_LOG(ERROR) << "input parameter is nullptr."; return nullptr; } @@ -746,7 +746,7 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr } else if (tensor_info->data_type() == kNumberTypeFloat64) { data_type = kNumberTypeFloat32; } - param_node->set_name(node->fullname_with_scope()); + param_node->set_name(node_name); auto tensor_info_new = std::make_shared(data_type, shape_vector); if (tensor_info_new == nullptr) { MS_LOG(ERROR) << "new tensor::Tensor failed."; diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 50487f13e55..122c1519036 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -98,8 +98,8 @@ AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index); STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format); -ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const tensor::TensorPtr &tensor_info); +ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const tensor::TensorPtr &tensor_info, + const std::string &node_name); ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data, const std::string &node_name);