!31199 [MSLITE] Support to convert pytorch model.

Merge pull request !31199 from wangshaocong/torch_converter
This commit is contained in:
i-robot 2022-04-22 03:21:44 +00:00 committed by Gitee
commit b1023addba
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
24 changed files with 1031 additions and 14 deletions

View File

@ -46,6 +46,7 @@
"mindspore/mindspore/lite/tools/common/flag_parser.cc" "useStlAlgorithm"
"mindspore/mindspore/lite/tools/common/tensor_util.cc" "useStlAlgorithm"
"mindspore/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc" "useStlAlgorithm"
"mindspore/mindspore/lite/tools/converter/parser/pytorch/pytorch_model_parser.cc" "variableScope"
"mindspore/mindspore/lite/tools/converter/quantizer/quantize_util.cc" "useStlAlgorithm"
"mindspore/mindspore/lite/src/runtime/kernel/opencl/kernel/" "unreadVariable"
"mindspore/mindspore/lite/src/runtime/kernel/opencl/cl/" "unreadVariable"

View File

@ -525,7 +525,16 @@ if(PLATFORM_ARM64)
install(FILES ${opencv_LIBPATH}/libopencv_imgproc.so.4.5.2
DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libopencv_imgproc.so.4.5
COMPONENT ${RUNTIME_COMPONENT_NAME})
if(ENABLE_CONVERT_PYTORCH_MODEL AND NOT WIN32)
install(FILES ${LIB_TORCH_PATH}/lib/libtorch.so DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${LIB_TORCH_PATH}/lib/libtorch_cpu.so DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${LIB_TORCH_PATH}/lib/libc10.so DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${LIB_TORCH_PATH}/lib/libgomp.so DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
if(MSLITE_ENABLE_ACL)
set(LITE_ACL_DIR ${TOP_DIR}/mindspore/lite/build/tools/converter/adapter/acl)
install(FILES ${LITE_ACL_DIR}/mindspore_shared_lib/libmindspore_shared_lib.so

View File

@ -55,6 +55,7 @@ option(MSLITE_ENABLE_PARALLEL_INFERENCE "enable parallel inference interface" of
option(MSLITE_ENABLE_SHARING_MODEL_WEIGHT "enable sharing model weight" off)
option(MSLITE_ENABLE_EXPERIMENTAL_KERNEL "enable experimental kernel" on)
option(MSLITE_ENABLE_GRAPH_KERNEL "enable graph kernel" off)
option(MSLITE_ENABLE_CONVERT_PYTORCH_MODEL "enable to convert pytorch model" off)
#Option that can be configured through manually
option(ENABLE_VERBOSE "" off)
@ -175,6 +176,11 @@ if(DEFINED ENV{MSLITE_ENABLE_SERVING})
set(MSLITE_ENABLE_SERVING $ENV{MSLITE_ENABLE_SERVING})
endif()
if(DEFINED ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL} AND DEFINED ENV{LIB_TORCH_PATH})
set(ENABLE_CONVERT_PYTORCH_MODEL $ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL})
set(LIB_TORCH_PATH $ENV{LIB_TORCH_PATH})
endif()
if(MACHINE_LINUX_ARM64)
add_compile_definitions(MACHINE_LINUX_ARM64)
add_compile_definitions(LINUX_RUNTIME)

View File

@ -32,6 +32,7 @@ enum MS_API FmkType : int {
kFmkTypeOnnx = 2,
kFmkTypeMs = 3,
kFmkTypeTflite = 4,
kFmkTypePytorch = 5,
};
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.

View File

@ -79,6 +79,10 @@ add_subdirectory(parser/caffe)
add_subdirectory(parser/tflite)
add_subdirectory(parser/onnx)
add_subdirectory(parser/tf)
if(ENABLE_CONVERT_PYTORCH_MODEL AND NOT WIN32)
add_subdirectory(parser/pytorch)
endif()
add_subdirectory(legacy_optimizer)
add_subdirectory(quantizer)
add_subdirectory(registry)
@ -317,6 +321,10 @@ if(SUPPORT_TRAIN)
target_link_libraries(converter_lite PRIVATE train_cpu_kernel_mid)
endif()
if(ENABLE_CONVERT_PYTORCH_MODEL)
target_link_libraries(converter_lite PRIVATE pytorch_parser_mid)
endif()
if(NOT ENABLE_CLOUD_AND_LITE)
target_link_libraries(converter_lite PRIVATE
ccsrc_debug_common_mid_

View File

@ -69,6 +69,7 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
} else {
model_parser_ = registry::ModelParserRegistry::GetModelParser(flag.fmk);
if (model_parser_ == nullptr) {
MS_LOG(ERROR) << "Unsupported to converter models with fmk: " << flag.fmkIn;
return nullptr;
}
converter::ConverterParameters converter_parameters;

View File

@ -147,6 +147,8 @@ int Flags::InitFmk() {
this->fmk = kFmkTypeOnnx;
} else if (this->fmkIn == "TF") {
this->fmk = kFmkTypeTf;
} else if (this->fmkIn == "PYTORCH") {
this->fmk = kFmkTypePytorch;
} else {
std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl;
return RET_INPUT_PARAM_INVALID;

View File

@ -105,7 +105,7 @@ STATUS ReplaceTypeParameterNode(const FuncGraphPtr &func_graph, const ParameterP
MS_LOG(ERROR) << "default data is not tensor::Tensor.";
return lite::RET_NULL_PTR;
}
auto param_node_new = opt::BuildParameterNode(func_graph, param_node, tensor_info);
auto param_node_new = opt::BuildParameterNode(func_graph, tensor_info, param_node->fullname_with_scope());
if (param_node_new == nullptr) {
MS_LOG(ERROR) << "BuildParameterNode failed.";
return lite::RET_NULL_PTR;
@ -166,7 +166,7 @@ STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_LOG(ERROR) << "valueptr is not tensor::Tensorptr.";
return lite::RET_ERROR;
}
auto param_node = opt::BuildParameterNode(func_graph, cnode, tensor_info_ptr);
auto param_node = opt::BuildParameterNode(func_graph, tensor_info_ptr, cnode->fullname_with_scope());
if (param_node == nullptr) {
MS_LOG(ERROR) << "convert constant to param node failed.";
return lite::RET_ERROR;
@ -413,7 +413,7 @@ STATUS AdjustRandomNormal(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
auto tensor_info = CreateTensorInfo(data, data_size, shape, data_type);
free(data);
MS_CHECK_TRUE_RET(tensor_info != nullptr, RET_ERROR);
auto parameter = opt::BuildParameterNode(func_graph, cnode, tensor_info);
auto parameter = opt::BuildParameterNode(func_graph, tensor_info, cnode->fullname_with_scope());
if (parameter == nullptr) {
MS_LOG(ERROR) << "BuildParameterNode failed.";
return RET_ERROR;

View File

@ -0,0 +1,22 @@
include(${TOP_DIR}/mindspore/lite/cmake/merge.cmake)
merge_parser(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_BINARY_DIR}/tools/converter/parser/pytorch/pytorch_op_parser.cc)
file(GLOB_RECURSE PYTORCH_SRC_LIST ${CMAKE_BINARY_DIR}/tools/converter/parser/pytorch/pytorch_op_parser.cc)
set_property(SOURCE ${PYTORCH_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_library(pytorch_parser_mid OBJECT
${PYTORCH_SRC_LIST}
)
add_compile_definitions(C10_USE_GLOG)
if(NOT EXISTS ${LIB_TORCH_PATH})
message(FATAL_ERROR "Path of libtorch is invalid.")
endif()
find_package(Torch REQUIRED PATHS ${LIB_TORCH_PATH})
if(TORCH_FOUND)
target_link_libraries(pytorch_parser_mid PRIVATE ${TORCH_LIBRARIES})
target_include_directories(pytorch_parser_mid PRIVATE ${TORCH_INCLUDE_DIRS})
else()
message(FATAL_ERROR "Torch is not found")
endif()
add_dependencies(pytorch_parser_mid fbs_src)
add_dependencies(pytorch_parser_mid fbs_inner_src)

View File

@ -0,0 +1,72 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/pytorch/pytorch_conv_parser.h"
#include <algorithm>
#include <memory>
#include <vector>
#include <string>
#include "ops/fusion/conv2d_fusion.h"
#include "nnacl/op_base.h"
namespace mindspore::lite {
PrimitiveCPtr PytorchConvParser::Parse(const torch::jit::Node *torch_node, std::vector<size_t> *input_indices) {
MS_ASSERT(torch_node != nullptr && input_indices != nullptr);
auto prim = std::make_unique<ops::Conv2DFusion>();
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto bias = torch_node->input(kBiasIndex);
MS_CHECK_TRUE_RET(bias != nullptr, nullptr);
// the bias is noneType
size_t input_size = bias->isCompleteTensor() ? kInputSize2 : kInputSize1;
input_indices->resize(input_size);
std::iota(input_indices->begin(), input_indices->end(), 0);
auto stride = PytorchNodeParser::GetValueFromConstNode<std::vector<int64_t>>(torch_node->input(FOURTH_INPUT));
prim->set_stride(stride);
auto dilation = PytorchNodeParser::GetValueFromConstNode<std::vector<int64_t>>(torch_node->input(SIXTH_INPUT));
prim->set_dilation(dilation);
auto padding = PytorchNodeParser::GetValueFromConstNode<std::vector<int64_t>>(torch_node->input(FIFTH_INPUT));
if (padding.size() == DIMENSION_2D) {
padding.push_back(padding.at(1));
padding.insert(padding.begin(), padding.at(0));
}
prim->set_pad_list(padding);
auto group = PytorchNodeParser::GetValueFromConstNode<int64_t>(torch_node->input(8));
prim->set_group(group);
prim->set_pad({0, 0, 0, 0});
mindspore::PadMode pad_mode = mindspore::PadMode::PAD;
prim->set_pad_mode(pad_mode);
auto prim_c = prim->GetPrim();
MS_CHECK_TRUE_RET(prim_c != nullptr, nullptr);
mindspore::Format format = mindspore::Format::NCHW;
prim_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue<int64_t>(format));
bool conv1d = stride.size() == 1;
if (conv1d) {
prim_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue<int64_t>(NCW));
}
// parse activationType
prim->set_activation_type(mindspore::ActivationType::NO_ACTIVATION);
return prim->GetPrim();
}
PytorchNodeRegistrar g_pytorchConvParser("conv2d", new PytorchConvParser());
PytorchNodeRegistrar g_pytorchConvolutionParser("convolution", new PytorchConvParser());
} // namespace mindspore::lite

View File

@ -0,0 +1,35 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_CONV_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_CONV_PARSER_H
#include <vector>
#include "tools/converter/parser/pytorch/pytorch_node_parser.h"
#include "tools/converter/parser/pytorch/pytorch_node_parser_registry.h"
namespace mindspore {
namespace lite {
class PytorchConvParser : public PytorchNodeParser {
public:
PytorchConvParser() : PytorchNodeParser("Conv") {}
~PytorchConvParser() override = default;
PrimitiveCPtr Parse(const torch::jit::Node *torch_node, std::vector<size_t> *input_indices) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_CONV_PARSER_H

View File

@ -0,0 +1,508 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/pytorch/pytorch_model_parser.h"
#include <algorithm>
#include <memory>
#include <unordered_map>
#include "torch/csrc/jit/passes/freeze_module.h"
#include "torch/csrc/jit/passes/inliner.h"
#include "torch/csrc/jit/passes/normalize_ops.h"
#include "include/registry/node_parser_registry.h"
#include "tools/common/graph_util.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/converter/parser/unify_format.h"
#include "tools/converter/parser/lite_model_parser_creator.h"
#include "src/common/file_utils.h"
#include "src/common/log_util.h"
#include "nnacl/op_base.h"
#include "ops/op_utils.h"
#include "ops/make_tuple.h"
#include "ops/return.h"
#include "ops/tuple_get_item.h"
using mindspore::converter::kFmkTypePytorch;
namespace mindspore {
namespace lite {
api::FuncGraphPtr PytorchModelParser::Parse(const converter::ConverterParameters &flag) {
auto model_file = flag.model_file;
NotSupportOp::GetInstance()->set_fmk_type("PYTORCH");
auto anf_graph = std::make_shared<FuncGraph>();
MS_CHECK_TRUE_MSG(anf_graph != nullptr, nullptr, "create FuncGraph failed");
res_graph_ = api::MakeShared<api::FuncGraph>(anf_graph);
MS_CHECK_TRUE_MSG(res_graph_ != nullptr, nullptr, "create FuncGraph failed");
auto status = InitOriginModel(model_file);
if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
MS_LOG(ERROR) << "init origin model failed.";
return nullptr;
}
status = ConvertTorchGraph(anf_graph);
if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
MS_LOG(ERROR) << "convert pytorch graph failed.";
return nullptr;
}
static auto root_func_manager = Manage(anf_graph);
MS_ASSERT(root_func_manager != nullptr);
for (auto &subgraph : all_subgraphs_) {
MS_ASSERT(subgraph != nullptr);
subgraph->set_manager(root_func_manager);
subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypePytorch)));
}
anf_graph->set_attr("graph_name", MakeValue("main_graph"));
anf_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypePytorch)));
if ((status = CommonAnfAdjust(anf_graph)) != RET_OK) {
MS_LOG(ERROR) << "AdjustForAnf failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypePytorch, false);
MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr");
if (!unify_format->Run(anf_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}
return res_graph_;
}
STATUS PytorchModelParser::InitOriginModel(const std::string &model_file) {
if (ValidateFileStr(model_file, ".pt") != RET_OK && ValidateFileStr(model_file, ".pth") != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pt or *.pth";
return RET_ERROR;
}
std::string model_path = RealPath(model_file.c_str());
if (model_path.empty()) {
MS_LOG(ERROR) << "Binary proto file path " << model_file << " is not valid";
return RET_ERROR;
}
// only linux supports to convert pytorch model.
if (access(model_path.c_str(), F_OK) != 0 || access(model_path.c_str(), R_OK) != 0) {
MS_LOG(ERROR) << "The pytorch model file is not exist or can't be read.";
return RET_ERROR;
}
auto module = torch::jit::load(model_path);
module.eval(); // eval to expand function call
module = torch::jit::freeze_module(module); // freeze module
torch_model_ = module.get_method("forward").graph();
CHECK_NULL_RETURN(torch_model_);
// parse submodules in graph
torch::jit::Inline(*torch_model_);
torch::jit::NormalizeOps(torch_model_);
return RET_OK;
}
STATUS PytorchModelParser::ConvertTorchGraph(const FuncGraphPtr &anf_graph) {
MS_ASSERT(torch_graph != nullptr && anf_graph != nullptr && anf_nodes_map != nullptr &&
extra_subgraph_inputs != nullptr);
STATUS status = ConvertGraphInputs(anf_graph);
if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
MS_LOG(ERROR) << "convert graph inputs failed.";
return RET_OK;
}
status = ConvertNodes(anf_graph);
if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
MS_LOG(ERROR) << "convert nodes failed.";
return RET_ERROR;
}
status = ConvertGraphOutputs(anf_graph);
if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
MS_LOG(ERROR) << "convert graph outputs failed.";
return RET_ERROR;
}
return status;
}
STATUS PytorchModelParser::ConvertGraphInputs(const FuncGraphPtr &anf_graph) {
MS_ASSERT(anf_graph != nullptr && anf_nodes_map != nullptr);
for (auto &input : torch_model_->inputs()) {
auto input_name = input->debugName();
if (anf_nodes_map_.find(input_name) != anf_nodes_map_.end()) {
continue;
}
auto type = input->type();
MS_CHECK_TRUE_RET(type != nullptr, RET_ERROR);
auto tensor_type = type->cast<at::TensorType>();
if (tensor_type == nullptr) {
MS_LOG(DEBUG) << "The input is not a tensor, but a: " << c10::typeKindToString(type->kind());
continue;
}
auto scalar_type = tensor_type->scalarType().value_or(at::ScalarType::Float);
auto data_type = PytorchNodeParser::GetDataTypeFromTorch(scalar_type);
if (data_type == kTypeUnknown) {
MS_LOG(ERROR) << "not support pytorch data type " << scalar_type;
return RET_ERROR;
}
std::vector<int64_t> input_shape = ConverterInnerContext::GetInstance()->GetGraphInputTensorShape(input_name);
if (input_shape.empty()) {
if (tensor_type->sizes().isComplete()) {
input_shape = tensor_type->sizes().concrete_sizes().value();
} else {
MS_LOG(WARNING) << "The input shape is empty.";
}
}
auto abstract_tensor = CreateTensorAbstract(input_shape, data_type);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
auto parameter = anf_graph->add_parameter();
MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
parameter->set_abstract(abstract_tensor);
parameter->set_name(input_name);
anf_nodes_map_.emplace(input_name, parameter);
}
return RET_OK;
}
STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector<AnfNodePtr> &return_inputs) {
MS_ASSERT(anf_graph != nullptr);
auto return_prim_ptr = std::make_shared<ops::Return>();
if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new Return failed";
return RET_NULL_PTR;
}
auto return_prim = return_prim_ptr->GetPrim();
MS_CHECK_TRUE_RET(return_prim != nullptr, RET_ERROR);
auto return_cnode = anf_graph->NewCNode(return_prim, return_inputs);
if (return_cnode == nullptr) {
MS_LOG(ERROR) << "new cnode error";
return RET_ERROR;
}
return_cnode->set_fullname_with_scope("Return");
anf_graph->set_return(return_cnode);
return RET_OK;
}
STATUS PytorchModelParser::ConvertGraphOutputs(const FuncGraphPtr &anf_graph) {
MS_ASSERT(anf_graph != nullptr);
std::vector<AnfNodePtr> return_inputs;
if (torch_model_->outputs().size() == 0) {
MS_LOG(ERROR) << "pytorch graph has no output";
return RET_ERROR;
}
if (torch_model_->outputs().size() > 1) {
std::vector<AnfNodePtr> make_tuple_inputs;
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new MakeTuple failed";
return RET_NULL_PTR;
}
for (const auto &output : torch_model_->outputs()) {
auto output_name = output->debugName();
if (anf_nodes_map_.find(output_name) == anf_nodes_map_.end()) {
MS_LOG(ERROR) << "graph output get failed.";
return RET_ERROR;
}
auto cnode = anf_nodes_map_.at(output_name);
if (cnode == nullptr) {
MS_LOG(ERROR) << "Can't find input node.";
return RET_NOT_FIND_OP;
}
make_tuple_inputs.emplace_back(cnode);
}
auto make_tuple_prim = make_tuple_prim_ptr->GetPrim();
MS_CHECK_TRUE_RET(make_tuple_prim != nullptr, RET_ERROR);
auto make_tuple_cnode = anf_graph->NewCNode(make_tuple_prim, make_tuple_inputs);
if (make_tuple_cnode == nullptr) {
MS_LOG(ERROR) << "new cnode error";
return RET_ERROR;
}
make_tuple_cnode->set_fullname_with_scope("return tuple");
return_inputs.emplace_back(make_tuple_cnode);
} else {
const auto &output = torch_model_->outputs().front();
if (anf_nodes_map_.find(output->debugName()) == anf_nodes_map_.end()) {
MS_LOG(ERROR) << "graph output get failed.";
return RET_ERROR;
}
auto cnode = anf_nodes_map_.at(output->debugName());
if (cnode == nullptr) {
MS_LOG(ERROR) << "Can't find input node.";
return RET_NOT_FIND_OP;
}
return_inputs.emplace_back(cnode);
}
if (BuildReturnNode(anf_graph, return_inputs) != RET_OK) {
MS_LOG(ERROR) << "build return node failed.";
return RET_ERROR;
}
return RET_OK;
}
STATUS CopyDataFromTorchTensor(char *dst_data, const at::Tensor &torch_tensor, TypeId data_type) {
auto ele_size = abstract::TypeIdSize(data_type);
MS_CHECK_TRUE_RET(ele_size > 0, RET_ERROR);
auto data_shape = torch_tensor.sizes().vec();
auto stride = torch_tensor.strides().vec();
if (data_shape.empty()) {
auto data_size = torch_tensor.numel() * ele_size;
data_shape.push_back(data_size);
stride.push_back(1);
}
char *data_ptr = reinterpret_cast<char *>(torch_tensor.data_ptr());
if (data_ptr == nullptr) {
MS_LOG(ERROR) << "The tensor data is nullptr.";
return RET_ERROR;
}
size_t idx = 0;
std::function<void(size_t, size_t)> copy_data = [&](size_t dim, size_t offset) {
if (dim == data_shape.size() - 1) {
for (int i = 0; i < data_shape[dim]; i++) {
auto src_ptr = data_ptr + offset + i * stride[dim] * ele_size;
auto dst_ptr = dst_data + (idx++) * ele_size;
memcpy(dst_ptr, src_ptr, ele_size);
}
} else {
for (int i = 0; i < data_shape[dim]; i++) {
copy_data(dim + 1, offset + i * stride[dim] * ele_size);
}
}
};
copy_data(0, 0);
return RET_OK;
}
STATUS ConvertConstNode(const torch::jit::Node *torch_node, const FuncGraphPtr &anf_graph,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map) {
ParameterPtr parameter = nullptr;
auto output = torch_node->output();
auto type_kind = output->type()->kind();
auto value = torch::jit::toIValue(output);
MS_CHECK_TRUE_RET(value.has_value(), RET_ERROR);
switch (type_kind) {
case c10::TypeKind::BoolType: {
auto data = static_cast<int>(value.value().toBool());
parameter = opt::BuildIntValueParameterNode(anf_graph, data, output->debugName());
} break;
case c10::TypeKind::IntType: {
auto data = static_cast<int>(value.value().toInt());
parameter = opt::BuildIntValueParameterNode(anf_graph, data, output->debugName());
} break;
case c10::TypeKind::FloatType: {
auto data = static_cast<float>(value.value().toDouble());
parameter = opt::BuildFloatValueParameterNode(anf_graph, data, output->debugName());
} break;
case c10::TypeKind::ListType: {
auto element_type = value->toList().elementType()->kind();
switch (element_type) {
case c10::TypeKind::IntType: {
auto ori_data = value.value().toIntVector();
std::vector<int> data;
std::transform(ori_data.begin(), ori_data.end(), std::back_inserter(data),
[](int64_t ele) { return static_cast<int>(ele); });
parameter = opt::BuildIntVecParameterNode(anf_graph, data, output->debugName());
} break;
case c10::TypeKind::FloatType: {
auto ori_data = value.value().toDoubleVector();
std::vector<float> data;
std::transform(ori_data.begin(), ori_data.end(), std::back_inserter(data),
[](double ele) { return static_cast<float>(ele); });
parameter = opt::BuildFloatVecParameterNode(anf_graph, data, output->debugName());
} break;
default:
MS_LOG(ERROR) << "Unsupported data type: " << c10::typeKindToString(element_type);
return RET_ERROR;
}
} break;
case c10::TypeKind::TensorType: {
auto torch_tensor = value.value().toTensor();
auto data_type = PytorchNodeParser::GetDataTypeFromTorch(torch_tensor.scalar_type());
auto data_size = torch_tensor.numel() * abstract::TypeIdSize(data_type);
char *data_ptr = reinterpret_cast<char *>(malloc(data_size));
if (data_ptr == nullptr) {
MS_LOG(ERROR) << "malloc data failed.";
return RET_ERROR;
}
if (CopyDataFromTorchTensor(data_ptr, torch_tensor, data_type) != RET_OK) {
MS_LOG(ERROR) << "Copy data from torch tensor failed.";
free(data_ptr);
return RET_ERROR;
}
auto data_shape = torch_tensor.sizes().vec();
auto tensor_info = CreateTensorInfo(data_ptr, data_size, data_shape, data_type);
free(data_ptr);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Create tensorInfo failed.";
return RET_ERROR;
}
parameter = opt::BuildParameterNode(anf_graph, tensor_info, output->debugName());
} break;
case c10::TypeKind::NoneType:
MS_LOG(DEBUG) << "The const node is none.";
return RET_OK;
default:
MS_LOG(ERROR) << "Unsupported data type: " << c10::typeKindToString(type_kind);
return RET_ERROR;
}
if (parameter == nullptr) {
MS_LOG(ERROR) << "The parameter is nullptr.";
return RET_ERROR;
}
anf_nodes_map->emplace(output->debugName(), parameter);
return RET_OK;
}
STATUS BuildOpInputs(const torch::jit::Node *torch_node, std::vector<AnfNodePtr> *op_inputs,
const std::vector<size_t> &input_indices,
const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map) {
MS_ASSERT(torch_node != nullptr && op_inputs != nullptr);
for (size_t idx : input_indices) {
auto input = torch_node->input(idx);
MS_CHECK_TRUE_RET(input != nullptr, RET_ERROR);
auto input_name = input->debugName();
if (input_name.empty()) {
continue;
}
if (anf_nodes_map.find(input_name) != anf_nodes_map.end()) {
op_inputs->push_back(anf_nodes_map.at(input_name));
} else {
MS_LOG(ERROR) << "could not find input node: " << input_name;
return RET_ERROR;
}
}
return RET_OK;
}
STATUS BuildOpOutputs(const torch::jit::Node *torch_node, const FuncGraphPtr &anf_graph,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode) {
MS_ASSERT(torch_node != nullptr && anf_graph != nullptr && cnode != nullptr && anf_nodes_map != nullptr);
if (torch_node->outputs().size() == 1) {
auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
cnode->set_abstract(abstract_tensor);
anf_nodes_map->emplace(torch_node->output()->debugName(), cnode);
} else {
AbstractBasePtrList abstract_list;
int op_idx = 0;
for (const auto &output : torch_node->outputs()) {
auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract_tensor);
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed";
return RET_NULL_PTR;
}
auto tuple_get_item_prim = tuple_get_item_prim_ptr->GetPrim();
MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, RET_NULL_PTR, "get prim return nullptr");
auto tuple_get_item = NewValueNode(tuple_get_item_prim);
MS_CHECK_TRUE_MSG(tuple_get_item != nullptr, RET_NULL_PTR, "create ValueNode return nullptr");
auto get_item_value = NewValueNode(MakeValue<int>(op_idx));
MS_CHECK_TRUE_MSG(get_item_value != nullptr, RET_NULL_PTR, "create ValueNode return nullptr");
std::vector<AnfNodePtr> inputs{tuple_get_item, cnode, get_item_value};
CNodePtr get_item_cnode = anf_graph->NewCNode(inputs);
if (get_item_cnode == nullptr) {
MS_LOG(ERROR) << "new cnode error";
return RET_ERROR;
}
get_item_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx));
auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
if (get_item_abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
get_item_cnode->set_abstract(get_item_abstract);
anf_nodes_map->emplace(output->debugName(), get_item_cnode);
op_idx++;
}
auto new_abstract_list = std::make_shared<abstract::AbstractTuple>(abstract_list);
CHECK_NULL_RETURN(new_abstract_list);
cnode->set_abstract(new_abstract_list);
}
anf_nodes_map->emplace(torch_node->kind().toUnqualString(), cnode);
return RET_OK;
}
STATUS PytorchModelParser::ConvertNodes(const FuncGraphPtr &anf_graph) {
MS_ASSERT(anf_graph != nullptr);
STATUS status = RET_OK;
for (const auto &torch_node : torch_model_->nodes()) {
ops::PrimitiveCPtr primitive_c = nullptr;
auto node_type = PytorchNodeParser::GetTorchNodeType(torch_node);
MS_CHECK_TRUE_RET(!node_type.empty(), RET_ERROR);
// convert constant node.
if (node_type == "Constant") {
if (ConvertConstNode(torch_node, anf_graph, &anf_nodes_map_) != RET_OK) {
MS_LOG(ERROR) << "Convert constant node failed.";
return RET_ERROR;
}
continue;
}
std::vector<AnfNodePtr> op_inputs;
std::vector<size_t> input_indices;
auto node_parser_builtin = PytorchNodeParserRegistry::GetInstance().GetNodeParser(node_type);
if (node_parser_builtin == nullptr) {
NotSupportOp::GetInstance()->InsertOp(node_type);
status = status == RET_OK ? RET_NOT_FIND_OP : status;
MS_LOG(ERROR) << "not support pytorch op type " << node_type;
continue;
}
MS_LOG(INFO) << "parse op:" << node_type;
primitive_c = node_parser_builtin->Parse(torch_node, &input_indices);
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "parse node " << node_type << " failed.";
status = RET_ERROR;
continue;
}
// set default format and input indices.
if (primitive_c->GetAttr(ops::kOriginalFormat) == nullptr) {
primitive_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue<int64_t>(NCHW));
}
if (input_indices.empty()) {
input_indices.resize(torch_node->inputs().size());
std::iota(input_indices.begin(), input_indices.end(), 0);
}
if (BuildOpInputs(torch_node, &op_inputs, input_indices, anf_nodes_map_) != RET_OK) {
MS_LOG(ERROR) << "BuildOpInputs failed.";
return RET_ERROR;
}
auto new_cnode = anf_graph->NewCNode(primitive_c, op_inputs);
if (new_cnode == nullptr) {
MS_LOG(ERROR) << "new cnode error";
return RET_ERROR;
}
new_cnode->set_fullname_with_scope(std::string(torch_node->kind().toUnqualString()) + "_" +
torch_node->output(0)->debugName());
if (BuildOpOutputs(torch_node, anf_graph, &anf_nodes_map_, new_cnode) != RET_OK) {
MS_LOG(ERROR) << "BuildOpOutputs failed.";
return RET_ERROR;
}
}
return status;
}
REG_MODEL_PARSER(kFmkTypePytorch, LiteModelParserCreator<PytorchModelParser>)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_MODEL_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_MODEL_PARSER_H
#include <memory>
#include <string>
#include <vector>
#include <unordered_map>
#include "torch/script.h"
#include "securec/include/securec.h"
#include "include/registry/model_parser.h"
#include "include/registry/model_parser_registry.h"
#include "tools/converter/parser/pytorch/pytorch_node_parser_registry.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
class PytorchModelParser : public converter::ModelParser {
public:
PytorchModelParser() = default;
~PytorchModelParser() override = default;
api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
private:
STATUS InitOriginModel(const std::string &model_file);
STATUS ConvertTorchGraph(const FuncGraphPtr &anf_graph);
STATUS ConvertGraphInputs(const FuncGraphPtr &anf_graph);
STATUS ConvertGraphOutputs(const FuncGraphPtr &anf_graph);
STATUS ConvertNodes(const FuncGraphPtr &anf_graph);
std::shared_ptr<torch::jit::Graph> torch_model_;
std::vector<FuncGraphPtr> all_subgraphs_{};
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_{};
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_MODEL_PARSER_H

View File

@ -0,0 +1,50 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/pytorch/pytorch_node_parser.h"
#include <unordered_map>
namespace mindspore {
namespace lite {
namespace {
static std::unordered_map<at::ScalarType, TypeId> kTorchDataTypeTransferMap = {
{at::ScalarType::Bool, kNumberTypeBool}, {at::ScalarType::Byte, kNumberTypeUInt8},
{at::ScalarType::Char, kNumberTypeInt8}, {at::ScalarType::Int, kNumberTypeInt},
{at::ScalarType::Long, kNumberTypeInt}, {at::ScalarType::Half, kNumberTypeFloat16},
{at::ScalarType::Float, kNumberTypeFloat32}, {at::ScalarType::Double, kNumberTypeFloat32}};
} // namespace
std::string PytorchNodeParser::GetTorchNodeType(const torch::jit::Node *torch_node) {
const auto &kind = torch_node->kind();
std::string node_type = kind.toUnqualString();
if (node_type.empty()) {
return node_type;
}
node_type = node_type.at(0) == '_' ? node_type.substr(1) : node_type;
node_type = node_type.at(node_type.size() - 1) == '_' ? node_type.substr(0, node_type.size() - 1) : node_type;
return node_type;
}
TypeId PytorchNodeParser::GetDataTypeFromTorch(const at::ScalarType torch_data_type) {
auto iter = kTorchDataTypeTransferMap.find(torch_data_type);
if (iter == kTorchDataTypeTransferMap.end()) {
MS_LOG(ERROR) << "Unsupported torch data type: " << torch_data_type;
return kTypeUnknown;
}
return iter->second;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,66 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_PARSER_H
#include <string>
#include <vector>
#include <utility>
#include <algorithm>
#include "torch/script.h"
#include "include/errorcode.h"
#include "ir/dtype/type_id.h"
#include "ops/primitive_c.h"
#include "ops/op_name.h"
#include "src/common/log_adapter.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/parser/parser_utils.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
class PytorchNodeParser {
public:
explicit PytorchNodeParser(std::string node_name) : name_(std::move(node_name)) {}
virtual ~PytorchNodeParser() = default;
virtual PrimitiveCPtr Parse(const torch::jit::Node *torch_node, std::vector<size_t> *input_indices) {
return nullptr;
}
static std::string GetTorchNodeType(const torch::jit::Node *torch_node);
static TypeId GetDataTypeFromTorch(const at::ScalarType torch_data_type);
template <typename T>
static T GetValueFromConstNode(const torch::jit::Value *value_node) {
T data{};
auto ivalue = torch::jit::toIValue(value_node);
MS_CHECK_TRUE_RET(ivalue.has_value(), data);
auto value = ivalue.value();
auto optional_value = value.toOptional<T>();
MS_CHECK_TRUE_RET(optional_value, data);
return optional_value.value();
}
protected:
const std::string name_{};
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_PARSER_H

View File

@ -0,0 +1,58 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/pytorch/pytorch_node_parser_registry.h"
#include <string>
namespace mindspore {
namespace lite {
PytorchNodeParserRegistry::PytorchNodeParserRegistry() = default;
PytorchNodeParserRegistry::~PytorchNodeParserRegistry() {
for (auto ite : parsers) {
if (ite.second != nullptr) {
delete ite.second;
ite.second = nullptr;
}
}
}
PytorchNodeParserRegistry &PytorchNodeParserRegistry::GetInstance() {
static PytorchNodeParserRegistry instance;
return instance;
}
PytorchNodeParser *PytorchNodeParserRegistry::GetNodeParser(const std::string &name) const {
auto it = parsers.find(name);
if (it != parsers.end()) {
return it->second;
}
return nullptr;
}
void PytorchNodeParserRegistry::RegNodeParser(const std::string &name, PytorchNodeParser *parser) {
if (parser == nullptr) {
MS_LOG(WARNING) << "Input PytorchNodeParser is nullptr";
return;
}
if (this->parsers.find(name) != this->parsers.end()) {
MS_LOG(WARNING) << "PytorchNodeParser " << name << " is already exist";
return;
}
this->parsers[name] = parser;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_REGISTRY_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_REGISTRY_H
#include <string>
#include <unordered_map>
#include "tools/converter/parser/pytorch/pytorch_node_parser.h"
namespace mindspore {
namespace lite {
class PytorchNodeParserRegistry {
public:
virtual ~PytorchNodeParserRegistry();
static PytorchNodeParserRegistry &GetInstance();
PytorchNodeParser *GetNodeParser(const std::string &name) const;
void RegNodeParser(const std::string &name, PytorchNodeParser *parser);
private:
PytorchNodeParserRegistry();
private:
std::unordered_map<std::string, PytorchNodeParser *> parsers{};
};
class PytorchNodeRegistrar {
public:
PytorchNodeRegistrar(const std::string &name, PytorchNodeParser *parser) {
PytorchNodeParserRegistry::GetInstance().RegNodeParser(name, parser);
}
~PytorchNodeRegistrar() = default;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_NODE_REGISTRY_H

View File

@ -0,0 +1,34 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/pytorch/pytorch_permute_parser.h"
#include <memory>
#include <vector>
#include "ops/transpose.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
PrimitiveCPtr PytorchPermuteParser::Parse(const torch::jit::Node *torch_node, std::vector<size_t> *input_indices) {
MS_ASSERT(torch_node != nullptr && input_indices != nullptr);
auto prim = std::make_unique<ops::Transpose>();
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
return prim->GetPrim();
}
PytorchNodeRegistrar g_pytorchPermuteParser("permute", new PytorchPermuteParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,35 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_PERMUTE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_PERMUTE_PARSER_H
#include <vector>
#include "tools/converter/parser/pytorch/pytorch_node_parser.h"
#include "tools/converter/parser/pytorch/pytorch_node_parser_registry.h"
namespace mindspore {
namespace lite {
class PytorchPermuteParser : public PytorchNodeParser {
public:
PytorchPermuteParser() : PytorchNodeParser("Permute") {}
~PytorchPermuteParser() override = default;
PrimitiveCPtr Parse(const torch::jit::Node *torch_node, std::vector<size_t> *input_indices) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PYTORCH_PERMUTE_PARSER_H

View File

@ -121,7 +121,7 @@ STATUS TfliteInputsAdjust::ReplaceInt64ParameterNode(const FuncGraphPtr &func_gr
MS_LOG(ERROR) << "default data is not tensor::Tensor.";
return lite::RET_NULL_PTR;
}
auto param_node_new = opt::BuildParameterNode(func_graph, param_node, tensor_info);
auto param_node_new = opt::BuildParameterNode(func_graph, tensor_info, param_node->fullname_with_scope());
if (!manager->Replace(param_node, param_node_new)) {
MS_LOG(ERROR) << "Replace param node failed.";
return RET_ERROR;

View File

@ -208,7 +208,8 @@ STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode,
{converter::kFmkTypeTf, DecideTFConvWeightSrcFormat},
{converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat},
{converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat},
{converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}};
{converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat},
{converter::kFmkTypePytorch, DecideONNXConvWeightSrcFormat}};
auto iter = decide_functions.find(fmk_type_);
if (iter == decide_functions.end()) {
MS_LOG(ERROR) << "current fmk don't support, please check.";

View File

@ -26,7 +26,7 @@ std::map<FmkType, ModelParserCreator> model_parser_room;
} // namespace
ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) {
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) {
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) {
MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
return;
}
@ -34,7 +34,7 @@ ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator
}
converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) {
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) {
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) {
MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
return nullptr;
}

View File

@ -728,9 +728,9 @@ STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_for
return iter->second(tensor, src_format, dst_format);
}
ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const tensor::TensorPtr &tensor_info) {
if (func_graph == nullptr || node == nullptr || tensor_info == nullptr) {
ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const tensor::TensorPtr &tensor_info,
const std::string &node_name) {
if (func_graph == nullptr || tensor_info == nullptr) {
MS_LOG(ERROR) << "input parameter is nullptr.";
return nullptr;
}
@ -746,7 +746,7 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr
} else if (tensor_info->data_type() == kNumberTypeFloat64) {
data_type = kNumberTypeFloat32;
}
param_node->set_name(node->fullname_with_scope());
param_node->set_name(node_name);
auto tensor_info_new = std::make_shared<tensor::Tensor>(data_type, shape_vector);
if (tensor_info_new == nullptr) {
MS_LOG(ERROR) << "new tensor::Tensor failed.";

View File

@ -98,8 +98,8 @@ AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index);
STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format);
ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const tensor::TensorPtr &tensor_info);
ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const tensor::TensorPtr &tensor_info,
const std::string &node_name);
ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data,
const std::string &node_name);