forked from mindspore-Ecosystem/mindspore
remove unused Prepare calling in operator & add CreateSession(const char
*model_buf, size_t size, lite::Context *context) interface
This commit is contained in:
parent
e805051c1f
commit
e19a3e3926
27
build.sh
27
build.sh
|
@ -266,12 +266,15 @@ checkopts()
|
|||
COMPILE_LITE="on"
|
||||
if [[ "$OPTARG" == "arm64" ]]; then
|
||||
ENABLE_CONVERTER="off"
|
||||
RUN_TESTCASES="on"
|
||||
LITE_PLATFORM="arm64"
|
||||
elif [[ "$OPTARG" == "arm32" ]]; then
|
||||
ENABLE_CONVERTER="off"
|
||||
RUN_TESTCASES="on"
|
||||
LITE_PLATFORM="arm32"
|
||||
elif [[ "$OPTARG" == "x86_64" ]]; then
|
||||
ENABLE_CONVERTER="on"
|
||||
RUN_TESTCASES="on"
|
||||
LITE_PLATFORM="x86_64"
|
||||
else
|
||||
echo "-I parameter must be arm64、arm32 or x86_64"
|
||||
|
@ -315,7 +318,7 @@ checkopts()
|
|||
elif [[ "$OPTARG" == "object-c" ]]; then
|
||||
LITE_LANGUAGE="object-c"
|
||||
else
|
||||
echo "-A parameter must be cpp、java or object-c"
|
||||
echo "-A parameter must be cpp, java or object-c"
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
|
@ -628,9 +631,9 @@ build_minddata_lite_deps()
|
|||
}
|
||||
|
||||
get_version() {
|
||||
VERSION_MAJOR=`grep "const int ms_version_major =" ${BASEPATH}/mindspore/lite/include/version.h | tr -dc "[0-9]"`
|
||||
VERSION_MINOR=`grep "const int ms_version_minor =" ${BASEPATH}/mindspore/lite/include/version.h | tr -dc "[0-9]"`
|
||||
VERSION_REVISION=`grep "const int ms_version_revision =" ${BASEPATH}/mindspore/lite/include/version.h | tr -dc "[0-9]"`
|
||||
VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/mindspore/lite/include/version.h | tr -dc "[0-9]")
|
||||
VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/mindspore/lite/include/version.h | tr -dc "[0-9]")
|
||||
VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/mindspore/lite/include/version.h | tr -dc "[0-9]")
|
||||
VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION}
|
||||
}
|
||||
|
||||
|
@ -642,7 +645,9 @@ build_lite()
|
|||
echo "start build opencl"
|
||||
build_opencl
|
||||
fi
|
||||
build_gtest
|
||||
if [ "${RUN_TESTCASES}" == "on" ]; then
|
||||
build_gtest
|
||||
fi
|
||||
|
||||
if [ "${COMPILE_MINDDATA_LITE}" == "lite" ] || [ "${COMPILE_MINDDATA_LITE}" == "full" ]; then
|
||||
build_minddata_lite_deps
|
||||
|
@ -665,7 +670,7 @@ build_lite()
|
|||
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \
|
||||
-DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
|
||||
-DPLATFORM_ARM64=on -DENABLE_NEON=on -DENABLE_FP16="off" \
|
||||
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=on \
|
||||
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
|
||||
-DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
||||
-DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \
|
||||
-DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \
|
||||
|
@ -676,14 +681,14 @@ build_lite()
|
|||
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \
|
||||
-DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
|
||||
-DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
|
||||
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=on \
|
||||
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
|
||||
-DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
||||
-DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \
|
||||
-DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \
|
||||
"${BASEPATH}/mindspore/lite"
|
||||
else
|
||||
cmake -DPLATFORM_ARM64=off -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
|
||||
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=on \
|
||||
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
|
||||
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
||||
-DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \
|
||||
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
|
||||
|
@ -718,8 +723,8 @@ build_lite_java_arm64() {
|
|||
cd ${BASEPATH}/output/
|
||||
rm -rf mindspore-lite-${VERSION_STR}-runtime-arm64-cpu
|
||||
tar -zxvf mindspore-lite-${VERSION_STR}-runtime-arm64-cpu.tar.gz
|
||||
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
mkdir -p ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/arm64-v8a/*
|
||||
cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-runtime-arm64-cpu/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-runtime-arm64-cpu/lib/libmindspore-lite-fp16.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-runtime-arm64-cpu/lib/libmindspore-lite-optimize.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
|
@ -738,10 +743,10 @@ build_lite_java_arm32() {
|
|||
fi
|
||||
# copy arm32 so
|
||||
cd ${BASEPATH}/output/
|
||||
rm -rf mindspore-lite-${VERSION_STR}runtime-arm32-cpu
|
||||
rm -rf mindspore-lite-${VERSION_STR}-runtime-arm32-cpu
|
||||
tar -zxvf mindspore-lite-${VERSION_STR}-runtime-arm32-cpu.tar.gz
|
||||
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
mkdir -p ${JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/armeabi-v7a/*
|
||||
cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-runtime-arm32-cpu/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
[ -n "${VERSION_STR}" ] && rm -rf mindspore-lite-${VERSION_STR}-runtime-arm32-cpu
|
||||
}
|
||||
|
|
|
@ -35,7 +35,16 @@ class MS_API LiteSession {
|
|||
/// \param[in] context Define the context of session to be created.
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite LiteSession.
|
||||
static LiteSession *CreateSession(lite::Context *context);
|
||||
static LiteSession *CreateSession(const lite::Context *context);
|
||||
|
||||
/// \brief Static method to create a LiteSession pointer which has already compiled a model.
|
||||
///
|
||||
/// \param[in] model_buf Define the buffer read from a model file.
|
||||
/// \param[in] size Define bytes number of model buffer.
|
||||
/// \param[in] context Define the context of session to be created.
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite LiteSession.
|
||||
static LiteSession *CreateSession(const char *model_buf, size_t size, const lite::Context *context);
|
||||
|
||||
/// \brief Destructor of MindSpore Lite LiteSession.
|
||||
virtual ~LiteSession() = default;
|
||||
|
|
|
@ -27,6 +27,7 @@ set(LITE_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/tensor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/executor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/model_common.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "src/common/utils.h"
|
||||
#include "src/common/graph_util.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/model_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -284,6 +285,12 @@ int LiteSession::CompileGraph(Model *model) {
|
|||
return ret;
|
||||
}
|
||||
ret = executor->Prepare(this->kernels_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare executor failed: " << ret;
|
||||
is_running_.store(false);
|
||||
return ret;
|
||||
}
|
||||
ret = PrepareKernels();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare kernels failed: " << ret;
|
||||
is_running_.store(false);
|
||||
|
@ -293,6 +300,17 @@ int LiteSession::CompileGraph(Model *model) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int LiteSession::PrepareKernels() {
|
||||
for (auto kernel : this->kernels_) {
|
||||
auto ret = kernel->Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare kernel " << kernel->name() << " failed: " << ret;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() const { return this->input_vec_; }
|
||||
|
||||
int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) {
|
||||
|
@ -312,7 +330,7 @@ int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &af
|
|||
return ret;
|
||||
}
|
||||
|
||||
int LiteSession::Init(Context *context) {
|
||||
int LiteSession::Init(const Context *context) {
|
||||
bool expected = false;
|
||||
if (!is_running_.compare_exchange_strong(expected, true)) {
|
||||
MS_LOG(ERROR) << "Not support multi-threading";
|
||||
|
@ -508,7 +526,7 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
|
|||
}
|
||||
} // namespace lite
|
||||
|
||||
session::LiteSession *session::LiteSession::CreateSession(lite::Context *context) {
|
||||
session::LiteSession *session::LiteSession::CreateSession(const lite::Context *context) {
|
||||
auto session = new lite::LiteSession();
|
||||
auto ret = session->Init(context);
|
||||
if (ret != mindspore::lite::RET_OK) {
|
||||
|
@ -518,4 +536,26 @@ session::LiteSession *session::LiteSession::CreateSession(lite::Context *context
|
|||
}
|
||||
return session;
|
||||
}
|
||||
|
||||
session::LiteSession *session::LiteSession::CreateSession(const char *model_buf, size_t size,
|
||||
const lite::Context *context) {
|
||||
auto *session = LiteSession::CreateSession(context);
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "Create sesssion failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto *model = lite::ImportFromBuffer(model_buf, size, true);
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "Import model failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = session->CompileGraph(model);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Compile model failed";
|
||||
return nullptr;
|
||||
}
|
||||
model->buf = nullptr;
|
||||
delete (model);
|
||||
return session;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,7 +42,7 @@ class LiteSession : public session::LiteSession {
|
|||
|
||||
~LiteSession() override;
|
||||
|
||||
virtual int Init(Context *context);
|
||||
virtual int Init(const Context *context);
|
||||
|
||||
void BindThread(bool if_bind) override;
|
||||
|
||||
|
@ -86,6 +86,8 @@ class LiteSession : public session::LiteSession {
|
|||
|
||||
int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims);
|
||||
|
||||
int PrepareKernels();
|
||||
|
||||
private:
|
||||
void ResetInputsShape(const std::vector<std::vector<int>> &dims);
|
||||
|
||||
|
|
|
@ -16,124 +16,10 @@
|
|||
#include "src/ops/primitive_c.h"
|
||||
#include "include/model.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/graph_util.h"
|
||||
#include "include/version.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "src/model_common.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) {
|
||||
for (size_t i = 0; i < meta_graph->nodes()->size(); ++i) {
|
||||
Model::Node *node = new (std::nothrow) Model::Node();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "new node fail!";
|
||||
return false;
|
||||
}
|
||||
auto c_node = meta_graph->nodes()->GetAs<schema::CNode>(i);
|
||||
auto src_prim = c_node->primitive();
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
|
||||
#else
|
||||
auto primitive = const_cast<schema::Primitive *>(src_prim);
|
||||
node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive);
|
||||
#endif
|
||||
if (node->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "unpack primitive == nullptr!";
|
||||
delete node;
|
||||
return false;
|
||||
}
|
||||
node->primitive_->SetQuantType(c_node->quantType());
|
||||
node->name_ = c_node->name()->c_str();
|
||||
node->node_type_ = c_node->nodeType();
|
||||
auto count = c_node->inputIndex()->size();
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
node->input_indices_.push_back(size_t(c_node->inputIndex()->GetAs<uint32_t>(j)));
|
||||
}
|
||||
if (c_node->outputIndex() != nullptr) {
|
||||
count = c_node->outputIndex()->size();
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
node->output_indices_.push_back(size_t(c_node->outputIndex()->GetAs<uint32_t>(j)));
|
||||
}
|
||||
}
|
||||
model->nodes_.push_back(node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model) {
|
||||
auto tensor_count = meta_graph->allTensors()->size();
|
||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||
auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << i << "th tensor in model is nullptr";
|
||||
return false;
|
||||
}
|
||||
model->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Model *Model::Import(const char *model_buf, size_t size) {
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "The model buf is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
|
||||
if (!schema::VerifyMetaGraphBuffer(verify)) {
|
||||
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
|
||||
return nullptr;
|
||||
}
|
||||
auto *model = new (std::nothrow) Model();
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "new model fail!";
|
||||
return nullptr;
|
||||
}
|
||||
model->buf = reinterpret_cast<char *>(malloc(size));
|
||||
if (model->buf == nullptr) {
|
||||
MS_LOG(ERROR) << "new inner model buf fail!";
|
||||
delete (model);
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(model->buf, model_buf, size);
|
||||
auto meta_graph = schema::GetMetaGraph(model->buf);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "meta_graph is nullptr!";
|
||||
delete (model);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (meta_graph->name() != nullptr) {
|
||||
model->name_ = meta_graph->name()->c_str();
|
||||
}
|
||||
if (meta_graph->version() != nullptr) {
|
||||
model->version_ = meta_graph->version()->c_str();
|
||||
}
|
||||
|
||||
if (model->version_ != Version()) {
|
||||
MS_LOG(WARNING) << "model version is " << model->version_ << ", inference version is " << Version() << " not equal";
|
||||
}
|
||||
|
||||
auto in_count = meta_graph->inputIndex()->size();
|
||||
for (uint32_t i = 0; i < in_count; ++i) {
|
||||
model->input_indices_.push_back(size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i)));
|
||||
}
|
||||
|
||||
auto out_count = meta_graph->outputIndex()->size();
|
||||
for (uint32_t i = 0; i < out_count; ++i) {
|
||||
model->output_indices_.push_back(size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i)));
|
||||
}
|
||||
if (!ConvertNodes(meta_graph, model)) {
|
||||
delete model;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ConvertTensors(meta_graph, model)) {
|
||||
delete model;
|
||||
return nullptr;
|
||||
}
|
||||
return model;
|
||||
}
|
||||
Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
|
||||
|
||||
void Model::Free() {
|
||||
if (this->buf != nullptr) {
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/model_common.h"
|
||||
#include "include/version.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) {
|
||||
for (size_t i = 0; i < meta_graph->nodes()->size(); ++i) {
|
||||
Model::Node *node = new (std::nothrow) Model::Node();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "new node fail!";
|
||||
return false;
|
||||
}
|
||||
auto c_node = meta_graph->nodes()->GetAs<schema::CNode>(i);
|
||||
auto src_prim = c_node->primitive();
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
|
||||
#else
|
||||
auto primitive = const_cast<schema::Primitive *>(src_prim);
|
||||
node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive);
|
||||
#endif
|
||||
if (node->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "unpack primitive == nullptr!";
|
||||
delete node;
|
||||
return false;
|
||||
}
|
||||
node->primitive_->SetQuantType(c_node->quantType());
|
||||
node->name_ = c_node->name()->c_str();
|
||||
node->node_type_ = c_node->nodeType();
|
||||
auto count = c_node->inputIndex()->size();
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
node->input_indices_.push_back(size_t(c_node->inputIndex()->GetAs<uint32_t>(j)));
|
||||
}
|
||||
if (c_node->outputIndex() != nullptr) {
|
||||
count = c_node->outputIndex()->size();
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
node->output_indices_.push_back(size_t(c_node->outputIndex()->GetAs<uint32_t>(j)));
|
||||
}
|
||||
}
|
||||
model->nodes_.push_back(node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model) {
|
||||
auto tensor_count = meta_graph->allTensors()->size();
|
||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||
auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << i << "th tensor in model is nullptr";
|
||||
return false;
|
||||
}
|
||||
model->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "The model buf is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
|
||||
if (!schema::VerifyMetaGraphBuffer(verify)) {
|
||||
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
|
||||
return nullptr;
|
||||
}
|
||||
auto *model = new (std::nothrow) Model();
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "new model fail!";
|
||||
return nullptr;
|
||||
}
|
||||
if (take_buf) {
|
||||
model->buf = const_cast<char *>(model_buf);
|
||||
} else {
|
||||
model->buf = reinterpret_cast<char *>(malloc(size));
|
||||
if (model->buf == nullptr) {
|
||||
MS_LOG(ERROR) << "new inner model buf fail!";
|
||||
delete (model);
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(model->buf, model_buf, size);
|
||||
}
|
||||
|
||||
auto meta_graph = schema::GetMetaGraph(model->buf);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "meta_graph is nullptr!";
|
||||
delete (model);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (meta_graph->name() != nullptr) {
|
||||
model->name_ = meta_graph->name()->c_str();
|
||||
}
|
||||
if (meta_graph->version() != nullptr) {
|
||||
model->version_ = meta_graph->version()->c_str();
|
||||
}
|
||||
|
||||
if (model->version_ != Version()) {
|
||||
MS_LOG(WARNING) << "model version is " << model->version_ << ", inference version is " << Version() << " not equal";
|
||||
}
|
||||
|
||||
auto in_count = meta_graph->inputIndex()->size();
|
||||
for (uint32_t i = 0; i < in_count; ++i) {
|
||||
model->input_indices_.push_back(size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i)));
|
||||
}
|
||||
|
||||
auto out_count = meta_graph->outputIndex()->size();
|
||||
for (uint32_t i = 0; i < out_count; ++i) {
|
||||
model->output_indices_.push_back(size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i)));
|
||||
}
|
||||
if (!ConvertNodes(meta_graph, model)) {
|
||||
delete model;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ConvertTensors(meta_graph, model)) {
|
||||
delete model;
|
||||
return nullptr;
|
||||
}
|
||||
return model;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_MODEL_COMMON_H_
|
||||
#define MINDSPORE_LITE_SRC_MODEL_COMMON_H_
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "include/model.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model);
|
||||
|
||||
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model);
|
||||
|
||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_MODEL_COMMON_H_
|
|
@ -77,6 +77,7 @@ int CropFp16CPUKernel::Run() {
|
|||
auto ret = ParallelLaunch(this->context_->thread_pool_, CropFp16Run, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ParallelLaunch failed: " << ret;
|
||||
FreeInputAndOutput();
|
||||
return ret;
|
||||
}
|
||||
if (out_tensors_.at(kOutputIndex)->data_type() == kNumberTypeFloat32) {
|
||||
|
|
|
@ -280,12 +280,6 @@ int DeConvWinogradFp16CPUKernel::Init() {
|
|||
}
|
||||
|
||||
int DeConvWinogradFp16CPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
|
||||
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
|
||||
|
||||
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
|
||||
|
|
|
@ -113,12 +113,6 @@ int QuantDTypeCastRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int QuantDTypeCastFp16CPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
|
||||
if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 &&
|
||||
out_tensors_[0]->data_type() == TypeId::kNumberTypeFloat16) {
|
||||
int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c());
|
||||
|
|
|
@ -330,11 +330,6 @@ int DeConvolutionWinogradCPUKernel::DeDeconvPost(int task_id) {
|
|||
}
|
||||
|
||||
int DeConvolutionWinogradCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
float *src_in = reinterpret_cast<float *>(in_tensors_[0]->data_c());
|
||||
float *src_out = reinterpret_cast<float *>(out_tensors_[0]->data_c());
|
||||
|
||||
|
|
|
@ -38,12 +38,6 @@ int LshProjectionCPUKernel::Init() {
|
|||
int LshProjectionCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int LshProjectionCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto input_tensor0 = in_tensors_.at(0);
|
||||
auto input_tensor1 = in_tensors_.at(1);
|
||||
auto out_tensor0 = out_tensors_.at(0);
|
||||
|
@ -65,7 +59,7 @@ int LshProjectionCPUKernel::Run() {
|
|||
|
||||
elements_num_ = input_tensor0->DimensionSize(0);
|
||||
count_unit_ = thread_num_ > 1 ? UP_DIV(elements_num_, thread_num_) : elements_num_;
|
||||
ret = ParallelLaunch(this->context_->thread_pool_, LshProjectionRun, this, thread_num_);
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, LshProjectionRun, this, thread_num_);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -60,11 +60,6 @@ void ParseSentenceToWords(const StringPack &sentence, std::vector<StringPack> *w
|
|||
}
|
||||
|
||||
int SkipGramCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
skip_gram_parameter_ = reinterpret_cast<SkipGramParameter *>(op_parameter_);
|
||||
if (skip_gram_parameter_->ngram_size < 1) {
|
||||
MS_LOG(ERROR) << "Skip Gram Parameter Error, NgramSize should be at least 1, get "
|
||||
|
@ -105,7 +100,7 @@ int SkipGramCPUKernel::Run() {
|
|||
index--;
|
||||
}
|
||||
}
|
||||
ret = mindspore::lite::WriteSeperatedStringsToTensor(out_tensors_[0], result);
|
||||
auto ret = mindspore::lite::WriteSeperatedStringsToTensor(out_tensors_[0], result);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -79,12 +79,6 @@ int AdamRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int AdamCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "AdamCPUKernel Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Adam function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -65,12 +65,6 @@ int ApplyMomentumRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int ApplyMomentumCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ApplyMomentumCPUKernel Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, ApplyMomentumRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Apply Momentum function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -202,11 +202,6 @@ int ArithmeticGradRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int ArithmeticGradCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ArithmeticGradCPUKernel Prepare failed.";
|
||||
return ret;
|
||||
}
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, ArithmeticGradRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Arithmetic Grad function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -52,12 +52,6 @@ int AssignRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int AssignCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "AssignCPUKernel Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Assign function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -76,11 +76,6 @@ int BiasGradRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int BiasGradCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "BiasGradCPUKernel Prepare failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, BiasGradRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "bias function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -88,12 +88,6 @@ int BNGradRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int BNGradCPUKernel::Run() {
|
||||
// std::cout << "run succ" << std::endl;
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "BNGradCPUKernel Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -115,11 +115,6 @@ int ConvolutionTrainRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int ConvolutionTrainCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionTrainCPUKernel Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionTrainRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv train function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -117,11 +117,6 @@ int ConvolutionGradFilterRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int ConvolutionGradFilterCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionGradFilterCPUKernel Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradFilterRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv filter function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -115,12 +115,6 @@ int ConvolutionGradInputRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int ConvolutionGradInputCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionGradInputCPUKernel Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradInputRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "bias function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -113,12 +113,6 @@ int DeConvolutionGradFilterRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int DeConvolutionGradFilterCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, DeConvolutionGradFilterRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv filter function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -88,12 +88,6 @@ int PoolingGradImpl(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int PoolingGradCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "PoolingGradCPUKernel Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
|
||||
// clear output buffer before parallel run
|
||||
PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(op_parameter_);
|
||||
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
|
||||
|
|
|
@ -69,11 +69,6 @@ int PowerGradRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int PowerGradCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "PowerGradCPUKernel Prepare failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, PowerGradRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "power grad function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -65,12 +65,6 @@ int SgdRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int SgdCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SgdCPUKernel Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, SgdRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "SGD function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -91,12 +91,6 @@ int SoftmaxCrossEntropyWithLogitsRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int SoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SoftmaxCrossEntropyWithLogitsCPUKernel Prepare failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, SoftmaxCrossEntropyWithLogitsRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "SoftmaxCrossEntropy function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -79,12 +79,6 @@ int SoftmaxGradRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int SoftmaxGradCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SoftmaxGradCPUKernel Prepare failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, SoftmaxGradRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "SoftmaxGradRun function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -118,11 +118,6 @@ int SparseSoftmaxCrossEntropyRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SparseSoftmaxCrossEntropyWithLogitsCPUKernel Prepare failed.";
|
||||
return ret;
|
||||
}
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, SparseSoftmaxCrossEntropyRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "SparseSoftmaxCrossEntropy function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -63,11 +63,6 @@ int TupleRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int TupleGetItemCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "TupleGetItemCPUKernel Prepare failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, TupleRun, this, 1);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "tuple function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -150,12 +150,7 @@ int ConvolutionDepthwise3x3Int8CPUKernel::InitBuffer() {
|
|||
}
|
||||
|
||||
int ConvolutionDepthwise3x3Int8CPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = InitBuffer();
|
||||
auto ret = InitBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Depthwise int8 ReSize error!";
|
||||
return ret;
|
||||
|
|
|
@ -100,11 +100,6 @@ void NormalizeCPUKernel::FreeBuffer() {
|
|||
}
|
||||
|
||||
int NormalizeCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
|
||||
return ret;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
int string_num = lite::GetStringCount(input_tensor);
|
||||
std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(input_tensor);
|
||||
|
|
|
@ -73,11 +73,6 @@ std::vector<LabelInfo> PredictCPUKernel::GetLabelInfo() {
|
|||
static bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; }
|
||||
|
||||
int PredictCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
|
||||
return ret;
|
||||
}
|
||||
std::vector<LabelInfo> label_info_vec = GetLabelInfo();
|
||||
std::sort(label_info_vec.begin(), label_info_vec.end(), LabelInfoCmp);
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ class SubGraphOpenCLKernel : public SubGraphKernel {
|
|||
: SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx) {
|
||||
ocl_runtime_ = ocl_runtime_wrap_.GetInstance();
|
||||
subgraph_type_ = kGpuSubGraph;
|
||||
this->name_ = "GpuSubGraph";
|
||||
this->executor_ = new lite::opencl::OpenCLExecutor();
|
||||
}
|
||||
~SubGraphOpenCLKernel() override;
|
||||
|
|
|
@ -35,9 +35,6 @@ using kernel::KERNEL_ARCH::kGPU;
|
|||
|
||||
int Scheduler::Schedule(const lite::Model *model, std::vector<Tensor *> *tensors,
|
||||
std::vector<kernel::LiteKernel *> *kernels) {
|
||||
// 1. op ---> kernel
|
||||
// 2. sub graph
|
||||
// 3. kernels (kernels --> subGraph)
|
||||
int ret = InferShape(model, tensors);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "op infer shape failed.";
|
||||
|
|
|
@ -68,6 +68,7 @@ class CpuFp32SubGraph : public SubGraphKernel {
|
|||
const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx)
|
||||
: SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) {
|
||||
subgraph_type_ = kCpuFP32SubGraph;
|
||||
this->name_ = "CpuFP32SubGraph";
|
||||
this->executor_ = new mindspore::lite::Executor;
|
||||
}
|
||||
|
||||
|
@ -88,6 +89,7 @@ class CpuFp16SubGraph : public SubGraphKernel {
|
|||
const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx)
|
||||
: SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) {
|
||||
subgraph_type_ = kCpuFP16SubGraph;
|
||||
this->name_ = "CpuFP16SubGraph";
|
||||
this->executor_ = new mindspore::lite::Executor;
|
||||
}
|
||||
|
||||
|
|
|
@ -120,6 +120,7 @@ set(TEST_LITE_SRC
|
|||
${LITE_DIR}/src/lite_session.cc
|
||||
${LITE_DIR}/src/sub_graph_kernel.cc
|
||||
${LITE_DIR}/src/model.cc
|
||||
${LITE_DIR}/src/model_common.cc
|
||||
${LITE_DIR}/src/populate_parameter.cc
|
||||
${LITE_DIR}/src/scheduler.cc
|
||||
${LITE_DIR}/src/common/graph_util.cc
|
||||
|
|
|
@ -72,6 +72,7 @@ set(LITE_SRC
|
|||
${SRC_DIR}/lite_session.cc
|
||||
${SRC_DIR}/executor.cc
|
||||
${SRC_DIR}/model.cc
|
||||
${SRC_DIR}/model_common.cc
|
||||
)
|
||||
if (SUPPORT_TRAIN)
|
||||
set(LITE_SRC
|
||||
|
|
Loading…
Reference in New Issue