remove unused Prepare calling in operator & add CreateSession(const char

*model_buf, size_t size, lite::Context *context) interface
This commit is contained in:
hangq 2020-10-21 14:23:20 +08:00
parent e805051c1f
commit e19a3e3926
39 changed files with 250 additions and 275 deletions

View File

@ -266,12 +266,15 @@ checkopts()
COMPILE_LITE="on" COMPILE_LITE="on"
if [[ "$OPTARG" == "arm64" ]]; then if [[ "$OPTARG" == "arm64" ]]; then
ENABLE_CONVERTER="off" ENABLE_CONVERTER="off"
RUN_TESTCASES="on"
LITE_PLATFORM="arm64" LITE_PLATFORM="arm64"
elif [[ "$OPTARG" == "arm32" ]]; then elif [[ "$OPTARG" == "arm32" ]]; then
ENABLE_CONVERTER="off" ENABLE_CONVERTER="off"
RUN_TESTCASES="on"
LITE_PLATFORM="arm32" LITE_PLATFORM="arm32"
elif [[ "$OPTARG" == "x86_64" ]]; then elif [[ "$OPTARG" == "x86_64" ]]; then
ENABLE_CONVERTER="on" ENABLE_CONVERTER="on"
RUN_TESTCASES="on"
LITE_PLATFORM="x86_64" LITE_PLATFORM="x86_64"
else else
echo "-I parameter must be arm64、arm32 or x86_64" echo "-I parameter must be arm64、arm32 or x86_64"
@ -315,7 +318,7 @@ checkopts()
elif [[ "$OPTARG" == "object-c" ]]; then elif [[ "$OPTARG" == "object-c" ]]; then
LITE_LANGUAGE="object-c" LITE_LANGUAGE="object-c"
else else
echo "-A parameter must be cppjava or object-c" echo "-A parameter must be cpp, java or object-c"
exit 1 exit 1
fi fi
;; ;;
@ -628,9 +631,9 @@ build_minddata_lite_deps()
} }
get_version() { get_version() {
VERSION_MAJOR=`grep "const int ms_version_major =" ${BASEPATH}/mindspore/lite/include/version.h | tr -dc "[0-9]"` VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/mindspore/lite/include/version.h | tr -dc "[0-9]")
VERSION_MINOR=`grep "const int ms_version_minor =" ${BASEPATH}/mindspore/lite/include/version.h | tr -dc "[0-9]"` VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/mindspore/lite/include/version.h | tr -dc "[0-9]")
VERSION_REVISION=`grep "const int ms_version_revision =" ${BASEPATH}/mindspore/lite/include/version.h | tr -dc "[0-9]"` VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/mindspore/lite/include/version.h | tr -dc "[0-9]")
VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION} VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION}
} }
@ -642,7 +645,9 @@ build_lite()
echo "start build opencl" echo "start build opencl"
build_opencl build_opencl
fi fi
if [ "${RUN_TESTCASES}" == "on" ]; then
build_gtest build_gtest
fi
if [ "${COMPILE_MINDDATA_LITE}" == "lite" ] || [ "${COMPILE_MINDDATA_LITE}" == "full" ]; then if [ "${COMPILE_MINDDATA_LITE}" == "lite" ] || [ "${COMPILE_MINDDATA_LITE}" == "full" ]; then
build_minddata_lite_deps build_minddata_lite_deps
@ -665,7 +670,7 @@ build_lite()
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \
-DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ -DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
-DPLATFORM_ARM64=on -DENABLE_NEON=on -DENABLE_FP16="off" \ -DPLATFORM_ARM64=on -DENABLE_NEON=on -DENABLE_FP16="off" \
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=on \ -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
-DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ -DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
-DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \ -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \
-DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \
@ -676,14 +681,14 @@ build_lite()
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \ -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \
-DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ -DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=on \ -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
-DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ -DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
-DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \ -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \
-DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \
"${BASEPATH}/mindspore/lite" "${BASEPATH}/mindspore/lite"
else else
cmake -DPLATFORM_ARM64=off -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ cmake -DPLATFORM_ARM64=off -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=on \ -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
-DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \ -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
@ -718,8 +723,8 @@ build_lite_java_arm64() {
cd ${BASEPATH}/output/ cd ${BASEPATH}/output/
rm -rf mindspore-lite-${VERSION_STR}-runtime-arm64-cpu rm -rf mindspore-lite-${VERSION_STR}-runtime-arm64-cpu
tar -zxvf mindspore-lite-${VERSION_STR}-runtime-arm64-cpu.tar.gz tar -zxvf mindspore-lite-${VERSION_STR}-runtime-arm64-cpu.tar.gz
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/arm64-v8a/
mkdir -p ${JAVA_PATH}/java/app/libs/arm64-v8a/ mkdir -p ${JAVA_PATH}/java/app/libs/arm64-v8a/
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/arm64-v8a/*
cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-runtime-arm64-cpu/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-runtime-arm64-cpu/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-runtime-arm64-cpu/lib/libmindspore-lite-fp16.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-runtime-arm64-cpu/lib/libmindspore-lite-fp16.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-runtime-arm64-cpu/lib/libmindspore-lite-optimize.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-runtime-arm64-cpu/lib/libmindspore-lite-optimize.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
@ -738,10 +743,10 @@ build_lite_java_arm32() {
fi fi
# copy arm32 so # copy arm32 so
cd ${BASEPATH}/output/ cd ${BASEPATH}/output/
rm -rf mindspore-lite-${VERSION_STR}runtime-arm32-cpu rm -rf mindspore-lite-${VERSION_STR}-runtime-arm32-cpu
tar -zxvf mindspore-lite-${VERSION_STR}-runtime-arm32-cpu.tar.gz tar -zxvf mindspore-lite-${VERSION_STR}-runtime-arm32-cpu.tar.gz
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/armeabi-v7a/
mkdir -p ${JAVA_PATH}/java/app/libs/armeabi-v7a/ mkdir -p ${JAVA_PATH}/java/app/libs/armeabi-v7a/
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/armeabi-v7a/*
cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-runtime-arm32-cpu/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-runtime-arm32-cpu/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
[ -n "${VERSION_STR}" ] && rm -rf mindspore-lite-${VERSION_STR}-runtime-arm32-cpu [ -n "${VERSION_STR}" ] && rm -rf mindspore-lite-${VERSION_STR}-runtime-arm32-cpu
} }

View File

@ -35,7 +35,16 @@ class MS_API LiteSession {
/// \param[in] context Define the context of session to be created. /// \param[in] context Define the context of session to be created.
/// ///
/// \return Pointer of MindSpore Lite LiteSession. /// \return Pointer of MindSpore Lite LiteSession.
static LiteSession *CreateSession(lite::Context *context); static LiteSession *CreateSession(const lite::Context *context);
/// \brief Static method to create a LiteSession pointer which has already compiled a model.
///
/// \param[in] model_buf Define the buffer read from a model file.
/// \param[in] size Define bytes number of model buffer.
/// \param[in] context Define the context of session to be created.
///
/// \return Pointer of MindSpore Lite LiteSession.
static LiteSession *CreateSession(const char *model_buf, size_t size, const lite::Context *context);
/// \brief Destructor of MindSpore Lite LiteSession. /// \brief Destructor of MindSpore Lite LiteSession.
virtual ~LiteSession() = default; virtual ~LiteSession() = default;

View File

@ -27,6 +27,7 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/tensor.cc ${CMAKE_CURRENT_SOURCE_DIR}/tensor.cc
${CMAKE_CURRENT_SOURCE_DIR}/executor.cc ${CMAKE_CURRENT_SOURCE_DIR}/executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc ${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_common.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc

View File

@ -26,6 +26,7 @@
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/common/graph_util.h" #include "src/common/graph_util.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/model_common.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -284,6 +285,12 @@ int LiteSession::CompileGraph(Model *model) {
return ret; return ret;
} }
ret = executor->Prepare(this->kernels_); ret = executor->Prepare(this->kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare executor failed: " << ret;
is_running_.store(false);
return ret;
}
ret = PrepareKernels();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare kernels failed: " << ret; MS_LOG(ERROR) << "Prepare kernels failed: " << ret;
is_running_.store(false); is_running_.store(false);
@ -293,6 +300,17 @@ int LiteSession::CompileGraph(Model *model) {
return RET_OK; return RET_OK;
} }
int LiteSession::PrepareKernels() {
for (auto kernel : this->kernels_) {
auto ret = kernel->Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare kernel " << kernel->name() << " failed: " << ret;
return ret;
}
}
return RET_OK;
}
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() const { return this->input_vec_; } std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() const { return this->input_vec_; }
int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) { int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) {
@ -312,7 +330,7 @@ int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &af
return ret; return ret;
} }
int LiteSession::Init(Context *context) { int LiteSession::Init(const Context *context) {
bool expected = false; bool expected = false;
if (!is_running_.compare_exchange_strong(expected, true)) { if (!is_running_.compare_exchange_strong(expected, true)) {
MS_LOG(ERROR) << "Not support multi-threading"; MS_LOG(ERROR) << "Not support multi-threading";
@ -508,7 +526,7 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
} }
} // namespace lite } // namespace lite
session::LiteSession *session::LiteSession::CreateSession(lite::Context *context) { session::LiteSession *session::LiteSession::CreateSession(const lite::Context *context) {
auto session = new lite::LiteSession(); auto session = new lite::LiteSession();
auto ret = session->Init(context); auto ret = session->Init(context);
if (ret != mindspore::lite::RET_OK) { if (ret != mindspore::lite::RET_OK) {
@ -518,4 +536,26 @@ session::LiteSession *session::LiteSession::CreateSession(lite::Context *context
} }
return session; return session;
} }
session::LiteSession *session::LiteSession::CreateSession(const char *model_buf, size_t size,
const lite::Context *context) {
auto *session = LiteSession::CreateSession(context);
if (session == nullptr) {
MS_LOG(ERROR) << "Create sesssion failed";
return nullptr;
}
auto *model = lite::ImportFromBuffer(model_buf, size, true);
if (model == nullptr) {
MS_LOG(ERROR) << "Import model failed";
return nullptr;
}
auto ret = session->CompileGraph(model);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Compile model failed";
return nullptr;
}
model->buf = nullptr;
delete (model);
return session;
}
} // namespace mindspore } // namespace mindspore

View File

@ -42,7 +42,7 @@ class LiteSession : public session::LiteSession {
~LiteSession() override; ~LiteSession() override;
virtual int Init(Context *context); virtual int Init(const Context *context);
void BindThread(bool if_bind) override; void BindThread(bool if_bind) override;
@ -86,6 +86,8 @@ class LiteSession : public session::LiteSession {
int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims); int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims);
int PrepareKernels();
private: private:
void ResetInputsShape(const std::vector<std::vector<int>> &dims); void ResetInputsShape(const std::vector<std::vector<int>> &dims);

View File

@ -16,124 +16,10 @@
#include "src/ops/primitive_c.h" #include "src/ops/primitive_c.h"
#include "include/model.h" #include "include/model.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "include/errorcode.h" #include "src/model_common.h"
#include "src/common/graph_util.h"
#include "include/version.h"
#include "src/ops/ops_register.h"
namespace mindspore::lite { namespace mindspore::lite {
Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) {
for (size_t i = 0; i < meta_graph->nodes()->size(); ++i) {
Model::Node *node = new (std::nothrow) Model::Node();
if (node == nullptr) {
MS_LOG(ERROR) << "new node fail!";
return false;
}
auto c_node = meta_graph->nodes()->GetAs<schema::CNode>(i);
auto src_prim = c_node->primitive();
#ifdef PRIMITIVE_WRITEABLE
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
#else
auto primitive = const_cast<schema::Primitive *>(src_prim);
node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive);
#endif
if (node->primitive_ == nullptr) {
MS_LOG(ERROR) << "unpack primitive == nullptr!";
delete node;
return false;
}
node->primitive_->SetQuantType(c_node->quantType());
node->name_ = c_node->name()->c_str();
node->node_type_ = c_node->nodeType();
auto count = c_node->inputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->input_indices_.push_back(size_t(c_node->inputIndex()->GetAs<uint32_t>(j)));
}
if (c_node->outputIndex() != nullptr) {
count = c_node->outputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->output_indices_.push_back(size_t(c_node->outputIndex()->GetAs<uint32_t>(j)));
}
}
model->nodes_.push_back(node);
}
return true;
}
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model) {
auto tensor_count = meta_graph->allTensors()->size();
for (uint32_t i = 0; i < tensor_count; ++i) {
auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
if (tensor == nullptr) {
MS_LOG(ERROR) << i << "th tensor in model is nullptr";
return false;
}
model->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
}
return true;
}
Model *Model::Import(const char *model_buf, size_t size) {
if (model_buf == nullptr) {
MS_LOG(ERROR) << "The model buf is nullptr";
return nullptr;
}
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
if (!schema::VerifyMetaGraphBuffer(verify)) {
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
return nullptr;
}
auto *model = new (std::nothrow) Model();
if (model == nullptr) {
MS_LOG(ERROR) << "new model fail!";
return nullptr;
}
model->buf = reinterpret_cast<char *>(malloc(size));
if (model->buf == nullptr) {
MS_LOG(ERROR) << "new inner model buf fail!";
delete (model);
return nullptr;
}
memcpy(model->buf, model_buf, size);
auto meta_graph = schema::GetMetaGraph(model->buf);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "meta_graph is nullptr!";
delete (model);
return nullptr;
}
if (meta_graph->name() != nullptr) {
model->name_ = meta_graph->name()->c_str();
}
if (meta_graph->version() != nullptr) {
model->version_ = meta_graph->version()->c_str();
}
if (model->version_ != Version()) {
MS_LOG(WARNING) << "model version is " << model->version_ << ", inference version is " << Version() << " not equal";
}
auto in_count = meta_graph->inputIndex()->size();
for (uint32_t i = 0; i < in_count; ++i) {
model->input_indices_.push_back(size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i)));
}
auto out_count = meta_graph->outputIndex()->size();
for (uint32_t i = 0; i < out_count; ++i) {
model->output_indices_.push_back(size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i)));
}
if (!ConvertNodes(meta_graph, model)) {
delete model;
return nullptr;
}
if (!ConvertTensors(meta_graph, model)) {
delete model;
return nullptr;
}
return model;
}
void Model::Free() { void Model::Free() {
if (this->buf != nullptr) { if (this->buf != nullptr) {

View File

@ -0,0 +1,138 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/model_common.h"
#include "include/version.h"
#include "src/ops/ops_register.h"
namespace mindspore::lite {
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) {
for (size_t i = 0; i < meta_graph->nodes()->size(); ++i) {
Model::Node *node = new (std::nothrow) Model::Node();
if (node == nullptr) {
MS_LOG(ERROR) << "new node fail!";
return false;
}
auto c_node = meta_graph->nodes()->GetAs<schema::CNode>(i);
auto src_prim = c_node->primitive();
#ifdef PRIMITIVE_WRITEABLE
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
#else
auto primitive = const_cast<schema::Primitive *>(src_prim);
node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive);
#endif
if (node->primitive_ == nullptr) {
MS_LOG(ERROR) << "unpack primitive == nullptr!";
delete node;
return false;
}
node->primitive_->SetQuantType(c_node->quantType());
node->name_ = c_node->name()->c_str();
node->node_type_ = c_node->nodeType();
auto count = c_node->inputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->input_indices_.push_back(size_t(c_node->inputIndex()->GetAs<uint32_t>(j)));
}
if (c_node->outputIndex() != nullptr) {
count = c_node->outputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->output_indices_.push_back(size_t(c_node->outputIndex()->GetAs<uint32_t>(j)));
}
}
model->nodes_.push_back(node);
}
return true;
}
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model) {
auto tensor_count = meta_graph->allTensors()->size();
for (uint32_t i = 0; i < tensor_count; ++i) {
auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
if (tensor == nullptr) {
MS_LOG(ERROR) << i << "th tensor in model is nullptr";
return false;
}
model->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
}
return true;
}
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
if (model_buf == nullptr) {
MS_LOG(ERROR) << "The model buf is nullptr";
return nullptr;
}
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
if (!schema::VerifyMetaGraphBuffer(verify)) {
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
return nullptr;
}
auto *model = new (std::nothrow) Model();
if (model == nullptr) {
MS_LOG(ERROR) << "new model fail!";
return nullptr;
}
if (take_buf) {
model->buf = const_cast<char *>(model_buf);
} else {
model->buf = reinterpret_cast<char *>(malloc(size));
if (model->buf == nullptr) {
MS_LOG(ERROR) << "new inner model buf fail!";
delete (model);
return nullptr;
}
memcpy(model->buf, model_buf, size);
}
auto meta_graph = schema::GetMetaGraph(model->buf);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "meta_graph is nullptr!";
delete (model);
return nullptr;
}
if (meta_graph->name() != nullptr) {
model->name_ = meta_graph->name()->c_str();
}
if (meta_graph->version() != nullptr) {
model->version_ = meta_graph->version()->c_str();
}
if (model->version_ != Version()) {
MS_LOG(WARNING) << "model version is " << model->version_ << ", inference version is " << Version() << " not equal";
}
auto in_count = meta_graph->inputIndex()->size();
for (uint32_t i = 0; i < in_count; ++i) {
model->input_indices_.push_back(size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i)));
}
auto out_count = meta_graph->outputIndex()->size();
for (uint32_t i = 0; i < out_count; ++i) {
model->output_indices_.push_back(size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i)));
}
if (!ConvertNodes(meta_graph, model)) {
delete model;
return nullptr;
}
if (!ConvertTensors(meta_graph, model)) {
delete model;
return nullptr;
}
return model;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_MODEL_COMMON_H_
#define MINDSPORE_LITE_SRC_MODEL_COMMON_H_
#include "src/ops/primitive_c.h"
#include "include/model.h"
namespace mindspore::lite {
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model);
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model);
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_MODEL_COMMON_H_

View File

@ -77,6 +77,7 @@ int CropFp16CPUKernel::Run() {
auto ret = ParallelLaunch(this->context_->thread_pool_, CropFp16Run, this, thread_count_); auto ret = ParallelLaunch(this->context_->thread_pool_, CropFp16Run, this, thread_count_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ParallelLaunch failed: " << ret; MS_LOG(ERROR) << "ParallelLaunch failed: " << ret;
FreeInputAndOutput();
return ret; return ret;
} }
if (out_tensors_.at(kOutputIndex)->data_type() == kNumberTypeFloat32) { if (out_tensors_.at(kOutputIndex)->data_type() == kNumberTypeFloat32) {

View File

@ -280,12 +280,6 @@ int DeConvWinogradFp16CPUKernel::Init() {
} }
int DeConvWinogradFp16CPUKernel::Run() { int DeConvWinogradFp16CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {

View File

@ -113,12 +113,6 @@ int QuantDTypeCastRun(void *cdata, int task_id) {
} }
int QuantDTypeCastFp16CPUKernel::Run() { int QuantDTypeCastFp16CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 &&
out_tensors_[0]->data_type() == TypeId::kNumberTypeFloat16) { out_tensors_[0]->data_type() == TypeId::kNumberTypeFloat16) {
int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c()); int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c());

View File

@ -330,11 +330,6 @@ int DeConvolutionWinogradCPUKernel::DeDeconvPost(int task_id) {
} }
int DeConvolutionWinogradCPUKernel::Run() { int DeConvolutionWinogradCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
float *src_in = reinterpret_cast<float *>(in_tensors_[0]->data_c()); float *src_in = reinterpret_cast<float *>(in_tensors_[0]->data_c());
float *src_out = reinterpret_cast<float *>(out_tensors_[0]->data_c()); float *src_out = reinterpret_cast<float *>(out_tensors_[0]->data_c());

View File

@ -38,12 +38,6 @@ int LshProjectionCPUKernel::Init() {
int LshProjectionCPUKernel::ReSize() { return RET_OK; } int LshProjectionCPUKernel::ReSize() { return RET_OK; }
int LshProjectionCPUKernel::Run() { int LshProjectionCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret;
}
auto input_tensor0 = in_tensors_.at(0); auto input_tensor0 = in_tensors_.at(0);
auto input_tensor1 = in_tensors_.at(1); auto input_tensor1 = in_tensors_.at(1);
auto out_tensor0 = out_tensors_.at(0); auto out_tensor0 = out_tensors_.at(0);
@ -65,7 +59,7 @@ int LshProjectionCPUKernel::Run() {
elements_num_ = input_tensor0->DimensionSize(0); elements_num_ = input_tensor0->DimensionSize(0);
count_unit_ = thread_num_ > 1 ? UP_DIV(elements_num_, thread_num_) : elements_num_; count_unit_ = thread_num_ > 1 ? UP_DIV(elements_num_, thread_num_) : elements_num_;
ret = ParallelLaunch(this->context_->thread_pool_, LshProjectionRun, this, thread_num_); auto ret = ParallelLaunch(this->context_->thread_pool_, LshProjectionRun, this, thread_num_);
return ret; return ret;
} }

View File

@ -60,11 +60,6 @@ void ParseSentenceToWords(const StringPack &sentence, std::vector<StringPack> *w
} }
int SkipGramCPUKernel::Run() { int SkipGramCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret;
}
skip_gram_parameter_ = reinterpret_cast<SkipGramParameter *>(op_parameter_); skip_gram_parameter_ = reinterpret_cast<SkipGramParameter *>(op_parameter_);
if (skip_gram_parameter_->ngram_size < 1) { if (skip_gram_parameter_->ngram_size < 1) {
MS_LOG(ERROR) << "Skip Gram Parameter Error, NgramSize should be at least 1, get " MS_LOG(ERROR) << "Skip Gram Parameter Error, NgramSize should be at least 1, get "
@ -105,7 +100,7 @@ int SkipGramCPUKernel::Run() {
index--; index--;
} }
} }
ret = mindspore::lite::WriteSeperatedStringsToTensor(out_tensors_[0], result); auto ret = mindspore::lite::WriteSeperatedStringsToTensor(out_tensors_[0], result);
return ret; return ret;
} }

View File

@ -79,12 +79,6 @@ int AdamRun(void *cdata, int task_id) {
} }
int AdamCPUKernel::Run() { int AdamCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "AdamCPUKernel Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "Adam function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "Adam function error error_code[" << error_code << "]";

View File

@ -65,12 +65,6 @@ int ApplyMomentumRun(void *cdata, int task_id) {
} }
int ApplyMomentumCPUKernel::Run() { int ApplyMomentumCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "ApplyMomentumCPUKernel Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, ApplyMomentumRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, ApplyMomentumRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "Apply Momentum function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "Apply Momentum function error error_code[" << error_code << "]";

View File

@ -202,11 +202,6 @@ int ArithmeticGradRun(void *cdata, int task_id) {
} }
int ArithmeticGradCPUKernel::Run() { int ArithmeticGradCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticGradCPUKernel Prepare failed.";
return ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, ArithmeticGradRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, ArithmeticGradRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "Arithmetic Grad function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "Arithmetic Grad function error error_code[" << error_code << "]";

View File

@ -52,12 +52,6 @@ int AssignRun(void *cdata, int task_id) {
} }
int AssignCPUKernel::Run() { int AssignCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "AssignCPUKernel Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "Assign function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "Assign function error error_code[" << error_code << "]";

View File

@ -76,11 +76,6 @@ int BiasGradRun(void *cdata, int task_id) {
} }
int BiasGradCPUKernel::Run() { int BiasGradCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "BiasGradCPUKernel Prepare failed.";
return RET_ERROR;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, BiasGradRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, BiasGradRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "bias function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "bias function error error_code[" << error_code << "]";

View File

@ -88,12 +88,6 @@ int BNGradRun(void *cdata, int task_id) {
} }
int BNGradCPUKernel::Run() { int BNGradCPUKernel::Run() {
// std::cout << "run succ" << std::endl;
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "BNGradCPUKernel Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]";

View File

@ -115,11 +115,6 @@ int ConvolutionTrainRun(void *cdata, int task_id) {
} }
int ConvolutionTrainCPUKernel::Run() { int ConvolutionTrainCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionTrainCPUKernel Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionTrainRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionTrainRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv train function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "conv train function error error_code[" << error_code << "]";

View File

@ -117,11 +117,6 @@ int ConvolutionGradFilterRun(void *cdata, int task_id) {
} }
int ConvolutionGradFilterCPUKernel::Run() { int ConvolutionGradFilterCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionGradFilterCPUKernel Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradFilterRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradFilterRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv filter function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "conv filter function error error_code[" << error_code << "]";

View File

@ -115,12 +115,6 @@ int ConvolutionGradInputRun(void *cdata, int task_id) {
} }
int ConvolutionGradInputCPUKernel::Run() { int ConvolutionGradInputCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionGradInputCPUKernel Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradInputRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradInputRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "bias function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "bias function error error_code[" << error_code << "]";

View File

@ -113,12 +113,6 @@ int DeConvolutionGradFilterRun(void *cdata, int task_id) {
} }
int DeConvolutionGradFilterCPUKernel::Run() { int DeConvolutionGradFilterCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, DeConvolutionGradFilterRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, DeConvolutionGradFilterRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv filter function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "conv filter function error error_code[" << error_code << "]";

View File

@ -88,12 +88,6 @@ int PoolingGradImpl(void *cdata, int task_id) {
} }
int PoolingGradCPUKernel::Run() { int PoolingGradCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "PoolingGradCPUKernel Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
// clear output buffer before parallel run // clear output buffer before parallel run
PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(op_parameter_); PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(op_parameter_);
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());

View File

@ -69,11 +69,6 @@ int PowerGradRun(void *cdata, int task_id) {
} }
int PowerGradCPUKernel::Run() { int PowerGradCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "PowerGradCPUKernel Prepare failed.";
return RET_ERROR;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, PowerGradRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, PowerGradRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "power grad function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "power grad function error error_code[" << error_code << "]";

View File

@ -65,12 +65,6 @@ int SgdRun(void *cdata, int task_id) {
} }
int SgdCPUKernel::Run() { int SgdCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "SgdCPUKernel Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, SgdRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, SgdRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "SGD function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "SGD function error error_code[" << error_code << "]";

View File

@ -91,12 +91,6 @@ int SoftmaxCrossEntropyWithLogitsRun(void *cdata, int task_id) {
} }
int SoftmaxCrossEntropyWithLogitsCPUKernel::Run() { int SoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "SoftmaxCrossEntropyWithLogitsCPUKernel Prepare failed.";
return ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, SoftmaxCrossEntropyWithLogitsRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, SoftmaxCrossEntropyWithLogitsRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "SoftmaxCrossEntropy function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "SoftmaxCrossEntropy function error error_code[" << error_code << "]";

View File

@ -79,12 +79,6 @@ int SoftmaxGradRun(void *cdata, int task_id) {
} }
int SoftmaxGradCPUKernel::Run() { int SoftmaxGradCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "SoftmaxGradCPUKernel Prepare failed.";
return ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, SoftmaxGradRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, SoftmaxGradRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "SoftmaxGradRun function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "SoftmaxGradRun function error error_code[" << error_code << "]";

View File

@ -118,11 +118,6 @@ int SparseSoftmaxCrossEntropyRun(void *cdata, int task_id) {
} }
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "SparseSoftmaxCrossEntropyWithLogitsCPUKernel Prepare failed.";
return ret;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, SparseSoftmaxCrossEntropyRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, SparseSoftmaxCrossEntropyRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "SparseSoftmaxCrossEntropy function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "SparseSoftmaxCrossEntropy function error error_code[" << error_code << "]";

View File

@ -63,11 +63,6 @@ int TupleRun(void *cdata, int task_id) {
} }
int TupleGetItemCPUKernel::Run() { int TupleGetItemCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "TupleGetItemCPUKernel Prepare failed.";
return RET_ERROR;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, TupleRun, this, 1); int error_code = ParallelLaunch(this->context_->thread_pool_, TupleRun, this, 1);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "tuple function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "tuple function error error_code[" << error_code << "]";

View File

@ -150,12 +150,7 @@ int ConvolutionDepthwise3x3Int8CPUKernel::InitBuffer() {
} }
int ConvolutionDepthwise3x3Int8CPUKernel::Run() { int ConvolutionDepthwise3x3Int8CPUKernel::Run() {
auto ret = Prepare(); auto ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
ret = InitBuffer();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Depthwise int8 ReSize error!"; MS_LOG(ERROR) << "Depthwise int8 ReSize error!";
return ret; return ret;

View File

@ -100,11 +100,6 @@ void NormalizeCPUKernel::FreeBuffer() {
} }
int NormalizeCPUKernel::Run() { int NormalizeCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
return ret;
}
auto input_tensor = in_tensors_.at(0); auto input_tensor = in_tensors_.at(0);
int string_num = lite::GetStringCount(input_tensor); int string_num = lite::GetStringCount(input_tensor);
std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(input_tensor); std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(input_tensor);

View File

@ -73,11 +73,6 @@ std::vector<LabelInfo> PredictCPUKernel::GetLabelInfo() {
static bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; } static bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; }
int PredictCPUKernel::Run() { int PredictCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
return ret;
}
std::vector<LabelInfo> label_info_vec = GetLabelInfo(); std::vector<LabelInfo> label_info_vec = GetLabelInfo();
std::sort(label_info_vec.begin(), label_info_vec.end(), LabelInfoCmp); std::sort(label_info_vec.begin(), label_info_vec.end(), LabelInfoCmp);

View File

@ -39,6 +39,7 @@ class SubGraphOpenCLKernel : public SubGraphKernel {
: SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx) { : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx) {
ocl_runtime_ = ocl_runtime_wrap_.GetInstance(); ocl_runtime_ = ocl_runtime_wrap_.GetInstance();
subgraph_type_ = kGpuSubGraph; subgraph_type_ = kGpuSubGraph;
this->name_ = "GpuSubGraph";
this->executor_ = new lite::opencl::OpenCLExecutor(); this->executor_ = new lite::opencl::OpenCLExecutor();
} }
~SubGraphOpenCLKernel() override; ~SubGraphOpenCLKernel() override;

View File

@ -35,9 +35,6 @@ using kernel::KERNEL_ARCH::kGPU;
int Scheduler::Schedule(const lite::Model *model, std::vector<Tensor *> *tensors, int Scheduler::Schedule(const lite::Model *model, std::vector<Tensor *> *tensors,
std::vector<kernel::LiteKernel *> *kernels) { std::vector<kernel::LiteKernel *> *kernels) {
// 1. op ---> kernel
// 2. sub graph
// 3. kernels (kernels --> subGraph)
int ret = InferShape(model, tensors); int ret = InferShape(model, tensors);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "op infer shape failed."; MS_LOG(ERROR) << "op infer shape failed.";

View File

@ -68,6 +68,7 @@ class CpuFp32SubGraph : public SubGraphKernel {
const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx) const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx)
: SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { : SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) {
subgraph_type_ = kCpuFP32SubGraph; subgraph_type_ = kCpuFP32SubGraph;
this->name_ = "CpuFP32SubGraph";
this->executor_ = new mindspore::lite::Executor; this->executor_ = new mindspore::lite::Executor;
} }
@ -88,6 +89,7 @@ class CpuFp16SubGraph : public SubGraphKernel {
const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx) const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx)
: SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { : SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) {
subgraph_type_ = kCpuFP16SubGraph; subgraph_type_ = kCpuFP16SubGraph;
this->name_ = "CpuFP16SubGraph";
this->executor_ = new mindspore::lite::Executor; this->executor_ = new mindspore::lite::Executor;
} }

View File

@ -120,6 +120,7 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/lite_session.cc ${LITE_DIR}/src/lite_session.cc
${LITE_DIR}/src/sub_graph_kernel.cc ${LITE_DIR}/src/sub_graph_kernel.cc
${LITE_DIR}/src/model.cc ${LITE_DIR}/src/model.cc
${LITE_DIR}/src/model_common.cc
${LITE_DIR}/src/populate_parameter.cc ${LITE_DIR}/src/populate_parameter.cc
${LITE_DIR}/src/scheduler.cc ${LITE_DIR}/src/scheduler.cc
${LITE_DIR}/src/common/graph_util.cc ${LITE_DIR}/src/common/graph_util.cc

View File

@ -72,6 +72,7 @@ set(LITE_SRC
${SRC_DIR}/lite_session.cc ${SRC_DIR}/lite_session.cc
${SRC_DIR}/executor.cc ${SRC_DIR}/executor.cc
${SRC_DIR}/model.cc ${SRC_DIR}/model.cc
${SRC_DIR}/model_common.cc
) )
if (SUPPORT_TRAIN) if (SUPPORT_TRAIN)
set(LITE_SRC set(LITE_SRC