From 87325ad656f52e54ef5418b952e1a3a1c9d4934a Mon Sep 17 00:00:00 2001 From: xuanyue Date: Mon, 21 Dec 2020 21:33:11 +0800 Subject: [PATCH] primitive attr->tensor --- build.sh | 6 +- mindspore/lite/include/lite_utils.h | 7 +- mindspore/lite/include/model.h | 7 +- mindspore/lite/src/CMakeLists.txt | 6 +- mindspore/lite/src/common/version_manager.h | 44 ++++ .../src/{model_common.cc => lite_model.cc} | 239 +++++++++++++----- mindspore/lite/src/lite_model.h | 223 ++++++++++++++++ mindspore/lite/src/lite_session.cc | 4 +- mindspore/lite/src/model.cc | 52 ---- mindspore/lite/src/model_common.h | 192 -------------- mindspore/lite/src/ops/CMakeLists.txt | 4 + .../src/ops/compat/attr_transfer_common.cc | 65 +++++ .../src/ops/compat/attr_transfer_common.h | 35 +++ .../lite/src/ops/compat/compat_register.h | 67 +++++ .../ops/compat/v0/broadcat_to_compat_v0.cc | 48 ++++ .../src/ops/compat/v0/reshape_compat_v0.cc | 47 ++++ .../ops/compat/v0/strided_slice_compat_v0.cc | 67 +++++ mindspore/lite/src/train/train_model.cc | 24 +- mindspore/lite/src/train/train_model.h | 8 +- mindspore/lite/test/CMakeLists.txt | 6 +- mindspore/lite/tools/converter/CMakeLists.txt | 5 +- .../quantizer/post_training_quantizer.cc | 2 + .../lite/tools/lib_cropper/lib_cropper.h | 1 + 23 files changed, 809 insertions(+), 350 deletions(-) create mode 100644 mindspore/lite/src/common/version_manager.h rename mindspore/lite/src/{model_common.cc => lite_model.cc} (50%) create mode 100644 mindspore/lite/src/lite_model.h delete mode 100644 mindspore/lite/src/model.cc delete mode 100644 mindspore/lite/src/model_common.h create mode 100644 mindspore/lite/src/ops/compat/attr_transfer_common.cc create mode 100644 mindspore/lite/src/ops/compat/attr_transfer_common.h create mode 100644 mindspore/lite/src/ops/compat/compat_register.h create mode 100644 mindspore/lite/src/ops/compat/v0/broadcat_to_compat_v0.cc create mode 100644 mindspore/lite/src/ops/compat/v0/reshape_compat_v0.cc create mode 100644 mindspore/lite/src/ops/compat/v0/strided_slice_compat_v0.cc diff --git a/build.sh b/build.sh index d615fb8a391..e149629a99a 100755 --- a/build.sh +++ b/build.sh @@ -594,7 +594,7 @@ build_lite() -DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ -DPLATFORM_ARM64=on -DENABLE_NEON=on -DENABLE_FP16="off" \ -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ - -DSUPPORT_GPU=${LITE_ENABLE_GPU} -DSUPPORT_NPU=${LITE_ENABLE_NPU} \ + -DSUPPORT_GPU=${LITE_ENABLE_GPU} -DSUPPORT_NPU=${LITE_ENABLE_NPU} -DENABLE_V0=on \ -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \ -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ @@ -606,7 +606,7 @@ build_lite() -DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ - -DSUPPORT_GPU=${ENABLE_GPU} -DSUPPORT_NPU=${ENABLE_NPU} \ + -DSUPPORT_GPU=${ENABLE_GPU} -DSUPPORT_NPU=${ENABLE_NPU} -DENABLE_V0=on \ -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \ -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ @@ -615,7 +615,7 @@ build_lite() cmake -DPLATFORM_ARM64=off -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DSUPPORT_NPU=${ENABLE_NPU} \ - -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ + -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} -DENABLE_V0=on \ -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \ -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ -DENABLE_VERBOSE=${ENABLE_VERBOSE} -DX86_64_SIMD=${X86_64_SIMD} "${BASEPATH}/mindspore/lite" diff --git a/mindspore/lite/include/lite_utils.h b/mindspore/lite/include/lite_utils.h index c059b6e3460..9e09613b63a 100644 --- a/mindspore/lite/include/lite_utils.h +++ b/mindspore/lite/include/lite_utils.h @@ -19,9 +19,12 @@ #include #include #include -#include "schema/model_generated.h" #include "include/ms_tensor.h" +namespace mindspore::schema { +struct Tensor; +} // namespace mindspore::schema + namespace mindspore::lite { /// \brief Allocator defined a memory pool for malloc memory and free memory dynamically. /// @@ -35,7 +38,7 @@ using TensorPtrVector = std::vector; using DeviceContextVector = std::vector; using Uint32Vector = std::vector; using String = std::string; -using NodeType = schema::NodeType; +using NodeType = int; /**< 0 : NodeType_ValueNode, 1 : NodeType_Parameter, 2 : NodeType_CNode. */ using AllocatorPtr = std::shared_ptr; /// \brief Set data of MSTensor from string vector. diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h index ba1f24aafba..cd7a574c2ee 100644 --- a/mindspore/lite/include/model.h +++ b/mindspore/lite/include/model.h @@ -53,13 +53,10 @@ struct MS_API Model { static Model *Import(const char *model_buf, size_t size); /// \brief Free meta graph temporary buffer - virtual void Free(); - - /// \brief Free all temporay buffer.EG: nodes in the model. - void Destroy(); + virtual void Free() = 0; /// \brief Model destruct, free all memory - virtual ~Model(); + virtual ~Model() = default; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 261778cf0df..151e5fe3366 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -1,4 +1,7 @@ add_compile_definitions(USE_ANDROID_LOG) +if (ENABLE_V0) + add_definitions(-DENABLE_V0) +endif() set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) include_directories(${LITE_DIR}/nnacl/) include_directories(${LITE_DIR}/nnacl/optimize) @@ -29,13 +32,12 @@ set(LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/tensorlist.cc ${CMAKE_CURRENT_SOURCE_DIR}/executor.cc ${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc - ${CMAKE_CURRENT_SOURCE_DIR}/model_common.cc + ${CMAKE_CURRENT_SOURCE_DIR}/lite_model.cc ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc - ${CMAKE_CURRENT_SOURCE_DIR}/model.cc ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc ) diff --git a/mindspore/lite/src/common/version_manager.h b/mindspore/lite/src/common/version_manager.h new file mode 100644 index 00000000000..5b336c000d3 --- /dev/null +++ b/mindspore/lite/src/common/version_manager.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_ +#define MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_ + +#include +#include "src/lite_model.h" + +namespace mindspore { +namespace lite { +class VersionManager { + public: + static VersionManager *GetInstance() { + static VersionManager instance; + return &instance; + } + virtual ~VersionManager() = default; + + void SetSchemaVersion(const int schema_version) { schema_version_ = schema_version; } + int GetSchemaVersion() const { return schema_version_; } + + private: + VersionManager() = default; + + private: + int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_ diff --git a/mindspore/lite/src/model_common.cc b/mindspore/lite/src/lite_model.cc similarity index 50% rename from mindspore/lite/src/model_common.cc rename to mindspore/lite/src/lite_model.cc index 5fc7c3967a8..fca67d624d4 100644 --- a/mindspore/lite/src/model_common.cc +++ b/mindspore/lite/src/lite_model.cc @@ -13,15 +13,115 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/model_common.h" + +#include "src/lite_model.h" +#include +#include +#include #include "src/ops/while.h" +#ifdef ENABLE_V0 +#include "src/ops/compat/compat_register.h" +#endif namespace mindspore::lite { -int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) { - if (model == nullptr) { - MS_LOG(ERROR) << "model is null."; +#ifdef ENABLE_V0 +int LiteModel::ConvertAttrs(Model::Node *node, const schema::v0::Primitive *prim, + std::vector *dst_tensor) { + if (node == nullptr || dst_tensor == nullptr) { + MS_LOG(ERROR) << "node or tensor_vec is nullptr."; return RET_ERROR; } + int primitive_type = prim->value_type(); + auto creator = CompatRegistry::GetInstance()->GetTransferAttrFunc(SCHEMA_VERSION::SCHEMA_V0, primitive_type); + if (creator == nullptr) { + MS_LOG(DEBUG) << "the node don't need to convert attr to tensor."; + return RET_OK; + } + int status = creator(reinterpret_cast(prim), node, dst_tensor, &this->attr_tensor_bufs_); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "translate attr to tensor failed."; + return status; + } + return RET_OK; +} + +int LiteModel::ConvertAttrToTensors(const void *meta_graph) { + MS_ASSERT(meta_graph != nullptr); + int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); + if (schema_version != SCHEMA_VERSION::SCHEMA_V0) { + MS_LOG(DEBUG) << "no need to convert attr to tensor."; + return RET_OK; + } + auto meta_graph_v0 = reinterpret_cast(meta_graph); + std::unordered_map> subgraph_node_indexes; + for (size_t subgraph_index = 0; subgraph_index < this->sub_graphs_.size(); ++subgraph_index) { + for (size_t node_index = 0; node_index < this->sub_graphs_[subgraph_index]->node_indices_.size(); ++node_index) { + subgraph_node_indexes[subgraph_index].insert(this->sub_graphs_[subgraph_index]->node_indices_[node_index]); + } + } + int cur_all_tensors_size = this->all_tensors_.size(); + for (size_t index = 0; index < this->all_nodes_.size(); ++index) { + std::vector dst_tensors; + auto prim = meta_graph_v0->nodes()->GetAs(index)->primitive(); + int status = ConvertAttrs(this->all_nodes_[index], prim, &dst_tensors); + if (status != RET_OK) { + MS_LOG(ERROR) << "fail to convert attr to tensor."; + return RET_ERROR; + } + if (dst_tensors.empty()) { + continue; + } + std::vector subgraphs_with_node; + for (size_t subgraph_index = 0; subgraph_index < this->sub_graphs_.size(); ++subgraph_index) { + if (subgraph_node_indexes[subgraph_index].find(index) == subgraph_node_indexes[subgraph_index].end()) { + continue; + } + subgraphs_with_node.push_back(subgraph_index); + } + for (auto tensor : dst_tensors) { + for (auto subgraph_index : subgraphs_with_node) { + this->sub_graphs_[subgraph_index]->tensor_indices_.push_back(cur_all_tensors_size); + } + this->all_nodes_[index]->input_indices_.push_back(cur_all_tensors_size++); + this->all_tensors_.push_back(tensor); + } + } + return RET_OK; +} +#endif + +void LiteModel::Free() { + if (this->buf != nullptr) { + free(this->buf); + this->buf = nullptr; + } + for (auto &tensor_buf : attr_tensor_bufs_) { + free(tensor_buf); + } + attr_tensor_bufs_.resize(0); +} + +LiteModel::~LiteModel() { + Free(); + auto nodes_size = this->all_nodes_.size(); + for (size_t i = 0; i < nodes_size; ++i) { + auto node = this->all_nodes_[i]; + MS_ASSERT(node != nullptr); + MS_ASSERT(node->primitive_ != nullptr); + delete node->primitive_; + node->primitive_ = nullptr; + delete node; + } + this->all_nodes_.clear(); + + auto sub_graph_size = this->sub_graphs_.size(); + for (size_t i = 0; i < sub_graph_size; ++i) { + auto sub_graph = this->sub_graphs_[i]; + delete sub_graph; + } +} + +int LiteModel::ConvertSubGraph(const schema::SubGraph &sub_graph) { if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr || sub_graph.nodeIndices() == nullptr || sub_graph.tensorIndices() == nullptr) { MS_LOG(ERROR) << "sub_graph is invalid."; @@ -51,28 +151,31 @@ int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) { for (uint32_t i = 0; i < tensor_count; ++i) { subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i)); } - model->sub_graphs_.push_back(subgraph); + this->sub_graphs_.push_back(subgraph); return RET_OK; } -int VersionVerify(flatbuffers::Verifier *verify) { +int LiteModel::VersionVerify(flatbuffers::Verifier *verify) const { if (verify == nullptr) { MS_LOG(ERROR) << "verify is null."; return RET_ERROR; } if (schema::VerifyMetaGraphBuffer(*verify)) { return SCHEMA_VERSION::SCHEMA_CUR; - } else if (schema::v0::VerifyMetaGraphBuffer(*verify)) { + } +#ifdef ENABLE_V0 + if (schema::v0::VerifyMetaGraphBuffer(*verify)) { return SCHEMA_VERSION::SCHEMA_V0; } +#endif return SCHEMA_VERSION::SCHEMA_INVALID; } -int NodeVerify(const Model &model) { - auto tensor_size = model.all_tensors_.size(); - uint32_t subGraph_size = model.sub_graphs_.size(); +int LiteModel::NodeVerify() const { + auto tensor_size = this->all_tensors_.size(); + uint32_t subGraph_size = this->sub_graphs_.size(); - for (auto &node : model.all_nodes_) { + for (auto &node : this->all_nodes_) { if (node == nullptr || node->primitive_ == nullptr) { MS_LOG(ERROR) << "node or its primitive_ is null."; return RET_ERROR; @@ -105,11 +208,11 @@ int NodeVerify(const Model &model) { return RET_OK; } -int SubGraphVerify(const Model &model) { - auto tensor_size = model.all_tensors_.size(); - auto node_size = model.all_nodes_.size(); +int LiteModel::SubGraphVerify() const { + auto tensor_size = this->all_tensors_.size(); + auto node_size = this->all_nodes_.size(); - for (auto &graph : model.sub_graphs_) { + for (auto &graph : this->sub_graphs_) { if (graph == nullptr) { MS_LOG(ERROR) << "graph is null."; return RET_ERROR; @@ -138,49 +241,78 @@ int SubGraphVerify(const Model &model) { return RET_OK; } -bool ModelVerify(const Model &model) { return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK; } +bool LiteModel::ModelVerify() const { return NodeVerify() == RET_OK && SubGraphVerify() == RET_OK; } -const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) { - if (buf == nullptr) { - MS_LOG(ERROR) << "buf is null."; - return nullptr; - } +const void *LiteModel::GetMetaGraphByVerison() { + MS_ASSERT(this->buf != nullptr); + auto schema_version = VersionManager::GetInstance()->GetSchemaVersion(); if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { - return reinterpret_cast(schema::GetMetaGraph(buf)); - } else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { + return reinterpret_cast(schema::GetMetaGraph(this->buf)); + } +#ifdef ENABLE_V0 + if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { return reinterpret_cast(schema::v0::GetMetaGraph(buf)); } +#endif return nullptr; } -int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version) { - if (meta_graph == nullptr || model == nullptr) { - MS_LOG(ERROR) << "meta_graph or model is null."; - return RET_ERROR; - } +int LiteModel::GenerateModelByVersion(const void *meta_graph) { + MS_ASSERT(meta_graph != nullptr); + auto schema_version = VersionManager::GetInstance()->GetSchemaVersion(); int status = RET_ERROR; if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { - status = GenerateModel(*reinterpret_cast(meta_graph), - model, schema_version); - } else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { - status = GenerateModel( - *reinterpret_cast(meta_graph), model, schema_version); + status = GenerateModel(*reinterpret_cast(meta_graph)); } +#ifdef ENABLE_V0 + if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { + status = GenerateModel( + *reinterpret_cast(meta_graph)); + } +#endif return status; } +int LiteModel::ConstructModel() { + if (this->buf == nullptr || this->buf_size_ <= 0) { + MS_LOG(ERROR) << "cannot construct model."; + return RET_NULL_PTR; + } + flatbuffers::Verifier verify((const uint8_t *)this->buf, this->buf_size_); + int schema_version = VersionVerify(&verify); + if (schema_version == SCHEMA_INVALID) { + MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; + return RET_ERROR; + } + VersionManager::GetInstance()->SetSchemaVersion(schema_version); + const void *meta_graph = GetMetaGraphByVerison(); + if (meta_graph == nullptr) { + MS_LOG(ERROR) << "meta_graph is nullptr!"; + return RET_NULL_PTR; + } + + int status = GenerateModelByVersion(meta_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "fail to generate model"; + return status; + } + + if (this->version_ != Version()) { + MS_LOG(WARNING) << "model version is " << this->version_ << ", inference version is " << Version() << " not equal"; + } + if (this->sub_graphs_.empty()) { + return RET_ERROR; + } + + return ModelVerify() ? RET_OK : RET_ERROR; +} + Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { if (model_buf == nullptr) { MS_LOG(ERROR) << "The model buf is nullptr"; return nullptr; } - flatbuffers::Verifier verify((const uint8_t *)model_buf, size); - int schema_version = VersionVerify(&verify); - if (schema_version == SCHEMA_INVALID) { - MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; - return nullptr; - } - auto *model = new (std::nothrow) Model(); + auto *model = new (std::nothrow) LiteModel(); if (model == nullptr) { MS_LOG(ERROR) << "new model fail!"; return nullptr; @@ -201,28 +333,15 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { } memcpy(model->buf, model_buf, size); } - const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version); - if (meta_graph == nullptr) { - MS_LOG(ERROR) << "meta_graph is nullptr!"; - delete (model); - return nullptr; - } - - int status = GenerateModelByVersion(meta_graph, model, schema_version); + model->buf_size_ = size; + auto status = model->ConstructModel(); if (status != RET_OK) { - delete (model); - MS_LOG(ERROR) << "fail to generate model"; + MS_LOG(ERROR) << "construct model failed."; + delete model; return nullptr; } - - if (model->version_ != Version()) { - MS_LOG(WARNING) << "model version is " << model->version_ << ", inference version is " << Version() << " not equal"; - } - if (model->sub_graphs_.empty()) { - delete (model); - return nullptr; - } - - return ModelVerify(*model) ? model : nullptr; + return model; } + +Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); } } // namespace mindspore::lite diff --git a/mindspore/lite/src/lite_model.h b/mindspore/lite/src/lite_model.h new file mode 100644 index 00000000000..5dd043a9727 --- /dev/null +++ b/mindspore/lite/src/lite_model.h @@ -0,0 +1,223 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_LITE_MODEL_H_ +#define MINDSPORE_LITE_SRC_LITE_MODEL_H_ + +#include +#include +#include "include/model.h" +#include "src/ops/primitive_c.h" +#include "include/version.h" +#include "schema/model_generated.h" +#include "src/common/common.h" +#include "src/common/version_manager.h" +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif +#ifdef ENABLE_V0 +#include "schema/model_v0_generated.h" +#endif + +namespace mindspore { +namespace lite { +class LiteModel : public Model { + public: + int ConstructModel(); + + bool ModelVerify() const; + + void Free() override; + + ~LiteModel() override; + + private: +#ifdef ENABLE_V0 + int ConvertAttrs(Model::Node *node, const schema::v0::Primitive *prim, std::vector *dst_tensor); + + int ConvertAttrToTensors(const void *meta_graph); +#endif + + template + bool ConvertNodes(const T &meta_graph) { + if (meta_graph.nodes() == nullptr) { + MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; + return false; + } + for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) { + auto *node = new (std::nothrow) Model::Node(); + if (node == nullptr) { + MS_LOG(ERROR) << "new node fail!"; + return false; + } + auto c_node = meta_graph.nodes()->template GetAs(i); + auto src_prim = reinterpret_cast(c_node->primitive()); +#ifdef PRIMITIVE_WRITEABLE + node->primitive_ = PrimitiveC::Create(const_cast(src_prim)); +#else + auto primitive = const_cast(src_prim); + auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type()); + if (func_pointer == nullptr) { + MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: " + << schema::EnumNamePrimitiveType(primitive->value_type()); + delete node; + return false; + } + node->primitive_ = func_pointer(primitive); +#endif + if (node->primitive_ == nullptr) { + MS_LOG(ERROR) << "unpack primitive == nullptr!"; + delete node; + return false; + } + node->primitive_->set_quant_type(static_cast(c_node->quantType())); + node->name_ = c_node->name()->c_str(); + node->node_type_ = static_cast(c_node->nodeType()); + auto count = c_node->inputIndex()->size(); + for (uint32_t j = 0; j < count; ++j) { + node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs(j))); + } + if (c_node->outputIndex() != nullptr) { + count = c_node->outputIndex()->size(); + for (uint32_t j = 0; j < count; ++j) { + node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs(j))); + } + } + this->all_nodes_.push_back(node); + } + return true; + } + + template + bool ConvertTensors(const T &meta_graph) { + if (meta_graph.allTensors() == nullptr) { + MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; + return false; + } + auto tensor_count = meta_graph.allTensors()->size(); + for (uint32_t i = 0; i < tensor_count; ++i) { + auto *tensor = meta_graph.allTensors()->template GetAs(i); + if (tensor == nullptr) { + MS_LOG(ERROR) << i << "the tensor in metagraph is nullptr"; + return false; + } + this->all_tensors_.push_back(const_cast(tensor)); + } + return true; + } + + template + int MetaGraphMappingSubGraph(const T &meta_graph) { + if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || meta_graph.nodes() == nullptr || + meta_graph.allTensors() == nullptr) { + MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; + return RET_ERROR; + } + auto *subgraph = new (std::nothrow) Model::SubGraph(); + if (subgraph == nullptr) { + MS_LOG(ERROR) << "new subGraph fail!"; + return RET_ERROR; + } + if (meta_graph.name() != nullptr) { + subgraph->name_ = meta_graph.name()->c_str(); + } + auto in_count = meta_graph.inputIndex()->size(); + for (uint32_t i = 0; i < in_count; ++i) { + subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs(i))); + } + auto out_count = meta_graph.outputIndex()->size(); + for (uint32_t i = 0; i < out_count; ++i) { + subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs(i))); + } + auto node_count = meta_graph.nodes()->size(); + for (uint32_t i = 0; i < node_count; ++i) { + subgraph->node_indices_.push_back(i); + } + auto tensor_count = meta_graph.allTensors()->size(); + for (uint32_t i = 0; i < tensor_count; ++i) { + subgraph->tensor_indices_.push_back(i); + } + this->sub_graphs_.push_back(subgraph); + return RET_OK; + } + + template + int GenerateModel(const T &meta_graph) { + if (meta_graph.name() != nullptr) { + this->name_ = meta_graph.name()->c_str(); + } + if (meta_graph.version() != nullptr) { + this->version_ = meta_graph.version()->c_str(); + } + if (!ConvertNodes(meta_graph)) { + MS_LOG(ERROR) << "convert node failed"; + return RET_ERROR; + } + if (!ConvertTensors(meta_graph)) { + MS_LOG(ERROR) << "convert tensor failed"; + return RET_ERROR; + } + if (meta_graph.subGraph() == nullptr) { + int ret = MetaGraphMappingSubGraph(meta_graph); + if (ret != RET_OK) { + MS_LOG(ERROR) << "converter old version model wrong."; + return ret; + } + } else { + auto sub_graphs = meta_graph.subGraph(); + auto sub_graph_size = sub_graphs->size(); + for (size_t i = 0; i < sub_graph_size; i++) { + auto sub_graph = sub_graphs->template GetAs(i); + int ret = ConvertSubGraph(*sub_graph); + if (ret != RET_OK) { + MS_LOG(ERROR) << "converter subgraph wrong."; + return ret; + } + } + } +#ifdef ENABLE_V0 + if (ConvertAttrToTensors(&meta_graph) != RET_OK) { + MS_LOG(ERROR) << "fail to convert attr to tensor."; + return RET_ERROR; + } +#endif + return RET_OK; + } + + int VersionVerify(flatbuffers::Verifier *verify) const; + + const void *GetMetaGraphByVerison(); + + int GenerateModelByVersion(const void *meta_graph); + + int ConvertSubGraph(const schema::SubGraph &sub_graph); + + int NodeVerify() const; + + int SubGraphVerify() const; + + public: + size_t buf_size_ = 0; + + protected: + std::vector attr_tensor_bufs_; +}; + +Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf); +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_LITE_MODEL_H_ diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 96c16b7c92e..cb3506b9f57 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -26,7 +26,7 @@ #include "src/common/utils.h" #include "src/common/graph_util.h" #include "src/kernel_registry.h" -#include "src/model_common.h" +#include "src/lite_model.h" #include "src/runtime/kernel/arm/base/dequant.h" #if SUPPORT_NPU #include "src/runtime/agent/npu/npu_manager.h" @@ -353,7 +353,7 @@ int LiteSession::CompileGraph(Model *model) { is_running_.store(false); return RET_PARAM_INVALID; } - if (!ModelVerify(*model)) { + if (!reinterpret_cast(model)->ModelVerify()) { MS_LOG(ERROR) << "wrong model input, please check"; is_running_.store(false); return RET_ERROR; diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc deleted file mode 100644 index 74dd8662a0e..00000000000 --- a/mindspore/lite/src/model.cc +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "src/ops/primitive_c.h" -#include "include/model.h" -#include "src/common/log_adapter.h" -#include "src/model_common.h" - -namespace mindspore::lite { -Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); } - -void Model::Free() { - if (this->buf != nullptr) { - free(this->buf); - this->buf = nullptr; - } -} - -void Model::Destroy() { - Free(); - auto nodes_size = this->all_nodes_.size(); - for (size_t i = 0; i < nodes_size; ++i) { - auto node = this->all_nodes_[i]; - MS_ASSERT(node != nullptr); - MS_ASSERT(node->primitive_ != nullptr); - delete node->primitive_; - node->primitive_ = nullptr; - delete node; - } - this->all_nodes_.clear(); - - auto sub_graph_size = this->sub_graphs_.size(); - for (size_t i = 0; i < sub_graph_size; ++i) { - auto sub_graph = this->sub_graphs_[i]; - delete sub_graph; - } -} - -Model::~Model() { Destroy(); } -} // namespace mindspore::lite diff --git a/mindspore/lite/src/model_common.h b/mindspore/lite/src/model_common.h deleted file mode 100644 index 51c9317e6ad..00000000000 --- a/mindspore/lite/src/model_common.h +++ /dev/null @@ -1,192 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_MODEL_COMMON_H_ -#define MINDSPORE_LITE_SRC_MODEL_COMMON_H_ - -#include -#include "src/ops/primitive_c.h" -#include "include/model.h" -#include "include/version.h" -#include "schema/model_generated.h" -#include "schema/model_v0_generated.h" -#include "src/common/common.h" -#ifndef PRIMITIVE_WRITEABLE -#include "src/ops/ops_register.h" -#endif - -namespace mindspore::lite { -int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model); - -template -bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = SCHEMA_CUR) { - if (model == nullptr || meta_graph.nodes() == nullptr) { - MS_LOG(ERROR) << "model or meta_graph is invalid, please check your model file."; - return false; - } - for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) { - auto *node = new (std::nothrow) Model::Node(); - if (node == nullptr) { - MS_LOG(ERROR) << "new node fail!"; - return false; - } - auto c_node = meta_graph.nodes()->template GetAs(i); - auto src_prim = reinterpret_cast(c_node->primitive()); -#ifdef PRIMITIVE_WRITEABLE - node->primitive_ = PrimitiveC::Create(const_cast(src_prim)); -#else - auto primitive = const_cast(src_prim); - auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type()); - if (func_pointer == nullptr) { - MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: " - << schema::EnumNamePrimitiveType(primitive->value_type()); - delete node; - return false; - } - node->primitive_ = func_pointer(primitive); -#endif - if (node->primitive_ == nullptr) { - MS_LOG(ERROR) << "unpack primitive == nullptr!"; - delete node; - return false; - } - node->primitive_->set_quant_type(static_cast(c_node->quantType())); - node->name_ = c_node->name()->c_str(); - node->node_type_ = static_cast(c_node->nodeType()); - auto count = c_node->inputIndex()->size(); - for (uint32_t j = 0; j < count; ++j) { - node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs(j))); - } - if (c_node->outputIndex() != nullptr) { - count = c_node->outputIndex()->size(); - for (uint32_t j = 0; j < count; ++j) { - node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs(j))); - } - } - model->all_nodes_.push_back(node); - } - return true; -} - -template -bool ConvertTensors(const T &meta_graph, Model *model) { - if (model == nullptr || meta_graph.allTensors() == nullptr) { - MS_LOG(ERROR) << "model or meta_graph is invalid, please check your model file."; - return false; - } - auto tensor_count = meta_graph.allTensors()->size(); - for (uint32_t i = 0; i < tensor_count; ++i) { - auto *tensor = meta_graph.allTensors()->template GetAs(i); - if (tensor == nullptr) { - MS_LOG(ERROR) << i << "th tensor in model is nullptr"; - return false; - } - model->all_tensors_.push_back(const_cast(tensor)); - } - return true; -} - -template -int MetaGraphMappingSubGraph(const T &meta_graph, Model *model) { - if (model == nullptr || meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || - meta_graph.nodes() == nullptr || meta_graph.allTensors() == nullptr) { - MS_LOG(ERROR) << "model or meta_graph is invalid, please check your model file."; - return RET_ERROR; - } - auto *subgraph = new (std::nothrow) Model::SubGraph(); - if (subgraph == nullptr) { - MS_LOG(ERROR) << "new subGraph fail!"; - return RET_ERROR; - } - if (meta_graph.name() != nullptr) { - subgraph->name_ = meta_graph.name()->c_str(); - } - auto in_count = meta_graph.inputIndex()->size(); - for (uint32_t i = 0; i < in_count; ++i) { - subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs(i))); - } - auto out_count = meta_graph.outputIndex()->size(); - for (uint32_t i = 0; i < out_count; ++i) { - subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs(i))); - } - auto node_count = meta_graph.nodes()->size(); - for (uint32_t i = 0; i < node_count; ++i) { - subgraph->node_indices_.push_back(i); - } - auto tensor_count = meta_graph.allTensors()->size(); - for (uint32_t i = 0; i < tensor_count; ++i) { - subgraph->tensor_indices_.push_back(i); - } - model->sub_graphs_.push_back(subgraph); - return RET_OK; -} - -template -int GenerateModel(const T &meta_graph, Model *model, int schema_version = 0) { - if (model == nullptr) { - MS_LOG(ERROR) << "model is nullptr."; - return RET_ERROR; - } - if (meta_graph.name() != nullptr) { - model->name_ = meta_graph.name()->c_str(); - } - if (meta_graph.version() != nullptr) { - model->version_ = meta_graph.version()->c_str(); - } - if (!ConvertNodes(meta_graph, model, schema_version)) { - MS_LOG(ERROR) << "convert node failed"; - return RET_ERROR; - } - if (!ConvertTensors(meta_graph, model)) { - MS_LOG(ERROR) << "convert tensor failed"; - return RET_ERROR; - } - if (meta_graph.subGraph() == nullptr) { - int ret = MetaGraphMappingSubGraph(meta_graph, model); - if (ret != RET_OK) { - MS_LOG(ERROR) << "converter old version model wrong."; - return ret; - } - } else { - auto sub_graphs = meta_graph.subGraph(); - auto sub_graph_size = sub_graphs->size(); - for (size_t i = 0; i < sub_graph_size; i++) { - auto sub_graph = sub_graphs->template GetAs(i); - int ret = ConvertSubGraph(*sub_graph, model); - if (ret != RET_OK) { - MS_LOG(ERROR) << "converter subgraph wrong."; - return ret; - } - } - } - return RET_OK; -} - -int VersionVerify(flatbuffers::Verifier *verify); - -int NodeVerify(const Model &model); - -int SubGraphVerify(const Model &model); - -bool ModelVerify(const Model &model); - -const void *GetMetaGraphByVerison(const char *buf, const int &schema_version); - -int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version); - -Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf); -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_MODEL_COMMON_H_ diff --git a/mindspore/lite/src/ops/CMakeLists.txt b/mindspore/lite/src/ops/CMakeLists.txt index c90b0b22f9b..df4e5280713 100644 --- a/mindspore/lite/src/ops/CMakeLists.txt +++ b/mindspore/lite/src/ops/CMakeLists.txt @@ -4,6 +4,10 @@ file(GLOB OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/populate/*.cc ) +if (ENABLE_V0) + file(GLOB_RECURSE COMPAT_SRC ${CMAKE_CURRENT_SOURCE_DIR}/compat/*.cc) + set(OPS_SRC ${OPS_SRC} ${COMPAT_SRC}) +endif () add_library(cpu_ops_mid OBJECT ${OPS_SRC}) add_dependencies(cpu_ops_mid fbs_src) diff --git a/mindspore/lite/src/ops/compat/attr_transfer_common.cc b/mindspore/lite/src/ops/compat/attr_transfer_common.cc new file mode 100644 index 00000000000..633482ea24c --- /dev/null +++ b/mindspore/lite/src/ops/compat/attr_transfer_common.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/compat/attr_transfer_common.h" +#include +#include "src/common/log_adapter.h" + +namespace mindspore { +namespace lite { +schema::Tensor *AttrToTensor(void *data, int data_size, bool is_array, TypeId type_id, + std::vector *tensor_bufs) { + if (data == nullptr || tensor_bufs == nullptr) { + MS_LOG(ERROR) << "the parameter of this function is nullptr."; + return nullptr; + } + auto dst_tensor = + (is_array ? new (std::nothrow) Tensor(type_id, {data_size}, schema::Format_NHWC, Tensor::Category::CONST_TENSOR) + : new (std::nothrow) Tensor(type_id, {}, schema::Format_NHWC, Tensor::Category::CONST_SCALAR)); + auto dst_data = dst_tensor->MutableData(); + if (dst_data == nullptr) { + MS_LOG(ERROR) << "Data from tensor is nullptr"; + return nullptr; + } + std::vector uint8_data; + uint8_data.resize(dst_tensor->Size()); + memcpy(uint8_data.data(), data, dst_tensor->Size()); + auto shape = dst_tensor->shape(); + flatbuffers::FlatBufferBuilder fbb(1024); + auto tensor_offset = schema::CreateTensorDirect(fbb, schema::NodeType_ValueNode, type_id, &shape, schema::Format_NHWC, + 0, 0, &uint8_data); + fbb.Finish(tensor_offset); + delete dst_tensor; + auto buf = fbb.GetBufferPointer(); + if (buf == nullptr) { + MS_LOG(ERROR) << "GetBufferPointer return nullptr"; + fbb.Clear(); + return nullptr; + } + auto tensor_buf = reinterpret_cast(malloc(fbb.GetSize())); + if (tensor_buf == nullptr) { + MS_LOG(ERROR) << "malloc primitive_buf_ failed"; + fbb.Clear(); + return nullptr; + } + memcpy(tensor_buf, buf, fbb.GetSize()); + auto tensor = flatbuffers::GetRoot(tensor_buf); + tensor_bufs->push_back(tensor_buf); + fbb.Clear(); + return const_cast(tensor); +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/compat/attr_transfer_common.h b/mindspore/lite/src/ops/compat/attr_transfer_common.h new file mode 100644 index 00000000000..6ecf2be2514 --- /dev/null +++ b/mindspore/lite/src/ops/compat/attr_transfer_common.h @@ -0,0 +1,35 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_OP_ATTR_TRANSFER_COMMON_H_ +#define LITE_MINDSPORE_LITE_C_OPS_OP_ATTR_TRANSFER_COMMON_H_ + +#include +#include "ir/dtype/type_id.h" +#include "src/tensor.h" +#include "include/errorcode.h" +#include "schema/model_v0_generated.h" +#include "src/common/common.h" +#include "src/ops/compat/compat_register.h" + +namespace mindspore { +namespace lite { +schema::Tensor *AttrToTensor(void *data, int data_size, bool is_array, TypeId type_id, + std::vector *tensor_bufs); +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_OP_ATTR_TRANSFER_COMMON_H_ diff --git a/mindspore/lite/src/ops/compat/compat_register.h b/mindspore/lite/src/ops/compat/compat_register.h new file mode 100644 index 00000000000..8285d1e7f2f --- /dev/null +++ b/mindspore/lite/src/ops/compat/compat_register.h @@ -0,0 +1,67 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_OP_COMPAT_REGISTER_H_ +#define LITE_MINDSPORE_LITE_C_OPS_OP_COMPAT_REGISTER_H_ + +#include +#include +#include +#include "include/model.h" +#include "schema/model_generated.h" +#include "src/common/log_adapter.h" + +namespace mindspore { +namespace lite { +// compatibility, transfer attr to input tensor. +typedef int (*TransferAttrFunc)(const void *primitive, Model::Node *node, std::vector *tensor, + std::vector *tensor_bufs); +class CompatRegistry { + public: + static CompatRegistry *GetInstance() { + static CompatRegistry registry; + return ®istry; + } + + void InsertTransferAttrFuncMap(int schema_version, int primitive_type, TransferAttrFunc transfer_attr_func) { + int key = primitive_type * 10 + schema_version; + transfer_attr_funcs_[key] = transfer_attr_func; + } + + TransferAttrFunc GetTransferAttrFunc(int schema_version, int primitive_type) { + int key = primitive_type * 10 + schema_version; + if (transfer_attr_funcs_.find(key) != transfer_attr_funcs_.end()) { + return transfer_attr_funcs_[key]; + } else { + MS_LOG(DEBUG) << "Unsupported transformer type in Create : " << key; + return nullptr; + } + } + + protected: + std::unordered_map transfer_attr_funcs_; +}; + +class Register { + public: + Register(int schema_version, int primitive_type, TransferAttrFunc transfer_attr_func) { + CompatRegistry::GetInstance()->InsertTransferAttrFuncMap(schema_version, primitive_type, transfer_attr_func); + } + virtual ~Register() = default; +}; +} // namespace lite +} // namespace mindspore +#endif // LITE_MINDSPORE_LITE_C_OPS_OP_COMPAT_REGISTER_H_ diff --git a/mindspore/lite/src/ops/compat/v0/broadcat_to_compat_v0.cc b/mindspore/lite/src/ops/compat/v0/broadcat_to_compat_v0.cc new file mode 100644 index 00000000000..a217240d1d6 --- /dev/null +++ b/mindspore/lite/src/ops/compat/v0/broadcat_to_compat_v0.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/compat/attr_transfer_common.h" + +namespace mindspore { +namespace lite { +int TransferBroadcastToAttr(const void *primitive, Model::Node *node, std::vector *dst_tensors, + std::vector *tensor_bufs) { + if (primitive == nullptr || node == nullptr || dst_tensors == nullptr || tensor_bufs == nullptr) { + MS_LOG(ERROR) << "the parameter of this function is nullptr."; + return RET_ERROR; + } + if (node->input_indices_.size() != 1) { + MS_LOG(DEBUG) << "broadcast_to don't need to convert attr to tensor."; + return RET_OK; + } + dst_tensors->clear(); + tensor_bufs->clear(); + auto prim = reinterpret_cast(primitive); + auto dst_shape_attr = prim->value_as_BroadcastTo()->dst_shape(); + std::vector dst_shape = std::vector(dst_shape_attr->begin(), dst_shape_attr->end()); + auto dst_shape_tensor = AttrToTensor(dst_shape.data(), dst_shape.size(), true, kNumberTypeInt32, tensor_bufs); + if (dst_shape_tensor == nullptr) { + MS_LOG(ERROR) << "attr tensor is nullptr, transform is failed."; + return RET_NULL_PTR; + } + dst_tensors->push_back(dst_shape_tensor); + return RET_OK; +} + +Register BroadcastToTransferRegistry(SCHEMA_VERSION::SCHEMA_V0, schema::v0::PrimitiveType_BroadcastTo, + TransferBroadcastToAttr); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/compat/v0/reshape_compat_v0.cc b/mindspore/lite/src/ops/compat/v0/reshape_compat_v0.cc new file mode 100644 index 00000000000..153d9bed465 --- /dev/null +++ b/mindspore/lite/src/ops/compat/v0/reshape_compat_v0.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/compat/attr_transfer_common.h" + +namespace mindspore { +namespace lite { +int TransferReshapeAttr(const void *primitive, Model::Node *node, std::vector *dst_tensors, + std::vector *tensor_bufs) { + if (primitive == nullptr || node == nullptr || dst_tensors == nullptr || tensor_bufs == nullptr) { + MS_LOG(ERROR) << "the parameter of this function is nullptr."; + return RET_ERROR; + } + if (node->input_indices_.size() != 1) { + MS_LOG(DEBUG) << "reshape need to convert attr to tensor."; + return RET_OK; + } + dst_tensors->clear(); + tensor_bufs->clear(); + auto prim = reinterpret_cast(primitive); + auto dst_shape_attr = prim->value_as_Reshape()->shape(); + std::vector dst_shape = std::vector(dst_shape_attr->begin(), dst_shape_attr->end()); + auto dst_shape_tensor = AttrToTensor(dst_shape.data(), dst_shape.size(), true, kNumberTypeInt32, tensor_bufs); + if (dst_shape_tensor == nullptr) { + MS_LOG(ERROR) << "attr tensor is nullptr, transform is failed."; + return RET_NULL_PTR; + } + dst_tensors->push_back(dst_shape_tensor); + return RET_OK; +} + +Register ReshapeTransferRegistry(SCHEMA_VERSION::SCHEMA_V0, schema::v0::PrimitiveType_Reshape, TransferReshapeAttr); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/compat/v0/strided_slice_compat_v0.cc b/mindspore/lite/src/ops/compat/v0/strided_slice_compat_v0.cc new file mode 100644 index 00000000000..5c7d19593e6 --- /dev/null +++ b/mindspore/lite/src/ops/compat/v0/strided_slice_compat_v0.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/compat/attr_transfer_common.h" + +namespace mindspore { +namespace lite { +int TransferStridedSliceAttr(const void *primitive, Model::Node *node, std::vector *dst_tensors, + std::vector *tensor_bufs) { + if (primitive == nullptr || node == nullptr || dst_tensors == nullptr || tensor_bufs == nullptr) { + MS_LOG(ERROR) << "the parameter of this function is nullptr."; + return RET_ERROR; + } + dst_tensors->clear(); + tensor_bufs->clear(); + auto prim = reinterpret_cast(primitive); + int inputs_size = node->input_indices_.size(); + switch (inputs_size) { + case 1: { + auto begins_attr = prim->value_as_StridedSlice()->begin(); + std::vector dst_begins = std::vector(begins_attr->begin(), begins_attr->end()); + auto dst_begins_tensor = AttrToTensor(dst_begins.data(), dst_begins.size(), true, kNumberTypeInt32, tensor_bufs); + dst_tensors->push_back(dst_begins_tensor); + } + case 2: { + auto ends_attr = prim->value_as_StridedSlice()->end(); + std::vector dst_ends = std::vector(ends_attr->begin(), ends_attr->end()); + auto dst_ends_tensor = AttrToTensor(dst_ends.data(), dst_ends.size(), true, kNumberTypeInt32, tensor_bufs); + dst_tensors->push_back(dst_ends_tensor); + } + case 3: { + auto strides_attr = prim->value_as_StridedSlice()->stride(); + std::vector dst_strides = std::vector(strides_attr->begin(), strides_attr->end()); + auto dst_strides_tensor = + AttrToTensor(dst_strides.data(), dst_strides.size(), true, kNumberTypeInt32, tensor_bufs); + dst_tensors->push_back(dst_strides_tensor); + break; + } + default: { + MS_LOG(DEBUG) << "stride_slice don't need to convert attr to tensor."; + return RET_OK; + } + } + if (std::any_of(dst_tensors->begin(), dst_tensors->end(), [](schema::Tensor *tensor) { return tensor == nullptr; })) { + MS_LOG(ERROR) << "convert attr to tensor failed."; + return RET_ERROR; + } + return RET_OK; +} + +Register StridedSliceTransferRegistry(SCHEMA_VERSION::SCHEMA_V0, schema::v0::PrimitiveType_StridedSlice, + TransferStridedSliceAttr); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/train/train_model.cc b/mindspore/lite/src/train/train_model.cc index 64f1739ffbf..c76ba01fc43 100644 --- a/mindspore/lite/src/train/train_model.cc +++ b/mindspore/lite/src/train/train_model.cc @@ -18,7 +18,6 @@ #include "src/common/log_adapter.h" #include "include/errorcode.h" #include "src/common/graph_util.h" -#include "src/model_common.h" namespace mindspore::lite { @@ -27,12 +26,6 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { MS_LOG(ERROR) << "The model buf is nullptr"; return nullptr; } - flatbuffers::Verifier verify((const uint8_t *)model_buf, size); - int schema_version = VersionVerify(&verify); - if (schema_version == -1) { - MS_LOG(ERROR) << "The model buffer is invalid, cannot get schema version"; - return nullptr; - } TrainModel *model = new (std::nothrow) TrainModel(); if (model == nullptr) { MS_LOG(ERROR) << "new model fail!"; @@ -46,19 +39,10 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { } memcpy(model->buf, model_buf, size); model->buf_size_ = size; - const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version); - if (meta_graph == nullptr) { - MS_LOG(ERROR) << "meta_graph is nullptr!"; - free(model->buf); - delete (model); - return nullptr; - } - - int status = GenerateModelByVersion(meta_graph, model, schema_version); + auto status = model->ConstructModel(); if (status != RET_OK) { - free(model->buf); - delete (model); - MS_LOG(ERROR) << "fail to generate model"; + MS_LOG(ERROR) << "construct model failed."; + delete model; return nullptr; } return model; @@ -91,6 +75,4 @@ char *TrainModel::ExportBuf(char *buffer, size_t *len) const { *len = buf_size_; return buffer; } - -TrainModel::~TrainModel() { Model::Free(); } } // namespace mindspore::lite diff --git a/mindspore/lite/src/train/train_model.h b/mindspore/lite/src/train/train_model.h index 486e3e03d06..14ddf479225 100644 --- a/mindspore/lite/src/train/train_model.h +++ b/mindspore/lite/src/train/train_model.h @@ -16,13 +16,13 @@ #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ #include -#include "include/model.h" +#include "src/lite_model.h" namespace mindspore { namespace lite { /// \brief TrainModel Defines a class that allows to import and export a mindsport trainable model -struct TrainModel : public lite::Model { +struct TrainModel : public lite::LiteModel { /// \brief Static method to create a TrainModel object /// /// \param[in] model_buf A buffer that was read from a MS model file @@ -35,7 +35,7 @@ struct TrainModel : public lite::Model { void Free() override; /// \brief Class destructor, free all memory - virtual ~TrainModel(); + virtual ~TrainModel() = default; /// \brief Export Model into a buffer /// @@ -44,8 +44,6 @@ struct TrainModel : public lite::Model { /// /// \return Pointer to buffer with exported model char *ExportBuf(char *buf, size_t *len) const; - - size_t buf_size_; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 6696a43c2da..4dd031127ad 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -105,7 +105,8 @@ if (PLATFORM_ARM32 OR PLATFORM_ARM64) endif() endif() ### runtime framework -file(GLOB_RECURSE OPS_SRC ${LITE_DIR}/src/ops/*.cc ${LITE_DIR}/src/ops/populate/*.cc) +add_definitions(-DENABLE_V0) +file(GLOB_RECURSE OPS_SRC ${LITE_DIR}/src/ops/*.cc) set(TEST_LITE_SRC ${TEST_LITE_SRC} ${CCSRC_SRC} @@ -123,8 +124,7 @@ set(TEST_LITE_SRC ${LITE_DIR}/src/lite_kernel.cc ${LITE_DIR}/src/lite_session.cc ${LITE_DIR}/src/sub_graph_kernel.cc - ${LITE_DIR}/src/model.cc - ${LITE_DIR}/src/model_common.cc + ${LITE_DIR}/src/lite_model.cc ${LITE_DIR}/src/scheduler.cc ${LITE_DIR}/src/common/graph_util.cc ${LITE_DIR}/src/common/file_utils.cc diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 42e400b2a25..8756a9eabed 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -9,7 +9,7 @@ set(CCSRC_SRC include(${TOP_DIR}/cmake/external_libs/glog.cmake) -file(GLOB_RECURSE OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc +file(GLOB OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc) file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} @@ -88,8 +88,7 @@ set(LITE_SRC ${SRC_DIR}/sub_graph_kernel.cc ${SRC_DIR}/lite_session.cc ${SRC_DIR}/executor.cc - ${SRC_DIR}/model.cc - ${SRC_DIR}/model_common.cc + ${SRC_DIR}/lite_model.cc ${SRC_DIR}/errorcode.cc ) if (SUPPORT_TRAIN) diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 5b7f6341baf..aa53b6c743a 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -1581,6 +1581,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { flatbuffers::FlatBufferBuilder builder(1024); auto offset = schema::MetaGraph::Pack(builder, meta_graph); builder.Finish(offset); + schema::FinishMetaGraphBuffer(builder, offset); size_t size = builder.GetSize(); auto *content = reinterpret_cast(builder.GetBufferPointer()); if (content == nullptr) { @@ -1662,6 +1663,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { flatbuffers::FlatBufferBuilder int8_builder(1024); auto int8_offset = schema::MetaGraph::Pack(int8_builder, int8_meta_graph); int8_builder.Finish(int8_offset); + schema::FinishMetaGraphBuffer(int8_builder, int8_offset); size = int8_builder.GetSize(); auto *int8_content = reinterpret_cast(int8_builder.GetBufferPointer()); if (int8_content == nullptr) { diff --git a/mindspore/lite/tools/lib_cropper/lib_cropper.h b/mindspore/lite/tools/lib_cropper/lib_cropper.h index d951910ade2..9d680b56435 100644 --- a/mindspore/lite/tools/lib_cropper/lib_cropper.h +++ b/mindspore/lite/tools/lib_cropper/lib_cropper.h @@ -24,6 +24,7 @@ #include "tools/common/flag_parser.h" #include "src/common/file_utils.h" #include "src/common/utils.h" +#include "schema/model_generated.h" #include "include/lite_session.h" #include "tools/lib_cropper/cropper_flags.h"