forked from mindspore-Ecosystem/mindspore
!8932 [lite] convert attr to tensor
From: @xu_anyue Reviewed-by: Signed-off-by:
This commit is contained in:
commit
3e9d95dca1
6
build.sh
6
build.sh
|
@ -536,7 +536,7 @@ build_lite()
|
||||||
-DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
|
-DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
|
||||||
-DPLATFORM_ARM64=on -DENABLE_NEON=on -DENABLE_FP16="off" \
|
-DPLATFORM_ARM64=on -DENABLE_NEON=on -DENABLE_FP16="off" \
|
||||||
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
|
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
|
||||||
-DSUPPORT_GPU=${LITE_ENABLE_GPU} -DSUPPORT_NPU=${LITE_ENABLE_NPU} \
|
-DSUPPORT_GPU=${LITE_ENABLE_GPU} -DSUPPORT_NPU=${LITE_ENABLE_NPU} -DENABLE_V0=on \
|
||||||
-DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
-DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
||||||
-DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \
|
-DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \
|
||||||
-DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \
|
-DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \
|
||||||
|
@ -548,7 +548,7 @@ build_lite()
|
||||||
-DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
|
-DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
|
||||||
-DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
|
-DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
|
||||||
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
|
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
|
||||||
-DSUPPORT_GPU=${ENABLE_GPU} -DSUPPORT_NPU=${ENABLE_NPU} \
|
-DSUPPORT_GPU=${ENABLE_GPU} -DSUPPORT_NPU=${ENABLE_NPU} -DENABLE_V0=on \
|
||||||
-DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
-DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
||||||
-DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \
|
-DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \
|
||||||
-DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \
|
-DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \
|
||||||
|
@ -557,7 +557,7 @@ build_lite()
|
||||||
cmake -DPLATFORM_ARM64=off -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
|
cmake -DPLATFORM_ARM64=off -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
|
||||||
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
|
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
|
||||||
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DSUPPORT_NPU=${ENABLE_NPU} \
|
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DSUPPORT_NPU=${ENABLE_NPU} \
|
||||||
-DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
-DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} -DENABLE_V0=on \
|
||||||
-DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \
|
-DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \
|
||||||
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
|
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
|
||||||
-DENABLE_VERBOSE=${ENABLE_VERBOSE} -DX86_64_SIMD=${X86_64_SIMD} "${BASEPATH}/mindspore/lite"
|
-DENABLE_VERBOSE=${ENABLE_VERBOSE} -DX86_64_SIMD=${X86_64_SIMD} "${BASEPATH}/mindspore/lite"
|
||||||
|
|
|
@ -19,9 +19,12 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "schema/model_generated.h"
|
|
||||||
#include "include/ms_tensor.h"
|
#include "include/ms_tensor.h"
|
||||||
|
|
||||||
|
namespace mindspore::schema {
|
||||||
|
struct Tensor;
|
||||||
|
} // namespace mindspore::schema
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
/// \brief Allocator defined a memory pool for malloc memory and free memory dynamically.
|
/// \brief Allocator defined a memory pool for malloc memory and free memory dynamically.
|
||||||
///
|
///
|
||||||
|
@ -35,7 +38,7 @@ using TensorPtrVector = std::vector<mindspore::schema::Tensor *>;
|
||||||
using DeviceContextVector = std::vector<DeviceContext>;
|
using DeviceContextVector = std::vector<DeviceContext>;
|
||||||
using Uint32Vector = std::vector<uint32_t>;
|
using Uint32Vector = std::vector<uint32_t>;
|
||||||
using String = std::string;
|
using String = std::string;
|
||||||
using NodeType = schema::NodeType;
|
using NodeType = int; /**< 0 : NodeType_ValueNode, 1 : NodeType_Parameter, 2 : NodeType_CNode. */
|
||||||
using AllocatorPtr = std::shared_ptr<Allocator>;
|
using AllocatorPtr = std::shared_ptr<Allocator>;
|
||||||
|
|
||||||
/// \brief Set data of MSTensor from string vector.
|
/// \brief Set data of MSTensor from string vector.
|
||||||
|
|
|
@ -53,13 +53,10 @@ struct MS_API Model {
|
||||||
static Model *Import(const char *model_buf, size_t size);
|
static Model *Import(const char *model_buf, size_t size);
|
||||||
|
|
||||||
/// \brief Free meta graph temporary buffer
|
/// \brief Free meta graph temporary buffer
|
||||||
virtual void Free();
|
virtual void Free() = 0;
|
||||||
|
|
||||||
/// \brief Free all temporay buffer.EG: nodes in the model.
|
|
||||||
void Destroy();
|
|
||||||
|
|
||||||
/// \brief Model destruct, free all memory
|
/// \brief Model destruct, free all memory
|
||||||
virtual ~Model();
|
virtual ~Model() = default;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
add_compile_definitions(USE_ANDROID_LOG)
|
add_compile_definitions(USE_ANDROID_LOG)
|
||||||
|
if (ENABLE_V0)
|
||||||
|
add_definitions(-DENABLE_V0)
|
||||||
|
endif()
|
||||||
set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||||
include_directories(${LITE_DIR}/nnacl/)
|
include_directories(${LITE_DIR}/nnacl/)
|
||||||
include_directories(${LITE_DIR}/nnacl/optimize)
|
include_directories(${LITE_DIR}/nnacl/optimize)
|
||||||
|
@ -29,13 +32,12 @@ set(LITE_SRC
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/tensorlist.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/tensorlist.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/executor.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/executor.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/model_common.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/lite_model.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/model.cc
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "src/lite_model.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class VersionManager {
|
||||||
|
public:
|
||||||
|
static VersionManager *GetInstance() {
|
||||||
|
static VersionManager instance;
|
||||||
|
return &instance;
|
||||||
|
}
|
||||||
|
virtual ~VersionManager() = default;
|
||||||
|
|
||||||
|
void SetSchemaVersion(const int schema_version) { schema_version_ = schema_version; }
|
||||||
|
int GetSchemaVersion() const { return schema_version_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
VersionManager() = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_
|
|
@ -13,15 +13,115 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "src/model_common.h"
|
|
||||||
|
#include "src/lite_model.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
#include <unordered_map>
|
||||||
#include "src/ops/while.h"
|
#include "src/ops/while.h"
|
||||||
|
#ifdef ENABLE_V0
|
||||||
|
#include "src/ops/compat/compat_register.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) {
|
#ifdef ENABLE_V0
|
||||||
if (model == nullptr) {
|
int LiteModel::ConvertAttrs(Model::Node *node, const schema::v0::Primitive *prim,
|
||||||
MS_LOG(ERROR) << "model is null.";
|
std::vector<schema::Tensor *> *dst_tensor) {
|
||||||
|
if (node == nullptr || dst_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "node or tensor_vec is nullptr.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
int primitive_type = prim->value_type();
|
||||||
|
auto creator = CompatRegistry::GetInstance()->GetTransferAttrFunc(SCHEMA_VERSION::SCHEMA_V0, primitive_type);
|
||||||
|
if (creator == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "the node don't need to convert attr to tensor.";
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
int status = creator(reinterpret_cast<const void *>(prim), node, dst_tensor, &this->attr_tensor_bufs_);
|
||||||
|
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||||
|
MS_LOG(ERROR) << "translate attr to tensor failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LiteModel::ConvertAttrToTensors(const void *meta_graph) {
|
||||||
|
MS_ASSERT(meta_graph != nullptr);
|
||||||
|
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
|
||||||
|
if (schema_version != SCHEMA_VERSION::SCHEMA_V0) {
|
||||||
|
MS_LOG(DEBUG) << "no need to convert attr to tensor.";
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
auto meta_graph_v0 = reinterpret_cast<const schema::v0::MetaGraph *>(meta_graph);
|
||||||
|
std::unordered_map<int, std::set<int>> subgraph_node_indexes;
|
||||||
|
for (size_t subgraph_index = 0; subgraph_index < this->sub_graphs_.size(); ++subgraph_index) {
|
||||||
|
for (size_t node_index = 0; node_index < this->sub_graphs_[subgraph_index]->node_indices_.size(); ++node_index) {
|
||||||
|
subgraph_node_indexes[subgraph_index].insert(this->sub_graphs_[subgraph_index]->node_indices_[node_index]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int cur_all_tensors_size = this->all_tensors_.size();
|
||||||
|
for (size_t index = 0; index < this->all_nodes_.size(); ++index) {
|
||||||
|
std::vector<schema::Tensor *> dst_tensors;
|
||||||
|
auto prim = meta_graph_v0->nodes()->GetAs<schema::v0::CNode>(index)->primitive();
|
||||||
|
int status = ConvertAttrs(this->all_nodes_[index], prim, &dst_tensors);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "fail to convert attr to tensor.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (dst_tensors.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::vector<int> subgraphs_with_node;
|
||||||
|
for (size_t subgraph_index = 0; subgraph_index < this->sub_graphs_.size(); ++subgraph_index) {
|
||||||
|
if (subgraph_node_indexes[subgraph_index].find(index) == subgraph_node_indexes[subgraph_index].end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
subgraphs_with_node.push_back(subgraph_index);
|
||||||
|
}
|
||||||
|
for (auto tensor : dst_tensors) {
|
||||||
|
for (auto subgraph_index : subgraphs_with_node) {
|
||||||
|
this->sub_graphs_[subgraph_index]->tensor_indices_.push_back(cur_all_tensors_size);
|
||||||
|
}
|
||||||
|
this->all_nodes_[index]->input_indices_.push_back(cur_all_tensors_size++);
|
||||||
|
this->all_tensors_.push_back(tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void LiteModel::Free() {
|
||||||
|
if (this->buf != nullptr) {
|
||||||
|
free(this->buf);
|
||||||
|
this->buf = nullptr;
|
||||||
|
}
|
||||||
|
for (auto &tensor_buf : attr_tensor_bufs_) {
|
||||||
|
free(tensor_buf);
|
||||||
|
}
|
||||||
|
attr_tensor_bufs_.resize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
LiteModel::~LiteModel() {
|
||||||
|
Free();
|
||||||
|
auto nodes_size = this->all_nodes_.size();
|
||||||
|
for (size_t i = 0; i < nodes_size; ++i) {
|
||||||
|
auto node = this->all_nodes_[i];
|
||||||
|
MS_ASSERT(node != nullptr);
|
||||||
|
MS_ASSERT(node->primitive_ != nullptr);
|
||||||
|
delete node->primitive_;
|
||||||
|
node->primitive_ = nullptr;
|
||||||
|
delete node;
|
||||||
|
}
|
||||||
|
this->all_nodes_.clear();
|
||||||
|
|
||||||
|
auto sub_graph_size = this->sub_graphs_.size();
|
||||||
|
for (size_t i = 0; i < sub_graph_size; ++i) {
|
||||||
|
auto sub_graph = this->sub_graphs_[i];
|
||||||
|
delete sub_graph;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int LiteModel::ConvertSubGraph(const schema::SubGraph &sub_graph) {
|
||||||
if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr ||
|
if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr ||
|
||||||
sub_graph.nodeIndices() == nullptr || sub_graph.tensorIndices() == nullptr) {
|
sub_graph.nodeIndices() == nullptr || sub_graph.tensorIndices() == nullptr) {
|
||||||
MS_LOG(ERROR) << "sub_graph is invalid.";
|
MS_LOG(ERROR) << "sub_graph is invalid.";
|
||||||
|
@ -51,28 +151,31 @@ int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) {
|
||||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||||
subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i));
|
subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i));
|
||||||
}
|
}
|
||||||
model->sub_graphs_.push_back(subgraph);
|
this->sub_graphs_.push_back(subgraph);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int VersionVerify(flatbuffers::Verifier *verify) {
|
int LiteModel::VersionVerify(flatbuffers::Verifier *verify) const {
|
||||||
if (verify == nullptr) {
|
if (verify == nullptr) {
|
||||||
MS_LOG(ERROR) << "verify is null.";
|
MS_LOG(ERROR) << "verify is null.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (schema::VerifyMetaGraphBuffer(*verify)) {
|
if (schema::VerifyMetaGraphBuffer(*verify)) {
|
||||||
return SCHEMA_VERSION::SCHEMA_CUR;
|
return SCHEMA_VERSION::SCHEMA_CUR;
|
||||||
} else if (schema::v0::VerifyMetaGraphBuffer(*verify)) {
|
}
|
||||||
|
#ifdef ENABLE_V0
|
||||||
|
if (schema::v0::VerifyMetaGraphBuffer(*verify)) {
|
||||||
return SCHEMA_VERSION::SCHEMA_V0;
|
return SCHEMA_VERSION::SCHEMA_V0;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
return SCHEMA_VERSION::SCHEMA_INVALID;
|
return SCHEMA_VERSION::SCHEMA_INVALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
int NodeVerify(const Model &model) {
|
int LiteModel::NodeVerify() const {
|
||||||
auto tensor_size = model.all_tensors_.size();
|
auto tensor_size = this->all_tensors_.size();
|
||||||
uint32_t subGraph_size = model.sub_graphs_.size();
|
uint32_t subGraph_size = this->sub_graphs_.size();
|
||||||
|
|
||||||
for (auto &node : model.all_nodes_) {
|
for (auto &node : this->all_nodes_) {
|
||||||
if (node == nullptr || node->primitive_ == nullptr) {
|
if (node == nullptr || node->primitive_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "node or its primitive_ is null.";
|
MS_LOG(ERROR) << "node or its primitive_ is null.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -105,11 +208,11 @@ int NodeVerify(const Model &model) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int SubGraphVerify(const Model &model) {
|
int LiteModel::SubGraphVerify() const {
|
||||||
auto tensor_size = model.all_tensors_.size();
|
auto tensor_size = this->all_tensors_.size();
|
||||||
auto node_size = model.all_nodes_.size();
|
auto node_size = this->all_nodes_.size();
|
||||||
|
|
||||||
for (auto &graph : model.sub_graphs_) {
|
for (auto &graph : this->sub_graphs_) {
|
||||||
if (graph == nullptr) {
|
if (graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "graph is null.";
|
MS_LOG(ERROR) << "graph is null.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -138,49 +241,78 @@ int SubGraphVerify(const Model &model) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ModelVerify(const Model &model) { return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK; }
|
bool LiteModel::ModelVerify() const { return NodeVerify() == RET_OK && SubGraphVerify() == RET_OK; }
|
||||||
|
|
||||||
const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) {
|
const void *LiteModel::GetMetaGraphByVerison() {
|
||||||
if (buf == nullptr) {
|
MS_ASSERT(this->buf != nullptr);
|
||||||
MS_LOG(ERROR) << "buf is null.";
|
auto schema_version = VersionManager::GetInstance()->GetSchemaVersion();
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
|
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
|
||||||
return reinterpret_cast<const void *>(schema::GetMetaGraph(buf));
|
return reinterpret_cast<const void *>(schema::GetMetaGraph(this->buf));
|
||||||
} else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) {
|
}
|
||||||
|
#ifdef ENABLE_V0
|
||||||
|
if (schema_version == SCHEMA_VERSION::SCHEMA_V0) {
|
||||||
return reinterpret_cast<const void *>(schema::v0::GetMetaGraph(buf));
|
return reinterpret_cast<const void *>(schema::v0::GetMetaGraph(buf));
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version) {
|
int LiteModel::GenerateModelByVersion(const void *meta_graph) {
|
||||||
if (meta_graph == nullptr || model == nullptr) {
|
MS_ASSERT(meta_graph != nullptr);
|
||||||
MS_LOG(ERROR) << "meta_graph or model is null.";
|
auto schema_version = VersionManager::GetInstance()->GetSchemaVersion();
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
int status = RET_ERROR;
|
int status = RET_ERROR;
|
||||||
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
|
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
|
||||||
status = GenerateModel<schema::MetaGraph, schema::CNode>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph),
|
status = GenerateModel<schema::MetaGraph, schema::CNode>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph));
|
||||||
model, schema_version);
|
|
||||||
} else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) {
|
|
||||||
status = GenerateModel<schema::v0::MetaGraph, schema::v0::CNode>(
|
|
||||||
*reinterpret_cast<const schema::v0::MetaGraph *>(meta_graph), model, schema_version);
|
|
||||||
}
|
}
|
||||||
|
#ifdef ENABLE_V0
|
||||||
|
if (schema_version == SCHEMA_VERSION::SCHEMA_V0) {
|
||||||
|
status = GenerateModel<schema::v0::MetaGraph, schema::v0::CNode>(
|
||||||
|
*reinterpret_cast<const schema::v0::MetaGraph *>(meta_graph));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int LiteModel::ConstructModel() {
|
||||||
|
if (this->buf == nullptr || this->buf_size_ <= 0) {
|
||||||
|
MS_LOG(ERROR) << "cannot construct model.";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
flatbuffers::Verifier verify((const uint8_t *)this->buf, this->buf_size_);
|
||||||
|
int schema_version = VersionVerify(&verify);
|
||||||
|
if (schema_version == SCHEMA_INVALID) {
|
||||||
|
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
VersionManager::GetInstance()->SetSchemaVersion(schema_version);
|
||||||
|
const void *meta_graph = GetMetaGraphByVerison();
|
||||||
|
if (meta_graph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "meta_graph is nullptr!";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
int status = GenerateModelByVersion(meta_graph);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "fail to generate model";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this->version_ != Version()) {
|
||||||
|
MS_LOG(WARNING) << "model version is " << this->version_ << ", inference version is " << Version() << " not equal";
|
||||||
|
}
|
||||||
|
if (this->sub_graphs_.empty()) {
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ModelVerify() ? RET_OK : RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
|
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
|
||||||
if (model_buf == nullptr) {
|
if (model_buf == nullptr) {
|
||||||
MS_LOG(ERROR) << "The model buf is nullptr";
|
MS_LOG(ERROR) << "The model buf is nullptr";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
|
auto *model = new (std::nothrow) LiteModel();
|
||||||
int schema_version = VersionVerify(&verify);
|
|
||||||
if (schema_version == SCHEMA_INVALID) {
|
|
||||||
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto *model = new (std::nothrow) Model();
|
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
MS_LOG(ERROR) << "new model fail!";
|
MS_LOG(ERROR) << "new model fail!";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -201,28 +333,15 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
|
||||||
}
|
}
|
||||||
memcpy(model->buf, model_buf, size);
|
memcpy(model->buf, model_buf, size);
|
||||||
}
|
}
|
||||||
const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version);
|
model->buf_size_ = size;
|
||||||
if (meta_graph == nullptr) {
|
auto status = model->ConstructModel();
|
||||||
MS_LOG(ERROR) << "meta_graph is nullptr!";
|
|
||||||
delete (model);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
int status = GenerateModelByVersion(meta_graph, model, schema_version);
|
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
delete (model);
|
MS_LOG(ERROR) << "construct model failed.";
|
||||||
MS_LOG(ERROR) << "fail to generate model";
|
delete model;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
return model;
|
||||||
if (model->version_ != Version()) {
|
|
||||||
MS_LOG(WARNING) << "model version is " << model->version_ << ", inference version is " << Version() << " not equal";
|
|
||||||
}
|
|
||||||
if (model->sub_graphs_.empty()) {
|
|
||||||
delete (model);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
return ModelVerify(*model) ? model : nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
|
@ -0,0 +1,223 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_LITE_MODEL_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_LITE_MODEL_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "include/model.h"
|
||||||
|
#include "src/ops/primitive_c.h"
|
||||||
|
#include "include/version.h"
|
||||||
|
#include "schema/model_generated.h"
|
||||||
|
#include "src/common/common.h"
|
||||||
|
#include "src/common/version_manager.h"
|
||||||
|
#ifndef PRIMITIVE_WRITEABLE
|
||||||
|
#include "src/ops/ops_register.h"
|
||||||
|
#endif
|
||||||
|
#ifdef ENABLE_V0
|
||||||
|
#include "schema/model_v0_generated.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class LiteModel : public Model {
|
||||||
|
public:
|
||||||
|
int ConstructModel();
|
||||||
|
|
||||||
|
bool ModelVerify() const;
|
||||||
|
|
||||||
|
void Free() override;
|
||||||
|
|
||||||
|
~LiteModel() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
#ifdef ENABLE_V0
|
||||||
|
int ConvertAttrs(Model::Node *node, const schema::v0::Primitive *prim, std::vector<schema::Tensor *> *dst_tensor);
|
||||||
|
|
||||||
|
int ConvertAttrToTensors(const void *meta_graph);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename T = schema::MetaGraph, typename U = schema::CNode>
|
||||||
|
bool ConvertNodes(const T &meta_graph) {
|
||||||
|
if (meta_graph.nodes() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) {
|
||||||
|
auto *node = new (std::nothrow) Model::Node();
|
||||||
|
if (node == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new node fail!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto c_node = meta_graph.nodes()->template GetAs<U>(i);
|
||||||
|
auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive());
|
||||||
|
#ifdef PRIMITIVE_WRITEABLE
|
||||||
|
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
|
||||||
|
#else
|
||||||
|
auto primitive = const_cast<schema::Primitive *>(src_prim);
|
||||||
|
auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type());
|
||||||
|
if (func_pointer == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: "
|
||||||
|
<< schema::EnumNamePrimitiveType(primitive->value_type());
|
||||||
|
delete node;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
node->primitive_ = func_pointer(primitive);
|
||||||
|
#endif
|
||||||
|
if (node->primitive_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "unpack primitive == nullptr!";
|
||||||
|
delete node;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
node->primitive_->set_quant_type(static_cast<schema::QuantType>(c_node->quantType()));
|
||||||
|
node->name_ = c_node->name()->c_str();
|
||||||
|
node->node_type_ = static_cast<NodeType>(c_node->nodeType());
|
||||||
|
auto count = c_node->inputIndex()->size();
|
||||||
|
for (uint32_t j = 0; j < count; ++j) {
|
||||||
|
node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j)));
|
||||||
|
}
|
||||||
|
if (c_node->outputIndex() != nullptr) {
|
||||||
|
count = c_node->outputIndex()->size();
|
||||||
|
for (uint32_t j = 0; j < count; ++j) {
|
||||||
|
node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this->all_nodes_.push_back(node);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T = schema::MetaGraph>
|
||||||
|
bool ConvertTensors(const T &meta_graph) {
|
||||||
|
if (meta_graph.allTensors() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto tensor_count = meta_graph.allTensors()->size();
|
||||||
|
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||||
|
auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i);
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << i << "the tensor in metagraph is nullptr";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
this->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T = schema::MetaGraph>
|
||||||
|
int MetaGraphMappingSubGraph(const T &meta_graph) {
|
||||||
|
if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || meta_graph.nodes() == nullptr ||
|
||||||
|
meta_graph.allTensors() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto *subgraph = new (std::nothrow) Model::SubGraph();
|
||||||
|
if (subgraph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new subGraph fail!";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (meta_graph.name() != nullptr) {
|
||||||
|
subgraph->name_ = meta_graph.name()->c_str();
|
||||||
|
}
|
||||||
|
auto in_count = meta_graph.inputIndex()->size();
|
||||||
|
for (uint32_t i = 0; i < in_count; ++i) {
|
||||||
|
subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs<uint32_t>(i)));
|
||||||
|
}
|
||||||
|
auto out_count = meta_graph.outputIndex()->size();
|
||||||
|
for (uint32_t i = 0; i < out_count; ++i) {
|
||||||
|
subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs<uint32_t>(i)));
|
||||||
|
}
|
||||||
|
auto node_count = meta_graph.nodes()->size();
|
||||||
|
for (uint32_t i = 0; i < node_count; ++i) {
|
||||||
|
subgraph->node_indices_.push_back(i);
|
||||||
|
}
|
||||||
|
auto tensor_count = meta_graph.allTensors()->size();
|
||||||
|
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||||
|
subgraph->tensor_indices_.push_back(i);
|
||||||
|
}
|
||||||
|
this->sub_graphs_.push_back(subgraph);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T = schema::MetaGraph, typename U = schema::CNode>
|
||||||
|
int GenerateModel(const T &meta_graph) {
|
||||||
|
if (meta_graph.name() != nullptr) {
|
||||||
|
this->name_ = meta_graph.name()->c_str();
|
||||||
|
}
|
||||||
|
if (meta_graph.version() != nullptr) {
|
||||||
|
this->version_ = meta_graph.version()->c_str();
|
||||||
|
}
|
||||||
|
if (!ConvertNodes<T, U>(meta_graph)) {
|
||||||
|
MS_LOG(ERROR) << "convert node failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (!ConvertTensors<T>(meta_graph)) {
|
||||||
|
MS_LOG(ERROR) << "convert tensor failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (meta_graph.subGraph() == nullptr) {
|
||||||
|
int ret = MetaGraphMappingSubGraph<T>(meta_graph);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "converter old version model wrong.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto sub_graphs = meta_graph.subGraph();
|
||||||
|
auto sub_graph_size = sub_graphs->size();
|
||||||
|
for (size_t i = 0; i < sub_graph_size; i++) {
|
||||||
|
auto sub_graph = sub_graphs->template GetAs<schema::SubGraph>(i);
|
||||||
|
int ret = ConvertSubGraph(*sub_graph);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "converter subgraph wrong.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#ifdef ENABLE_V0
|
||||||
|
if (ConvertAttrToTensors(&meta_graph) != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "fail to convert attr to tensor.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int VersionVerify(flatbuffers::Verifier *verify) const;
|
||||||
|
|
||||||
|
const void *GetMetaGraphByVerison();
|
||||||
|
|
||||||
|
int GenerateModelByVersion(const void *meta_graph);
|
||||||
|
|
||||||
|
int ConvertSubGraph(const schema::SubGraph &sub_graph);
|
||||||
|
|
||||||
|
int NodeVerify() const;
|
||||||
|
|
||||||
|
int SubGraphVerify() const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
size_t buf_size_ = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<char *> attr_tensor_bufs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_SRC_LITE_MODEL_H_
|
|
@ -26,7 +26,7 @@
|
||||||
#include "src/common/utils.h"
|
#include "src/common/utils.h"
|
||||||
#include "src/common/graph_util.h"
|
#include "src/common/graph_util.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "src/model_common.h"
|
#include "src/lite_model.h"
|
||||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||||
#if SUPPORT_NPU
|
#if SUPPORT_NPU
|
||||||
#include "src/runtime/agent/npu/npu_manager.h"
|
#include "src/runtime/agent/npu/npu_manager.h"
|
||||||
|
@ -363,7 +363,7 @@ int LiteSession::CompileGraph(Model *model) {
|
||||||
is_running_.store(false);
|
is_running_.store(false);
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
if (!ModelVerify(*model)) {
|
if (!reinterpret_cast<LiteModel *>(model)->ModelVerify()) {
|
||||||
MS_LOG(ERROR) << "wrong model input, please check";
|
MS_LOG(ERROR) << "wrong model input, please check";
|
||||||
is_running_.store(false);
|
is_running_.store(false);
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "src/ops/primitive_c.h"
|
|
||||||
#include "include/model.h"
|
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
#include "src/model_common.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
|
|
||||||
|
|
||||||
void Model::Free() {
|
|
||||||
if (this->buf != nullptr) {
|
|
||||||
free(this->buf);
|
|
||||||
this->buf = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Model::Destroy() {
|
|
||||||
Free();
|
|
||||||
auto nodes_size = this->all_nodes_.size();
|
|
||||||
for (size_t i = 0; i < nodes_size; ++i) {
|
|
||||||
auto node = this->all_nodes_[i];
|
|
||||||
MS_ASSERT(node != nullptr);
|
|
||||||
MS_ASSERT(node->primitive_ != nullptr);
|
|
||||||
delete node->primitive_;
|
|
||||||
node->primitive_ = nullptr;
|
|
||||||
delete node;
|
|
||||||
}
|
|
||||||
this->all_nodes_.clear();
|
|
||||||
|
|
||||||
auto sub_graph_size = this->sub_graphs_.size();
|
|
||||||
for (size_t i = 0; i < sub_graph_size; ++i) {
|
|
||||||
auto sub_graph = this->sub_graphs_[i];
|
|
||||||
delete sub_graph;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Model::~Model() { Destroy(); }
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,192 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_MODEL_COMMON_H_
|
|
||||||
#define MINDSPORE_LITE_SRC_MODEL_COMMON_H_
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include "src/ops/primitive_c.h"
|
|
||||||
#include "include/model.h"
|
|
||||||
#include "include/version.h"
|
|
||||||
#include "schema/model_generated.h"
|
|
||||||
#include "schema/model_v0_generated.h"
|
|
||||||
#include "src/common/common.h"
|
|
||||||
#ifndef PRIMITIVE_WRITEABLE
|
|
||||||
#include "src/ops/ops_register.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model);
|
|
||||||
|
|
||||||
template <typename T = schema::MetaGraph, typename U = schema::CNode>
|
|
||||||
bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = SCHEMA_CUR) {
|
|
||||||
if (model == nullptr || meta_graph.nodes() == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "model or meta_graph is invalid, please check your model file.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) {
|
|
||||||
auto *node = new (std::nothrow) Model::Node();
|
|
||||||
if (node == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "new node fail!";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto c_node = meta_graph.nodes()->template GetAs<U>(i);
|
|
||||||
auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive());
|
|
||||||
#ifdef PRIMITIVE_WRITEABLE
|
|
||||||
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
|
|
||||||
#else
|
|
||||||
auto primitive = const_cast<schema::Primitive *>(src_prim);
|
|
||||||
auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type());
|
|
||||||
if (func_pointer == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: "
|
|
||||||
<< schema::EnumNamePrimitiveType(primitive->value_type());
|
|
||||||
delete node;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
node->primitive_ = func_pointer(primitive);
|
|
||||||
#endif
|
|
||||||
if (node->primitive_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "unpack primitive == nullptr!";
|
|
||||||
delete node;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
node->primitive_->set_quant_type(static_cast<schema::QuantType>(c_node->quantType()));
|
|
||||||
node->name_ = c_node->name()->c_str();
|
|
||||||
node->node_type_ = static_cast<NodeType>(c_node->nodeType());
|
|
||||||
auto count = c_node->inputIndex()->size();
|
|
||||||
for (uint32_t j = 0; j < count; ++j) {
|
|
||||||
node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j)));
|
|
||||||
}
|
|
||||||
if (c_node->outputIndex() != nullptr) {
|
|
||||||
count = c_node->outputIndex()->size();
|
|
||||||
for (uint32_t j = 0; j < count; ++j) {
|
|
||||||
node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
model->all_nodes_.push_back(node);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T = schema::MetaGraph>
|
|
||||||
bool ConvertTensors(const T &meta_graph, Model *model) {
|
|
||||||
if (model == nullptr || meta_graph.allTensors() == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "model or meta_graph is invalid, please check your model file.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto tensor_count = meta_graph.allTensors()->size();
|
|
||||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
|
||||||
auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i);
|
|
||||||
if (tensor == nullptr) {
|
|
||||||
MS_LOG(ERROR) << i << "th tensor in model is nullptr";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
model->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T = schema::MetaGraph>
|
|
||||||
int MetaGraphMappingSubGraph(const T &meta_graph, Model *model) {
|
|
||||||
if (model == nullptr || meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr ||
|
|
||||||
meta_graph.nodes() == nullptr || meta_graph.allTensors() == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "model or meta_graph is invalid, please check your model file.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
auto *subgraph = new (std::nothrow) Model::SubGraph();
|
|
||||||
if (subgraph == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "new subGraph fail!";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
if (meta_graph.name() != nullptr) {
|
|
||||||
subgraph->name_ = meta_graph.name()->c_str();
|
|
||||||
}
|
|
||||||
auto in_count = meta_graph.inputIndex()->size();
|
|
||||||
for (uint32_t i = 0; i < in_count; ++i) {
|
|
||||||
subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs<uint32_t>(i)));
|
|
||||||
}
|
|
||||||
auto out_count = meta_graph.outputIndex()->size();
|
|
||||||
for (uint32_t i = 0; i < out_count; ++i) {
|
|
||||||
subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs<uint32_t>(i)));
|
|
||||||
}
|
|
||||||
auto node_count = meta_graph.nodes()->size();
|
|
||||||
for (uint32_t i = 0; i < node_count; ++i) {
|
|
||||||
subgraph->node_indices_.push_back(i);
|
|
||||||
}
|
|
||||||
auto tensor_count = meta_graph.allTensors()->size();
|
|
||||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
|
||||||
subgraph->tensor_indices_.push_back(i);
|
|
||||||
}
|
|
||||||
model->sub_graphs_.push_back(subgraph);
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T = schema::MetaGraph, typename U = schema::CNode>
|
|
||||||
int GenerateModel(const T &meta_graph, Model *model, int schema_version = 0) {
|
|
||||||
if (model == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "model is nullptr.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
if (meta_graph.name() != nullptr) {
|
|
||||||
model->name_ = meta_graph.name()->c_str();
|
|
||||||
}
|
|
||||||
if (meta_graph.version() != nullptr) {
|
|
||||||
model->version_ = meta_graph.version()->c_str();
|
|
||||||
}
|
|
||||||
if (!ConvertNodes<T, U>(meta_graph, model, schema_version)) {
|
|
||||||
MS_LOG(ERROR) << "convert node failed";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
if (!ConvertTensors<T>(meta_graph, model)) {
|
|
||||||
MS_LOG(ERROR) << "convert tensor failed";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
if (meta_graph.subGraph() == nullptr) {
|
|
||||||
int ret = MetaGraphMappingSubGraph<T>(meta_graph, model);
|
|
||||||
if (ret != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "converter old version model wrong.";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto sub_graphs = meta_graph.subGraph();
|
|
||||||
auto sub_graph_size = sub_graphs->size();
|
|
||||||
for (size_t i = 0; i < sub_graph_size; i++) {
|
|
||||||
auto sub_graph = sub_graphs->template GetAs<schema::SubGraph>(i);
|
|
||||||
int ret = ConvertSubGraph(*sub_graph, model);
|
|
||||||
if (ret != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "converter subgraph wrong.";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
int VersionVerify(flatbuffers::Verifier *verify);
|
|
||||||
|
|
||||||
int NodeVerify(const Model &model);
|
|
||||||
|
|
||||||
int SubGraphVerify(const Model &model);
|
|
||||||
|
|
||||||
bool ModelVerify(const Model &model);
|
|
||||||
|
|
||||||
const void *GetMetaGraphByVerison(const char *buf, const int &schema_version);
|
|
||||||
|
|
||||||
int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version);
|
|
||||||
|
|
||||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
#endif // MINDSPORE_LITE_SRC_MODEL_COMMON_H_
|
|
|
@ -4,6 +4,10 @@ file(GLOB OPS_SRC
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/*.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/*.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/populate/*.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/populate/*.cc
|
||||||
)
|
)
|
||||||
|
if (ENABLE_V0)
|
||||||
|
file(GLOB_RECURSE COMPAT_SRC ${CMAKE_CURRENT_SOURCE_DIR}/compat/*.cc)
|
||||||
|
set(OPS_SRC ${OPS_SRC} ${COMPAT_SRC})
|
||||||
|
endif ()
|
||||||
|
|
||||||
add_library(cpu_ops_mid OBJECT ${OPS_SRC})
|
add_library(cpu_ops_mid OBJECT ${OPS_SRC})
|
||||||
add_dependencies(cpu_ops_mid fbs_src)
|
add_dependencies(cpu_ops_mid fbs_src)
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/ops/compat/attr_transfer_common.h"
|
||||||
|
#include <vector>
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
schema::Tensor *AttrToTensor(void *data, int data_size, bool is_array, TypeId type_id,
|
||||||
|
std::vector<char *> *tensor_bufs) {
|
||||||
|
if (data == nullptr || tensor_bufs == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "the parameter of this function is nullptr.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto dst_tensor =
|
||||||
|
(is_array ? new (std::nothrow) Tensor(type_id, {data_size}, schema::Format_NHWC, Tensor::Category::CONST_TENSOR)
|
||||||
|
: new (std::nothrow) Tensor(type_id, {}, schema::Format_NHWC, Tensor::Category::CONST_SCALAR));
|
||||||
|
auto dst_data = dst_tensor->MutableData();
|
||||||
|
if (dst_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Data from tensor is nullptr";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::vector<uint8_t> uint8_data;
|
||||||
|
uint8_data.resize(dst_tensor->Size());
|
||||||
|
memcpy(uint8_data.data(), data, dst_tensor->Size());
|
||||||
|
auto shape = dst_tensor->shape();
|
||||||
|
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||||
|
auto tensor_offset = schema::CreateTensorDirect(fbb, schema::NodeType_ValueNode, type_id, &shape, schema::Format_NHWC,
|
||||||
|
0, 0, &uint8_data);
|
||||||
|
fbb.Finish(tensor_offset);
|
||||||
|
delete dst_tensor;
|
||||||
|
auto buf = fbb.GetBufferPointer();
|
||||||
|
if (buf == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "GetBufferPointer return nullptr";
|
||||||
|
fbb.Clear();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto tensor_buf = reinterpret_cast<char *>(malloc(fbb.GetSize()));
|
||||||
|
if (tensor_buf == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc primitive_buf_ failed";
|
||||||
|
fbb.Clear();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
memcpy(tensor_buf, buf, fbb.GetSize());
|
||||||
|
auto tensor = flatbuffers::GetRoot<schema::Tensor>(tensor_buf);
|
||||||
|
tensor_bufs->push_back(tensor_buf);
|
||||||
|
fbb.Clear();
|
||||||
|
return const_cast<schema::Tensor *>(tensor);
|
||||||
|
}
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,35 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef LITE_MINDSPORE_LITE_C_OPS_OP_ATTR_TRANSFER_COMMON_H_
|
||||||
|
#define LITE_MINDSPORE_LITE_C_OPS_OP_ATTR_TRANSFER_COMMON_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "ir/dtype/type_id.h"
|
||||||
|
#include "src/tensor.h"
|
||||||
|
#include "include/errorcode.h"
|
||||||
|
#include "schema/model_v0_generated.h"
|
||||||
|
#include "src/common/common.h"
|
||||||
|
#include "src/ops/compat/compat_register.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
schema::Tensor *AttrToTensor(void *data, int data_size, bool is_array, TypeId type_id,
|
||||||
|
std::vector<char *> *tensor_bufs);
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // LITE_MINDSPORE_LITE_C_OPS_OP_ATTR_TRANSFER_COMMON_H_
|
|
@ -0,0 +1,67 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef LITE_MINDSPORE_LITE_C_OPS_OP_COMPAT_REGISTER_H_
|
||||||
|
#define LITE_MINDSPORE_LITE_C_OPS_OP_COMPAT_REGISTER_H_
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "include/model.h"
|
||||||
|
#include "schema/model_generated.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
// compatibility, transfer attr to input tensor.
|
||||||
|
typedef int (*TransferAttrFunc)(const void *primitive, Model::Node *node, std::vector<schema::Tensor *> *tensor,
|
||||||
|
std::vector<char *> *tensor_bufs);
|
||||||
|
class CompatRegistry {
|
||||||
|
public:
|
||||||
|
static CompatRegistry *GetInstance() {
|
||||||
|
static CompatRegistry registry;
|
||||||
|
return ®istry;
|
||||||
|
}
|
||||||
|
|
||||||
|
void InsertTransferAttrFuncMap(int schema_version, int primitive_type, TransferAttrFunc transfer_attr_func) {
|
||||||
|
int key = primitive_type * 10 + schema_version;
|
||||||
|
transfer_attr_funcs_[key] = transfer_attr_func;
|
||||||
|
}
|
||||||
|
|
||||||
|
TransferAttrFunc GetTransferAttrFunc(int schema_version, int primitive_type) {
|
||||||
|
int key = primitive_type * 10 + schema_version;
|
||||||
|
if (transfer_attr_funcs_.find(key) != transfer_attr_funcs_.end()) {
|
||||||
|
return transfer_attr_funcs_[key];
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "Unsupported transformer type in Create : " << key;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::unordered_map<int, TransferAttrFunc> transfer_attr_funcs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Register {
|
||||||
|
public:
|
||||||
|
Register(int schema_version, int primitive_type, TransferAttrFunc transfer_attr_func) {
|
||||||
|
CompatRegistry::GetInstance()->InsertTransferAttrFuncMap(schema_version, primitive_type, transfer_attr_func);
|
||||||
|
}
|
||||||
|
virtual ~Register() = default;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // LITE_MINDSPORE_LITE_C_OPS_OP_COMPAT_REGISTER_H_
|
|
@ -0,0 +1,48 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/ops/compat/attr_transfer_common.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
int TransferBroadcastToAttr(const void *primitive, Model::Node *node, std::vector<schema::Tensor *> *dst_tensors,
|
||||||
|
std::vector<char *> *tensor_bufs) {
|
||||||
|
if (primitive == nullptr || node == nullptr || dst_tensors == nullptr || tensor_bufs == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "the parameter of this function is nullptr.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (node->input_indices_.size() != 1) {
|
||||||
|
MS_LOG(DEBUG) << "broadcast_to don't need to convert attr to tensor.";
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
dst_tensors->clear();
|
||||||
|
tensor_bufs->clear();
|
||||||
|
auto prim = reinterpret_cast<const schema::v0::Primitive *>(primitive);
|
||||||
|
auto dst_shape_attr = prim->value_as_BroadcastTo()->dst_shape();
|
||||||
|
std::vector<int> dst_shape = std::vector<int>(dst_shape_attr->begin(), dst_shape_attr->end());
|
||||||
|
auto dst_shape_tensor = AttrToTensor(dst_shape.data(), dst_shape.size(), true, kNumberTypeInt32, tensor_bufs);
|
||||||
|
if (dst_shape_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "attr tensor is nullptr, transform is failed.";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
dst_tensors->push_back(dst_shape_tensor);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
Register BroadcastToTransferRegistry(SCHEMA_VERSION::SCHEMA_V0, schema::v0::PrimitiveType_BroadcastTo,
|
||||||
|
TransferBroadcastToAttr);
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/ops/compat/attr_transfer_common.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
int TransferReshapeAttr(const void *primitive, Model::Node *node, std::vector<schema::Tensor *> *dst_tensors,
|
||||||
|
std::vector<char *> *tensor_bufs) {
|
||||||
|
if (primitive == nullptr || node == nullptr || dst_tensors == nullptr || tensor_bufs == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "the parameter of this function is nullptr.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (node->input_indices_.size() != 1) {
|
||||||
|
MS_LOG(DEBUG) << "reshape need to convert attr to tensor.";
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
dst_tensors->clear();
|
||||||
|
tensor_bufs->clear();
|
||||||
|
auto prim = reinterpret_cast<const schema::v0::Primitive *>(primitive);
|
||||||
|
auto dst_shape_attr = prim->value_as_Reshape()->shape();
|
||||||
|
std::vector<int> dst_shape = std::vector<int>(dst_shape_attr->begin(), dst_shape_attr->end());
|
||||||
|
auto dst_shape_tensor = AttrToTensor(dst_shape.data(), dst_shape.size(), true, kNumberTypeInt32, tensor_bufs);
|
||||||
|
if (dst_shape_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "attr tensor is nullptr, transform is failed.";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
dst_tensors->push_back(dst_shape_tensor);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
Register ReshapeTransferRegistry(SCHEMA_VERSION::SCHEMA_V0, schema::v0::PrimitiveType_Reshape, TransferReshapeAttr);
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,67 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/ops/compat/attr_transfer_common.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
int TransferStridedSliceAttr(const void *primitive, Model::Node *node, std::vector<schema::Tensor *> *dst_tensors,
|
||||||
|
std::vector<char *> *tensor_bufs) {
|
||||||
|
if (primitive == nullptr || node == nullptr || dst_tensors == nullptr || tensor_bufs == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "the parameter of this function is nullptr.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
dst_tensors->clear();
|
||||||
|
tensor_bufs->clear();
|
||||||
|
auto prim = reinterpret_cast<const schema::v0::Primitive *>(primitive);
|
||||||
|
int inputs_size = node->input_indices_.size();
|
||||||
|
switch (inputs_size) {
|
||||||
|
case 1: {
|
||||||
|
auto begins_attr = prim->value_as_StridedSlice()->begin();
|
||||||
|
std::vector<int> dst_begins = std::vector<int>(begins_attr->begin(), begins_attr->end());
|
||||||
|
auto dst_begins_tensor = AttrToTensor(dst_begins.data(), dst_begins.size(), true, kNumberTypeInt32, tensor_bufs);
|
||||||
|
dst_tensors->push_back(dst_begins_tensor);
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
auto ends_attr = prim->value_as_StridedSlice()->end();
|
||||||
|
std::vector<int> dst_ends = std::vector<int>(ends_attr->begin(), ends_attr->end());
|
||||||
|
auto dst_ends_tensor = AttrToTensor(dst_ends.data(), dst_ends.size(), true, kNumberTypeInt32, tensor_bufs);
|
||||||
|
dst_tensors->push_back(dst_ends_tensor);
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
auto strides_attr = prim->value_as_StridedSlice()->stride();
|
||||||
|
std::vector<int> dst_strides = std::vector<int>(strides_attr->begin(), strides_attr->end());
|
||||||
|
auto dst_strides_tensor =
|
||||||
|
AttrToTensor(dst_strides.data(), dst_strides.size(), true, kNumberTypeInt32, tensor_bufs);
|
||||||
|
dst_tensors->push_back(dst_strides_tensor);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
MS_LOG(DEBUG) << "stride_slice don't need to convert attr to tensor.";
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (std::any_of(dst_tensors->begin(), dst_tensors->end(), [](schema::Tensor *tensor) { return tensor == nullptr; })) {
|
||||||
|
MS_LOG(ERROR) << "convert attr to tensor failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
Register StridedSliceTransferRegistry(SCHEMA_VERSION::SCHEMA_V0, schema::v0::PrimitiveType_StridedSlice,
|
||||||
|
TransferStridedSliceAttr);
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -18,7 +18,6 @@
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "src/common/graph_util.h"
|
#include "src/common/graph_util.h"
|
||||||
#include "src/model_common.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
|
||||||
|
@ -27,12 +26,6 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
|
||||||
MS_LOG(ERROR) << "The model buf is nullptr";
|
MS_LOG(ERROR) << "The model buf is nullptr";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
|
|
||||||
int schema_version = VersionVerify(&verify);
|
|
||||||
if (schema_version == -1) {
|
|
||||||
MS_LOG(ERROR) << "The model buffer is invalid, cannot get schema version";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
TrainModel *model = new (std::nothrow) TrainModel();
|
TrainModel *model = new (std::nothrow) TrainModel();
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
MS_LOG(ERROR) << "new model fail!";
|
MS_LOG(ERROR) << "new model fail!";
|
||||||
|
@ -46,19 +39,10 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
|
||||||
}
|
}
|
||||||
memcpy(model->buf, model_buf, size);
|
memcpy(model->buf, model_buf, size);
|
||||||
model->buf_size_ = size;
|
model->buf_size_ = size;
|
||||||
const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version);
|
auto status = model->ConstructModel();
|
||||||
if (meta_graph == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "meta_graph is nullptr!";
|
|
||||||
free(model->buf);
|
|
||||||
delete (model);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
int status = GenerateModelByVersion(meta_graph, model, schema_version);
|
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
free(model->buf);
|
MS_LOG(ERROR) << "construct model failed.";
|
||||||
delete (model);
|
delete model;
|
||||||
MS_LOG(ERROR) << "fail to generate model";
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return model;
|
return model;
|
||||||
|
@ -91,6 +75,4 @@ char *TrainModel::ExportBuf(char *buffer, size_t *len) const {
|
||||||
*len = buf_size_;
|
*len = buf_size_;
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
TrainModel::~TrainModel() { Model::Free(); }
|
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -16,13 +16,13 @@
|
||||||
#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_
|
#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_
|
||||||
#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_
|
#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "include/model.h"
|
#include "src/lite_model.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
|
||||||
/// \brief TrainModel Defines a class that allows to import and export a mindsport trainable model
|
/// \brief TrainModel Defines a class that allows to import and export a mindsport trainable model
|
||||||
struct TrainModel : public lite::Model {
|
struct TrainModel : public lite::LiteModel {
|
||||||
/// \brief Static method to create a TrainModel object
|
/// \brief Static method to create a TrainModel object
|
||||||
///
|
///
|
||||||
/// \param[in] model_buf A buffer that was read from a MS model file
|
/// \param[in] model_buf A buffer that was read from a MS model file
|
||||||
|
@ -35,7 +35,7 @@ struct TrainModel : public lite::Model {
|
||||||
void Free() override;
|
void Free() override;
|
||||||
|
|
||||||
/// \brief Class destructor, free all memory
|
/// \brief Class destructor, free all memory
|
||||||
virtual ~TrainModel();
|
virtual ~TrainModel() = default;
|
||||||
|
|
||||||
/// \brief Export Model into a buffer
|
/// \brief Export Model into a buffer
|
||||||
///
|
///
|
||||||
|
@ -44,8 +44,6 @@ struct TrainModel : public lite::Model {
|
||||||
///
|
///
|
||||||
/// \return Pointer to buffer with exported model
|
/// \return Pointer to buffer with exported model
|
||||||
char *ExportBuf(char *buf, size_t *len) const;
|
char *ExportBuf(char *buf, size_t *len) const;
|
||||||
|
|
||||||
size_t buf_size_;
|
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -105,7 +105,8 @@ if (PLATFORM_ARM32 OR PLATFORM_ARM64)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
### runtime framework
|
### runtime framework
|
||||||
file(GLOB_RECURSE OPS_SRC ${LITE_DIR}/src/ops/*.cc ${LITE_DIR}/src/ops/populate/*.cc)
|
add_definitions(-DENABLE_V0)
|
||||||
|
file(GLOB_RECURSE OPS_SRC ${LITE_DIR}/src/ops/*.cc)
|
||||||
set(TEST_LITE_SRC
|
set(TEST_LITE_SRC
|
||||||
${TEST_LITE_SRC}
|
${TEST_LITE_SRC}
|
||||||
${CCSRC_SRC}
|
${CCSRC_SRC}
|
||||||
|
@ -123,8 +124,7 @@ set(TEST_LITE_SRC
|
||||||
${LITE_DIR}/src/lite_kernel.cc
|
${LITE_DIR}/src/lite_kernel.cc
|
||||||
${LITE_DIR}/src/lite_session.cc
|
${LITE_DIR}/src/lite_session.cc
|
||||||
${LITE_DIR}/src/sub_graph_kernel.cc
|
${LITE_DIR}/src/sub_graph_kernel.cc
|
||||||
${LITE_DIR}/src/model.cc
|
${LITE_DIR}/src/lite_model.cc
|
||||||
${LITE_DIR}/src/model_common.cc
|
|
||||||
${LITE_DIR}/src/scheduler.cc
|
${LITE_DIR}/src/scheduler.cc
|
||||||
${LITE_DIR}/src/common/graph_util.cc
|
${LITE_DIR}/src/common/graph_util.cc
|
||||||
${LITE_DIR}/src/common/file_utils.cc
|
${LITE_DIR}/src/common/file_utils.cc
|
||||||
|
|
|
@ -9,7 +9,7 @@ set(CCSRC_SRC
|
||||||
|
|
||||||
include(${TOP_DIR}/cmake/external_libs/glog.cmake)
|
include(${TOP_DIR}/cmake/external_libs/glog.cmake)
|
||||||
|
|
||||||
file(GLOB_RECURSE OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc
|
file(GLOB OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc)
|
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc)
|
||||||
|
|
||||||
file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
@ -88,8 +88,7 @@ set(LITE_SRC
|
||||||
${SRC_DIR}/sub_graph_kernel.cc
|
${SRC_DIR}/sub_graph_kernel.cc
|
||||||
${SRC_DIR}/lite_session.cc
|
${SRC_DIR}/lite_session.cc
|
||||||
${SRC_DIR}/executor.cc
|
${SRC_DIR}/executor.cc
|
||||||
${SRC_DIR}/model.cc
|
${SRC_DIR}/lite_model.cc
|
||||||
${SRC_DIR}/model_common.cc
|
|
||||||
${SRC_DIR}/errorcode.cc
|
${SRC_DIR}/errorcode.cc
|
||||||
)
|
)
|
||||||
if (SUPPORT_TRAIN)
|
if (SUPPORT_TRAIN)
|
||||||
|
|
|
@ -1581,6 +1581,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||||
flatbuffers::FlatBufferBuilder builder(1024);
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
|
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
|
||||||
builder.Finish(offset);
|
builder.Finish(offset);
|
||||||
|
schema::FinishMetaGraphBuffer(builder, offset);
|
||||||
size_t size = builder.GetSize();
|
size_t size = builder.GetSize();
|
||||||
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
|
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
|
||||||
if (content == nullptr) {
|
if (content == nullptr) {
|
||||||
|
@ -1662,6 +1663,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||||
flatbuffers::FlatBufferBuilder int8_builder(1024);
|
flatbuffers::FlatBufferBuilder int8_builder(1024);
|
||||||
auto int8_offset = schema::MetaGraph::Pack(int8_builder, int8_meta_graph);
|
auto int8_offset = schema::MetaGraph::Pack(int8_builder, int8_meta_graph);
|
||||||
int8_builder.Finish(int8_offset);
|
int8_builder.Finish(int8_offset);
|
||||||
|
schema::FinishMetaGraphBuffer(int8_builder, int8_offset);
|
||||||
size = int8_builder.GetSize();
|
size = int8_builder.GetSize();
|
||||||
auto *int8_content = reinterpret_cast<const char *>(int8_builder.GetBufferPointer());
|
auto *int8_content = reinterpret_cast<const char *>(int8_builder.GetBufferPointer());
|
||||||
if (int8_content == nullptr) {
|
if (int8_content == nullptr) {
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "tools/common/flag_parser.h"
|
#include "tools/common/flag_parser.h"
|
||||||
#include "src/common/file_utils.h"
|
#include "src/common/file_utils.h"
|
||||||
#include "src/common/utils.h"
|
#include "src/common/utils.h"
|
||||||
|
#include "schema/model_generated.h"
|
||||||
#include "include/lite_session.h"
|
#include "include/lite_session.h"
|
||||||
#include "tools/lib_cropper/cropper_flags.h"
|
#include "tools/lib_cropper/cropper_flags.h"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue