forked from mindspore-Ecosystem/mindspore
!31199 [MSLITE] Support to convert pytorch model.
Merge pull request !31199 from wangshaocong/torch_converter
This commit is contained in:
commit
b1023addba
|
@ -46,6 +46,7 @@
|
|||
"mindspore/mindspore/lite/tools/common/flag_parser.cc" "useStlAlgorithm"
|
||||
"mindspore/mindspore/lite/tools/common/tensor_util.cc" "useStlAlgorithm"
|
||||
"mindspore/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc" "useStlAlgorithm"
|
||||
"mindspore/mindspore/lite/tools/converter/parser/pytorch/pytorch_model_parser.cc" "variableScope"
|
||||
"mindspore/mindspore/lite/tools/converter/quantizer/quantize_util.cc" "useStlAlgorithm"
|
||||
"mindspore/mindspore/lite/src/runtime/kernel/opencl/kernel/" "unreadVariable"
|
||||
"mindspore/mindspore/lite/src/runtime/kernel/opencl/cl/" "unreadVariable"
|
||||
|
|
|
@ -525,7 +525,16 @@ if(PLATFORM_ARM64)
|
|||
install(FILES ${opencv_LIBPATH}/libopencv_imgproc.so.4.5.2
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libopencv_imgproc.so.4.5
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
||||
if(ENABLE_CONVERT_PYTORCH_MODEL AND NOT WIN32)
|
||||
install(FILES ${LIB_TORCH_PATH}/lib/libtorch.so DESTINATION ${CONVERTER_ROOT_DIR}/lib
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${LIB_TORCH_PATH}/lib/libtorch_cpu.so DESTINATION ${CONVERTER_ROOT_DIR}/lib
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${LIB_TORCH_PATH}/lib/libc10.so DESTINATION ${CONVERTER_ROOT_DIR}/lib
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${LIB_TORCH_PATH}/lib/libgomp.so DESTINATION ${CONVERTER_ROOT_DIR}/lib
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
if(MSLITE_ENABLE_ACL)
|
||||
set(LITE_ACL_DIR ${TOP_DIR}/mindspore/lite/build/tools/converter/adapter/acl)
|
||||
install(FILES ${LITE_ACL_DIR}/mindspore_shared_lib/libmindspore_shared_lib.so
|
||||
|
|
|
@ -55,6 +55,7 @@ option(MSLITE_ENABLE_PARALLEL_INFERENCE "enable parallel inference interface" of
|
|||
option(MSLITE_ENABLE_SHARING_MODEL_WEIGHT "enable sharing model weight" off)
|
||||
option(MSLITE_ENABLE_EXPERIMENTAL_KERNEL "enable experimental kernel" on)
|
||||
option(MSLITE_ENABLE_GRAPH_KERNEL "enable graph kernel" off)
|
||||
option(MSLITE_ENABLE_CONVERT_PYTORCH_MODEL "enable to convert pytorch model" off)
|
||||
|
||||
#Option that can be configured through manually
|
||||
option(ENABLE_VERBOSE "" off)
|
||||
|
@ -175,6 +176,11 @@ if(DEFINED ENV{MSLITE_ENABLE_SERVING})
|
|||
set(MSLITE_ENABLE_SERVING $ENV{MSLITE_ENABLE_SERVING})
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL} AND DEFINED ENV{LIB_TORCH_PATH})
|
||||
set(ENABLE_CONVERT_PYTORCH_MODEL $ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL})
|
||||
set(LIB_TORCH_PATH $ENV{LIB_TORCH_PATH})
|
||||
endif()
|
||||
|
||||
if(MACHINE_LINUX_ARM64)
|
||||
add_compile_definitions(MACHINE_LINUX_ARM64)
|
||||
add_compile_definitions(LINUX_RUNTIME)
|
||||
|
|
|
@ -32,6 +32,7 @@ enum MS_API FmkType : int {
|
|||
kFmkTypeOnnx = 2,
|
||||
kFmkTypeMs = 3,
|
||||
kFmkTypeTflite = 4,
|
||||
kFmkTypePytorch = 5,
|
||||
};
|
||||
|
||||
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
|
||||
|
|
|
@ -79,6 +79,10 @@ add_subdirectory(parser/caffe)
|
|||
add_subdirectory(parser/tflite)
|
||||
add_subdirectory(parser/onnx)
|
||||
add_subdirectory(parser/tf)
|
||||
if(ENABLE_CONVERT_PYTORCH_MODEL AND NOT WIN32)
|
||||
add_subdirectory(parser/pytorch)
|
||||
endif()
|
||||
|
||||
add_subdirectory(legacy_optimizer)
|
||||
add_subdirectory(quantizer)
|
||||
add_subdirectory(registry)
|
||||
|
@ -317,6 +321,10 @@ if(SUPPORT_TRAIN)
|
|||
target_link_libraries(converter_lite PRIVATE train_cpu_kernel_mid)
|
||||
endif()
|
||||
|
||||
if(ENABLE_CONVERT_PYTORCH_MODEL)
|
||||
target_link_libraries(converter_lite PRIVATE pytorch_parser_mid)
|
||||
endif()
|
||||
|
||||
if(NOT ENABLE_CLOUD_AND_LITE)
|
||||
target_link_libraries(converter_lite PRIVATE
|
||||
ccsrc_debug_common_mid_
|
||||
|
|
|
@ -69,6 +69,7 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
|||
} else {
|
||||
model_parser_ = registry::ModelParserRegistry::GetModelParser(flag.fmk);
|
||||
if (model_parser_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Unsupported to converter models with fmk: " << flag.fmkIn;
|
||||
return nullptr;
|
||||
}
|
||||
converter::ConverterParameters converter_parameters;
|
||||
|
|
|
@ -147,6 +147,8 @@ int Flags::InitFmk() {
|
|||
this->fmk = kFmkTypeOnnx;
|
||||
} else if (this->fmkIn == "TF") {
|
||||
this->fmk = kFmkTypeTf;
|
||||
} else if (this->fmkIn == "PYTORCH") {
|
||||
this->fmk = kFmkTypePytorch;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
|
|
|
@ -105,7 +105,7 @@ STATUS ReplaceTypeParameterNode(const FuncGraphPtr &func_graph, const ParameterP
|
|||
MS_LOG(ERROR) << "default data is not tensor::Tensor.";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
auto param_node_new = opt::BuildParameterNode(func_graph, param_node, tensor_info);
|
||||
auto param_node_new = opt::BuildParameterNode(func_graph, tensor_info, param_node->fullname_with_scope());
|
||||
if (param_node_new == nullptr) {
|
||||
MS_LOG(ERROR) << "BuildParameterNode failed.";
|
||||
return lite::RET_NULL_PTR;
|
||||
|
@ -166,7 +166,7 @@ STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|||
MS_LOG(ERROR) << "valueptr is not tensor::Tensorptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto param_node = opt::BuildParameterNode(func_graph, cnode, tensor_info_ptr);
|
||||
auto param_node = opt::BuildParameterNode(func_graph, tensor_info_ptr, cnode->fullname_with_scope());
|
||||
if (param_node == nullptr) {
|
||||
MS_LOG(ERROR) << "convert constant to param node failed.";
|
||||
return lite::RET_ERROR;
|
||||
|
@ -413,7 +413,7 @@ STATUS AdjustRandomNormal(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
|
|||
auto tensor_info = CreateTensorInfo(data, data_size, shape, data_type);
|
||||
free(data);
|
||||
MS_CHECK_TRUE_RET(tensor_info != nullptr, RET_ERROR);
|
||||
auto parameter = opt::BuildParameterNode(func_graph, cnode, tensor_info);
|
||||
auto parameter = opt::BuildParameterNode(func_graph, tensor_info, cnode->fullname_with_scope());
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "BuildParameterNode failed.";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
include(${TOP_DIR}/mindspore/lite/cmake/merge.cmake)
|
||||
merge_parser(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_BINARY_DIR}/tools/converter/parser/pytorch/pytorch_op_parser.cc)
|
||||
file(GLOB_RECURSE PYTORCH_SRC_LIST ${CMAKE_BINARY_DIR}/tools/converter/parser/pytorch/pytorch_op_parser.cc)
|
||||
set_property(SOURCE ${PYTORCH_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||
add_library(pytorch_parser_mid OBJECT
|
||||
${PYTORCH_SRC_LIST}
|
||||
)
|
||||
|
||||
add_compile_definitions(C10_USE_GLOG)
|
||||
if(NOT EXISTS ${LIB_TORCH_PATH})
|
||||
message(FATAL_ERROR "Path of libtorch is invalid.")
|
||||
endif()
|
||||
find_package(Torch REQUIRED PATHS ${LIB_TORCH_PATH})
|
||||
if(TORCH_FOUND)
|
||||
target_link_libraries(pytorch_parser_mid PRIVATE ${TORCH_LIBRARIES})
|
||||
target_include_directories(pytorch_parser_mid PRIVATE ${TORCH_INCLUDE_DIRS})
|
||||
else()
|
||||
message(FATAL_ERROR "Torch is not found")
|
||||
endif()
|
||||
|
||||
add_dependencies(pytorch_parser_mid fbs_src)
|
||||
add_dependencies(pytorch_parser_mid fbs_inner_src)
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/pytorch/pytorch_conv_parser.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
PrimitiveCPtr PytorchConvParser::Parse(const torch::jit::Node *torch_node, std::vector<size_t> *input_indices) {
|
||||
MS_ASSERT(torch_node != nullptr && input_indices != nullptr);
|
||||
auto prim = std::make_unique<ops::Conv2DFusion>();
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
auto bias = torch_node->input(kBiasIndex);
|
||||
MS_CHECK_TRUE_RET(bias != nullptr, nullptr);
|
||||
// the bias is noneType
|
||||
size_t input_size = bias->isCompleteTensor() ? kInputSize2 : kInputSize1;
|
||||
input_indices->resize(input_size);
|
||||
std::iota(input_indices->begin(), input_indices->end(), 0);
|
||||
|
||||
auto stride = PytorchNodeParser::GetValueFromConstNode<std::vector<int64_t>>(torch_node->input(FOURTH_INPUT));
|
||||
prim->set_stride(stride);
|
||||
auto dilation = PytorchNodeParser::GetValueFromConstNode<std::vector<int64_t>>(torch_node->input(SIXTH_INPUT));
|
||||
prim->set_dilation(dilation);
|
||||
|
||||
auto padding = PytorchNodeParser::GetValueFromConstNode<std::vector<int64_t>>(torch_node->input(FIFTH_INPUT));
|
||||
if (padding.size() == DIMENSION_2D) {
|
||||
padding.push_back(padding.at(1));
|
||||
padding.insert(padding.begin(), padding.at(0));
|
||||
}
|
||||
prim->set_pad_list(padding);
|
||||
|
||||
auto group = PytorchNodeParser::GetValueFromConstNode<int64_t>(torch_node->input(8));
|
||||
prim->set_group(group);
|
||||
|
||||
prim->set_pad({0, 0, 0, 0});
|
||||
mindspore::PadMode pad_mode = mindspore::PadMode::PAD;
|
||||
prim->set_pad_mode(pad_mode);
|
||||
auto prim_c = prim->GetPrim();
|
||||
MS_CHECK_TRUE_RET(prim_c != nullptr, nullptr);
|
||||
mindspore::Format format = mindspore::Format::NCHW;
|
||||
prim_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue<int64_t>(format));
|
||||
bool conv1d = stride.size() == 1;
|
||||
if (conv1d) {
|
||||
prim_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue<int64_t>(NCW));
|
||||
}
|
||||
|
||||
// parse activationType
|
||||
prim->set_activation_type(mindspore::ActivationType::NO_ACTIVATION);
|
||||
|
||||
return prim->GetPrim();
|
||||
}
|
||||
|
||||
PytorchNodeRegistrar g_pytorchConvParser("conv2d", new PytorchConvParser());
|
||||
PytorchNodeRegistrar g_pytorchConvolutionParser("convolution", new PytorchConvParser());
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_CONV_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_CONV_PARSER_H
|
||||
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/pytorch/pytorch_node_parser.h"
|
||||
#include "tools/converter/parser/pytorch/pytorch_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class PytorchConvParser : public PytorchNodeParser {
|
||||
public:
|
||||
PytorchConvParser() : PytorchNodeParser("Conv") {}
|
||||
~PytorchConvParser() override = default;
|
||||
|
||||
PrimitiveCPtr Parse(const torch::jit::Node *torch_node, std::vector<size_t> *input_indices) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_CONV_PARSER_H
|
|
@ -0,0 +1,508 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/pytorch/pytorch_model_parser.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "torch/csrc/jit/passes/freeze_module.h"
|
||||
#include "torch/csrc/jit/passes/inliner.h"
|
||||
#include "torch/csrc/jit/passes/normalize_ops.h"
|
||||
#include "include/registry/node_parser_registry.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
#include "tools/converter/parser/lite_model_parser_creator.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/make_tuple.h"
|
||||
#include "ops/return.h"
|
||||
#include "ops/tuple_get_item.h"
|
||||
|
||||
using mindspore::converter::kFmkTypePytorch;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
api::FuncGraphPtr PytorchModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
auto model_file = flag.model_file;
|
||||
NotSupportOp::GetInstance()->set_fmk_type("PYTORCH");
|
||||
auto anf_graph = std::make_shared<FuncGraph>();
|
||||
MS_CHECK_TRUE_MSG(anf_graph != nullptr, nullptr, "create FuncGraph failed");
|
||||
res_graph_ = api::MakeShared<api::FuncGraph>(anf_graph);
|
||||
MS_CHECK_TRUE_MSG(res_graph_ != nullptr, nullptr, "create FuncGraph failed");
|
||||
auto status = InitOriginModel(model_file);
|
||||
if (RET_OK != status) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
MS_LOG(ERROR) << "init origin model failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = ConvertTorchGraph(anf_graph);
|
||||
if (RET_OK != status) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
MS_LOG(ERROR) << "convert pytorch graph failed.";
|
||||
return nullptr;
|
||||
}
|
||||
static auto root_func_manager = Manage(anf_graph);
|
||||
MS_ASSERT(root_func_manager != nullptr);
|
||||
for (auto &subgraph : all_subgraphs_) {
|
||||
MS_ASSERT(subgraph != nullptr);
|
||||
subgraph->set_manager(root_func_manager);
|
||||
subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypePytorch)));
|
||||
}
|
||||
anf_graph->set_attr("graph_name", MakeValue("main_graph"));
|
||||
anf_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypePytorch)));
|
||||
if ((status = CommonAnfAdjust(anf_graph)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "AdjustForAnf failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypePytorch, false);
|
||||
MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr");
|
||||
if (!unify_format->Run(anf_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return res_graph_;
|
||||
}
|
||||
STATUS PytorchModelParser::InitOriginModel(const std::string &model_file) {
|
||||
if (ValidateFileStr(model_file, ".pt") != RET_OK && ValidateFileStr(model_file, ".pth") != RET_OK) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pt or *.pth";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::string model_path = RealPath(model_file.c_str());
|
||||
if (model_path.empty()) {
|
||||
MS_LOG(ERROR) << "Binary proto file path " << model_file << " is not valid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// only linux supports to convert pytorch model.
|
||||
if (access(model_path.c_str(), F_OK) != 0 || access(model_path.c_str(), R_OK) != 0) {
|
||||
MS_LOG(ERROR) << "The pytorch model file is not exist or can't be read.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto module = torch::jit::load(model_path);
|
||||
module.eval(); // eval to expand function call
|
||||
module = torch::jit::freeze_module(module); // freeze module
|
||||
torch_model_ = module.get_method("forward").graph();
|
||||
CHECK_NULL_RETURN(torch_model_);
|
||||
// parse submodules in graph
|
||||
torch::jit::Inline(*torch_model_);
|
||||
torch::jit::NormalizeOps(torch_model_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PytorchModelParser::ConvertTorchGraph(const FuncGraphPtr &anf_graph) {
|
||||
MS_ASSERT(torch_graph != nullptr && anf_graph != nullptr && anf_nodes_map != nullptr &&
|
||||
extra_subgraph_inputs != nullptr);
|
||||
STATUS status = ConvertGraphInputs(anf_graph);
|
||||
if (RET_OK != status) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
MS_LOG(ERROR) << "convert graph inputs failed.";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
status = ConvertNodes(anf_graph);
|
||||
if (RET_OK != status) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
MS_LOG(ERROR) << "convert nodes failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
status = ConvertGraphOutputs(anf_graph);
|
||||
if (RET_OK != status) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
MS_LOG(ERROR) << "convert graph outputs failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
STATUS PytorchModelParser::ConvertGraphInputs(const FuncGraphPtr &anf_graph) {
|
||||
MS_ASSERT(anf_graph != nullptr && anf_nodes_map != nullptr);
|
||||
for (auto &input : torch_model_->inputs()) {
|
||||
auto input_name = input->debugName();
|
||||
if (anf_nodes_map_.find(input_name) != anf_nodes_map_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto type = input->type();
|
||||
MS_CHECK_TRUE_RET(type != nullptr, RET_ERROR);
|
||||
auto tensor_type = type->cast<at::TensorType>();
|
||||
if (tensor_type == nullptr) {
|
||||
MS_LOG(DEBUG) << "The input is not a tensor, but a: " << c10::typeKindToString(type->kind());
|
||||
continue;
|
||||
}
|
||||
auto scalar_type = tensor_type->scalarType().value_or(at::ScalarType::Float);
|
||||
auto data_type = PytorchNodeParser::GetDataTypeFromTorch(scalar_type);
|
||||
if (data_type == kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "not support pytorch data type " << scalar_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int64_t> input_shape = ConverterInnerContext::GetInstance()->GetGraphInputTensorShape(input_name);
|
||||
if (input_shape.empty()) {
|
||||
if (tensor_type->sizes().isComplete()) {
|
||||
input_shape = tensor_type->sizes().concrete_sizes().value();
|
||||
} else {
|
||||
MS_LOG(WARNING) << "The input shape is empty.";
|
||||
}
|
||||
}
|
||||
auto abstract_tensor = CreateTensorAbstract(input_shape, data_type);
|
||||
if (abstract_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto parameter = anf_graph->add_parameter();
|
||||
MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
|
||||
parameter->set_abstract(abstract_tensor);
|
||||
parameter->set_name(input_name);
|
||||
anf_nodes_map_.emplace(input_name, parameter);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector<AnfNodePtr> &return_inputs) {
|
||||
MS_ASSERT(anf_graph != nullptr);
|
||||
auto return_prim_ptr = std::make_shared<ops::Return>();
|
||||
if (return_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "new Return failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto return_prim = return_prim_ptr->GetPrim();
|
||||
MS_CHECK_TRUE_RET(return_prim != nullptr, RET_ERROR);
|
||||
auto return_cnode = anf_graph->NewCNode(return_prim, return_inputs);
|
||||
if (return_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "new cnode error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return_cnode->set_fullname_with_scope("Return");
|
||||
anf_graph->set_return(return_cnode);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PytorchModelParser::ConvertGraphOutputs(const FuncGraphPtr &anf_graph) {
|
||||
MS_ASSERT(anf_graph != nullptr);
|
||||
std::vector<AnfNodePtr> return_inputs;
|
||||
if (torch_model_->outputs().size() == 0) {
|
||||
MS_LOG(ERROR) << "pytorch graph has no output";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (torch_model_->outputs().size() > 1) {
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
||||
if (make_tuple_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "new MakeTuple failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
for (const auto &output : torch_model_->outputs()) {
|
||||
auto output_name = output->debugName();
|
||||
if (anf_nodes_map_.find(output_name) == anf_nodes_map_.end()) {
|
||||
MS_LOG(ERROR) << "graph output get failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto cnode = anf_nodes_map_.at(output_name);
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Can't find input node.";
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
make_tuple_inputs.emplace_back(cnode);
|
||||
}
|
||||
auto make_tuple_prim = make_tuple_prim_ptr->GetPrim();
|
||||
MS_CHECK_TRUE_RET(make_tuple_prim != nullptr, RET_ERROR);
|
||||
auto make_tuple_cnode = anf_graph->NewCNode(make_tuple_prim, make_tuple_inputs);
|
||||
if (make_tuple_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "new cnode error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
||||
return_inputs.emplace_back(make_tuple_cnode);
|
||||
} else {
|
||||
const auto &output = torch_model_->outputs().front();
|
||||
if (anf_nodes_map_.find(output->debugName()) == anf_nodes_map_.end()) {
|
||||
MS_LOG(ERROR) << "graph output get failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto cnode = anf_nodes_map_.at(output->debugName());
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Can't find input node.";
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
return_inputs.emplace_back(cnode);
|
||||
}
|
||||
if (BuildReturnNode(anf_graph, return_inputs) != RET_OK) {
|
||||
MS_LOG(ERROR) << "build return node failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS CopyDataFromTorchTensor(char *dst_data, const at::Tensor &torch_tensor, TypeId data_type) {
|
||||
auto ele_size = abstract::TypeIdSize(data_type);
|
||||
MS_CHECK_TRUE_RET(ele_size > 0, RET_ERROR);
|
||||
auto data_shape = torch_tensor.sizes().vec();
|
||||
auto stride = torch_tensor.strides().vec();
|
||||
if (data_shape.empty()) {
|
||||
auto data_size = torch_tensor.numel() * ele_size;
|
||||
data_shape.push_back(data_size);
|
||||
stride.push_back(1);
|
||||
}
|
||||
char *data_ptr = reinterpret_cast<char *>(torch_tensor.data_ptr());
|
||||
if (data_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "The tensor data is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t idx = 0;
|
||||
std::function<void(size_t, size_t)> copy_data = [&](size_t dim, size_t offset) {
|
||||
if (dim == data_shape.size() - 1) {
|
||||
for (int i = 0; i < data_shape[dim]; i++) {
|
||||
auto src_ptr = data_ptr + offset + i * stride[dim] * ele_size;
|
||||
auto dst_ptr = dst_data + (idx++) * ele_size;
|
||||
memcpy(dst_ptr, src_ptr, ele_size);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < data_shape[dim]; i++) {
|
||||
copy_data(dim + 1, offset + i * stride[dim] * ele_size);
|
||||
}
|
||||
}
|
||||
};
|
||||
copy_data(0, 0);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS ConvertConstNode(const torch::jit::Node *torch_node, const FuncGraphPtr &anf_graph,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map) {
|
||||
ParameterPtr parameter = nullptr;
|
||||
auto output = torch_node->output();
|
||||
auto type_kind = output->type()->kind();
|
||||
auto value = torch::jit::toIValue(output);
|
||||
MS_CHECK_TRUE_RET(value.has_value(), RET_ERROR);
|
||||
switch (type_kind) {
|
||||
case c10::TypeKind::BoolType: {
|
||||
auto data = static_cast<int>(value.value().toBool());
|
||||
parameter = opt::BuildIntValueParameterNode(anf_graph, data, output->debugName());
|
||||
} break;
|
||||
case c10::TypeKind::IntType: {
|
||||
auto data = static_cast<int>(value.value().toInt());
|
||||
parameter = opt::BuildIntValueParameterNode(anf_graph, data, output->debugName());
|
||||
} break;
|
||||
case c10::TypeKind::FloatType: {
|
||||
auto data = static_cast<float>(value.value().toDouble());
|
||||
parameter = opt::BuildFloatValueParameterNode(anf_graph, data, output->debugName());
|
||||
} break;
|
||||
case c10::TypeKind::ListType: {
|
||||
auto element_type = value->toList().elementType()->kind();
|
||||
switch (element_type) {
|
||||
case c10::TypeKind::IntType: {
|
||||
auto ori_data = value.value().toIntVector();
|
||||
std::vector<int> data;
|
||||
std::transform(ori_data.begin(), ori_data.end(), std::back_inserter(data),
|
||||
[](int64_t ele) { return static_cast<int>(ele); });
|
||||
parameter = opt::BuildIntVecParameterNode(anf_graph, data, output->debugName());
|
||||
} break;
|
||||
case c10::TypeKind::FloatType: {
|
||||
auto ori_data = value.value().toDoubleVector();
|
||||
std::vector<float> data;
|
||||
std::transform(ori_data.begin(), ori_data.end(), std::back_inserter(data),
|
||||
[](double ele) { return static_cast<float>(ele); });
|
||||
parameter = opt::BuildFloatVecParameterNode(anf_graph, data, output->debugName());
|
||||
} break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported data type: " << c10::typeKindToString(element_type);
|
||||
return RET_ERROR;
|
||||
}
|
||||
} break;
|
||||
case c10::TypeKind::TensorType: {
|
||||
auto torch_tensor = value.value().toTensor();
|
||||
auto data_type = PytorchNodeParser::GetDataTypeFromTorch(torch_tensor.scalar_type());
|
||||
auto data_size = torch_tensor.numel() * abstract::TypeIdSize(data_type);
|
||||
char *data_ptr = reinterpret_cast<char *>(malloc(data_size));
|
||||
if (data_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc data failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (CopyDataFromTorchTensor(data_ptr, torch_tensor, data_type) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Copy data from torch tensor failed.";
|
||||
free(data_ptr);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto data_shape = torch_tensor.sizes().vec();
|
||||
auto tensor_info = CreateTensorInfo(data_ptr, data_size, data_shape, data_type);
|
||||
free(data_ptr);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensorInfo failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
parameter = opt::BuildParameterNode(anf_graph, tensor_info, output->debugName());
|
||||
} break;
|
||||
case c10::TypeKind::NoneType:
|
||||
MS_LOG(DEBUG) << "The const node is none.";
|
||||
return RET_OK;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported data type: " << c10::typeKindToString(type_kind);
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "The parameter is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
anf_nodes_map->emplace(output->debugName(), parameter);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS BuildOpInputs(const torch::jit::Node *torch_node, std::vector<AnfNodePtr> *op_inputs,
|
||||
const std::vector<size_t> &input_indices,
|
||||
const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map) {
|
||||
MS_ASSERT(torch_node != nullptr && op_inputs != nullptr);
|
||||
for (size_t idx : input_indices) {
|
||||
auto input = torch_node->input(idx);
|
||||
MS_CHECK_TRUE_RET(input != nullptr, RET_ERROR);
|
||||
auto input_name = input->debugName();
|
||||
if (input_name.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (anf_nodes_map.find(input_name) != anf_nodes_map.end()) {
|
||||
op_inputs->push_back(anf_nodes_map.at(input_name));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "could not find input node: " << input_name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS BuildOpOutputs(const torch::jit::Node *torch_node, const FuncGraphPtr &anf_graph,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode) {
|
||||
MS_ASSERT(torch_node != nullptr && anf_graph != nullptr && cnode != nullptr && anf_nodes_map != nullptr);
|
||||
if (torch_node->outputs().size() == 1) {
|
||||
auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||
if (abstract_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->set_abstract(abstract_tensor);
|
||||
anf_nodes_map->emplace(torch_node->output()->debugName(), cnode);
|
||||
} else {
|
||||
AbstractBasePtrList abstract_list;
|
||||
int op_idx = 0;
|
||||
for (const auto &output : torch_node->outputs()) {
|
||||
auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||
if (abstract_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
abstract_list.emplace_back(abstract_tensor);
|
||||
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||
if (tuple_get_item_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto tuple_get_item_prim = tuple_get_item_prim_ptr->GetPrim();
|
||||
MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, RET_NULL_PTR, "get prim return nullptr");
|
||||
auto tuple_get_item = NewValueNode(tuple_get_item_prim);
|
||||
MS_CHECK_TRUE_MSG(tuple_get_item != nullptr, RET_NULL_PTR, "create ValueNode return nullptr");
|
||||
auto get_item_value = NewValueNode(MakeValue<int>(op_idx));
|
||||
MS_CHECK_TRUE_MSG(get_item_value != nullptr, RET_NULL_PTR, "create ValueNode return nullptr");
|
||||
std::vector<AnfNodePtr> inputs{tuple_get_item, cnode, get_item_value};
|
||||
CNodePtr get_item_cnode = anf_graph->NewCNode(inputs);
|
||||
if (get_item_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "new cnode error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
get_item_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx));
|
||||
auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||
if (get_item_abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
get_item_cnode->set_abstract(get_item_abstract);
|
||||
anf_nodes_map->emplace(output->debugName(), get_item_cnode);
|
||||
op_idx++;
|
||||
}
|
||||
auto new_abstract_list = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
CHECK_NULL_RETURN(new_abstract_list);
|
||||
cnode->set_abstract(new_abstract_list);
|
||||
}
|
||||
anf_nodes_map->emplace(torch_node->kind().toUnqualString(), cnode);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PytorchModelParser::ConvertNodes(const FuncGraphPtr &anf_graph) {
|
||||
MS_ASSERT(anf_graph != nullptr);
|
||||
STATUS status = RET_OK;
|
||||
for (const auto &torch_node : torch_model_->nodes()) {
|
||||
ops::PrimitiveCPtr primitive_c = nullptr;
|
||||
auto node_type = PytorchNodeParser::GetTorchNodeType(torch_node);
|
||||
MS_CHECK_TRUE_RET(!node_type.empty(), RET_ERROR);
|
||||
// convert constant node.
|
||||
if (node_type == "Constant") {
|
||||
if (ConvertConstNode(torch_node, anf_graph, &anf_nodes_map_) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert constant node failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> op_inputs;
|
||||
std::vector<size_t> input_indices;
|
||||
auto node_parser_builtin = PytorchNodeParserRegistry::GetInstance().GetNodeParser(node_type);
|
||||
if (node_parser_builtin == nullptr) {
|
||||
NotSupportOp::GetInstance()->InsertOp(node_type);
|
||||
status = status == RET_OK ? RET_NOT_FIND_OP : status;
|
||||
MS_LOG(ERROR) << "not support pytorch op type " << node_type;
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "parse op:" << node_type;
|
||||
primitive_c = node_parser_builtin->Parse(torch_node, &input_indices);
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "parse node " << node_type << " failed.";
|
||||
status = RET_ERROR;
|
||||
continue;
|
||||
}
|
||||
// set default format and input indices.
|
||||
if (primitive_c->GetAttr(ops::kOriginalFormat) == nullptr) {
|
||||
primitive_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue<int64_t>(NCHW));
|
||||
}
|
||||
if (input_indices.empty()) {
|
||||
input_indices.resize(torch_node->inputs().size());
|
||||
std::iota(input_indices.begin(), input_indices.end(), 0);
|
||||
}
|
||||
|
||||
if (BuildOpInputs(torch_node, &op_inputs, input_indices, anf_nodes_map_) != RET_OK) {
|
||||
MS_LOG(ERROR) << "BuildOpInputs failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto new_cnode = anf_graph->NewCNode(primitive_c, op_inputs);
|
||||
if (new_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "new cnode error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
new_cnode->set_fullname_with_scope(std::string(torch_node->kind().toUnqualString()) + "_" +
|
||||
torch_node->output(0)->debugName());
|
||||
if (BuildOpOutputs(torch_node, anf_graph, &anf_nodes_map_, new_cnode) != RET_OK) {
|
||||
MS_LOG(ERROR) << "BuildOpOutputs failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(kFmkTypePytorch, LiteModelParserCreator<PytorchModelParser>)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_MODEL_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_MODEL_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "torch/script.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "include/registry/model_parser.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "tools/converter/parser/pytorch/pytorch_node_parser_registry.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class PytorchModelParser : public converter::ModelParser {
|
||||
public:
|
||||
PytorchModelParser() = default;
|
||||
|
||||
~PytorchModelParser() override = default;
|
||||
|
||||
api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
|
||||
private:
|
||||
STATUS InitOriginModel(const std::string &model_file);
|
||||
STATUS ConvertTorchGraph(const FuncGraphPtr &anf_graph);
|
||||
STATUS ConvertGraphInputs(const FuncGraphPtr &anf_graph);
|
||||
STATUS ConvertGraphOutputs(const FuncGraphPtr &anf_graph);
|
||||
STATUS ConvertNodes(const FuncGraphPtr &anf_graph);
|
||||
|
||||
std::shared_ptr<torch::jit::Graph> torch_model_;
|
||||
std::vector<FuncGraphPtr> all_subgraphs_{};
|
||||
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_{};
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_MODEL_PARSER_H
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/pytorch/pytorch_node_parser.h"
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
static std::unordered_map<at::ScalarType, TypeId> kTorchDataTypeTransferMap = {
|
||||
{at::ScalarType::Bool, kNumberTypeBool}, {at::ScalarType::Byte, kNumberTypeUInt8},
|
||||
{at::ScalarType::Char, kNumberTypeInt8}, {at::ScalarType::Int, kNumberTypeInt},
|
||||
{at::ScalarType::Long, kNumberTypeInt}, {at::ScalarType::Half, kNumberTypeFloat16},
|
||||
{at::ScalarType::Float, kNumberTypeFloat32}, {at::ScalarType::Double, kNumberTypeFloat32}};
|
||||
} // namespace
|
||||
|
||||
std::string PytorchNodeParser::GetTorchNodeType(const torch::jit::Node *torch_node) {
|
||||
const auto &kind = torch_node->kind();
|
||||
std::string node_type = kind.toUnqualString();
|
||||
if (node_type.empty()) {
|
||||
return node_type;
|
||||
}
|
||||
node_type = node_type.at(0) == '_' ? node_type.substr(1) : node_type;
|
||||
node_type = node_type.at(node_type.size() - 1) == '_' ? node_type.substr(0, node_type.size() - 1) : node_type;
|
||||
return node_type;
|
||||
}
|
||||
|
||||
TypeId PytorchNodeParser::GetDataTypeFromTorch(const at::ScalarType torch_data_type) {
|
||||
auto iter = kTorchDataTypeTransferMap.find(torch_data_type);
|
||||
if (iter == kTorchDataTypeTransferMap.end()) {
|
||||
MS_LOG(ERROR) << "Unsupported torch data type: " << torch_data_type;
|
||||
return kTypeUnknown;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_PARSER_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "torch/script.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_name.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class PytorchNodeParser {
|
||||
public:
|
||||
explicit PytorchNodeParser(std::string node_name) : name_(std::move(node_name)) {}
|
||||
|
||||
virtual ~PytorchNodeParser() = default;
|
||||
|
||||
virtual PrimitiveCPtr Parse(const torch::jit::Node *torch_node, std::vector<size_t> *input_indices) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static std::string GetTorchNodeType(const torch::jit::Node *torch_node);
|
||||
|
||||
static TypeId GetDataTypeFromTorch(const at::ScalarType torch_data_type);
|
||||
|
||||
template <typename T>
|
||||
static T GetValueFromConstNode(const torch::jit::Value *value_node) {
|
||||
T data{};
|
||||
auto ivalue = torch::jit::toIValue(value_node);
|
||||
MS_CHECK_TRUE_RET(ivalue.has_value(), data);
|
||||
auto value = ivalue.value();
|
||||
auto optional_value = value.toOptional<T>();
|
||||
MS_CHECK_TRUE_RET(optional_value, data);
|
||||
return optional_value.value();
|
||||
}
|
||||
|
||||
protected:
|
||||
const std::string name_{};
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_PARSER_H
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/pytorch/pytorch_node_parser_registry.h"
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
PytorchNodeParserRegistry::PytorchNodeParserRegistry() = default;
|
||||
|
||||
PytorchNodeParserRegistry::~PytorchNodeParserRegistry() {
|
||||
for (auto ite : parsers) {
|
||||
if (ite.second != nullptr) {
|
||||
delete ite.second;
|
||||
ite.second = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PytorchNodeParserRegistry &PytorchNodeParserRegistry::GetInstance() {
|
||||
static PytorchNodeParserRegistry instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
PytorchNodeParser *PytorchNodeParserRegistry::GetNodeParser(const std::string &name) const {
|
||||
auto it = parsers.find(name);
|
||||
if (it != parsers.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void PytorchNodeParserRegistry::RegNodeParser(const std::string &name, PytorchNodeParser *parser) {
|
||||
if (parser == nullptr) {
|
||||
MS_LOG(WARNING) << "Input PytorchNodeParser is nullptr";
|
||||
return;
|
||||
}
|
||||
if (this->parsers.find(name) != this->parsers.end()) {
|
||||
MS_LOG(WARNING) << "PytorchNodeParser " << name << " is already exist";
|
||||
return;
|
||||
}
|
||||
this->parsers[name] = parser;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_REGISTRY_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_REGISTRY_H
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "tools/converter/parser/pytorch/pytorch_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class PytorchNodeParserRegistry {
|
||||
public:
|
||||
virtual ~PytorchNodeParserRegistry();
|
||||
|
||||
static PytorchNodeParserRegistry &GetInstance();
|
||||
|
||||
PytorchNodeParser *GetNodeParser(const std::string &name) const;
|
||||
|
||||
void RegNodeParser(const std::string &name, PytorchNodeParser *parser);
|
||||
|
||||
private:
|
||||
PytorchNodeParserRegistry();
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, PytorchNodeParser *> parsers{};
|
||||
};
|
||||
|
||||
class PytorchNodeRegistrar {
|
||||
public:
|
||||
PytorchNodeRegistrar(const std::string &name, PytorchNodeParser *parser) {
|
||||
PytorchNodeParserRegistry::GetInstance().RegNodeParser(name, parser);
|
||||
}
|
||||
~PytorchNodeRegistrar() = default;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_REGISTRY_H
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/pytorch/pytorch_permute_parser.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ops/transpose.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
PrimitiveCPtr PytorchPermuteParser::Parse(const torch::jit::Node *torch_node, std::vector<size_t> *input_indices) {
|
||||
MS_ASSERT(torch_node != nullptr && input_indices != nullptr);
|
||||
auto prim = std::make_unique<ops::Transpose>();
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
return prim->GetPrim();
|
||||
}
|
||||
|
||||
PytorchNodeRegistrar g_pytorchPermuteParser("permute", new PytorchPermuteParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_PERMUTE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_PERMUTE_PARSER_H
|
||||
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/pytorch/pytorch_node_parser.h"
|
||||
#include "tools/converter/parser/pytorch/pytorch_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class PytorchPermuteParser : public PytorchNodeParser {
|
||||
public:
|
||||
PytorchPermuteParser() : PytorchNodeParser("Permute") {}
|
||||
~PytorchPermuteParser() override = default;
|
||||
|
||||
PrimitiveCPtr Parse(const torch::jit::Node *torch_node, std::vector<size_t> *input_indices) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_PERMUTE_PARSER_H
|
|
@ -121,7 +121,7 @@ STATUS TfliteInputsAdjust::ReplaceInt64ParameterNode(const FuncGraphPtr &func_gr
|
|||
MS_LOG(ERROR) << "default data is not tensor::Tensor.";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
auto param_node_new = opt::BuildParameterNode(func_graph, param_node, tensor_info);
|
||||
auto param_node_new = opt::BuildParameterNode(func_graph, tensor_info, param_node->fullname_with_scope());
|
||||
if (!manager->Replace(param_node, param_node_new)) {
|
||||
MS_LOG(ERROR) << "Replace param node failed.";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -208,7 +208,8 @@ STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode,
|
|||
{converter::kFmkTypeTf, DecideTFConvWeightSrcFormat},
|
||||
{converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat},
|
||||
{converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat},
|
||||
{converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}};
|
||||
{converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat},
|
||||
{converter::kFmkTypePytorch, DecideONNXConvWeightSrcFormat}};
|
||||
auto iter = decide_functions.find(fmk_type_);
|
||||
if (iter == decide_functions.end()) {
|
||||
MS_LOG(ERROR) << "current fmk don't support, please check.";
|
||||
|
|
|
@ -26,7 +26,7 @@ std::map<FmkType, ModelParserCreator> model_parser_room;
|
|||
} // namespace
|
||||
|
||||
ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) {
|
||||
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) {
|
||||
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) {
|
||||
MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
|
||||
return;
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator
|
|||
}
|
||||
|
||||
converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) {
|
||||
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) {
|
||||
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) {
|
||||
MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -728,9 +728,9 @@ STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_for
|
|||
return iter->second(tensor, src_format, dst_format);
|
||||
}
|
||||
|
||||
ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const tensor::TensorPtr &tensor_info) {
|
||||
if (func_graph == nullptr || node == nullptr || tensor_info == nullptr) {
|
||||
ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const tensor::TensorPtr &tensor_info,
|
||||
const std::string &node_name) {
|
||||
if (func_graph == nullptr || tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "input parameter is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -746,7 +746,7 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
} else if (tensor_info->data_type() == kNumberTypeFloat64) {
|
||||
data_type = kNumberTypeFloat32;
|
||||
}
|
||||
param_node->set_name(node->fullname_with_scope());
|
||||
param_node->set_name(node_name);
|
||||
auto tensor_info_new = std::make_shared<tensor::Tensor>(data_type, shape_vector);
|
||||
if (tensor_info_new == nullptr) {
|
||||
MS_LOG(ERROR) << "new tensor::Tensor failed.";
|
||||
|
|
|
@ -98,8 +98,8 @@ AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index);
|
|||
|
||||
STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format);
|
||||
|
||||
ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const tensor::TensorPtr &tensor_info);
|
||||
ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const tensor::TensorPtr &tensor_info,
|
||||
const std::string &node_name);
|
||||
|
||||
ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data,
|
||||
const std::string &node_name);
|
||||
|
|
Loading…
Reference in New Issue