!34451 [MS][LITE] remove lite session api

Merge pull request !34451 from sunsuodong/rm_lite_session
This commit is contained in:
i-robot 2022-05-18 08:14:58 +00:00 committed by Gitee
commit f35e313d33
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
40 changed files with 200 additions and 477 deletions

View File

@ -192,6 +192,7 @@ mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition
mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::PartitionNode
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/instance_norm_fp16.c:InstanceNormNC8HW8Fp16
mindspore/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable
mindspore/mindspore/lite/src/train/train_loop.cc:mindspore::lite::TrainLoop::Train
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_winograd_fp32.c:ConvWinogardFp32
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc:mindspore::opt::MatchAdd5Pattern
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/conv_fp32_nchwx_avx512.c:conv2d_compute_fp32_nchwx_avx512

View File

@ -1,254 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_
#define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_
#ifndef NOT_USE_STL
#include <unordered_map>
#endif // NOT_USE_STL
#include <vector>
#include <string>
#include <map>
#include "include/ms_tensor.h"
#include "include/model.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "include/lite_types.h"
namespace mindspore {
namespace lite {
class TrainCfg;
}
namespace session {
/// \brief LiteSession defined session in MindSpore Lite for compiling Model and forwarding model.
class MS_API LiteSession {
public:
/// \brief Static method to create a LiteSession pointer.
///
/// \param[in] context Define the context of session to be created.
///
/// \return Pointer of MindSpore Lite LiteSession.
static LiteSession *CreateSession(const lite::Context *context);
/// \brief Static method to create a LiteSession pointer which has already compiled a model.
///
/// \param[in] model_buf Define the buffer read from a model file.
/// \param[in] size Define bytes number of model buffer.
/// \param[in] context Define the context of session to be created.
///
/// \return Pointer of MindSpore Lite LiteSession.
static LiteSession *CreateSession(const char *model_buf, size_t size, const lite::Context *context);
/// \brief Destructor of MindSpore Lite LiteSession.
virtual ~LiteSession() = default;
/// \brief Attempt to bind or unbind threads in the thread pool to or from the specified cpu core.
///
/// \param[in] if_bind Define whether to bind or unbind threads.
virtual void BindThread(bool if_bind) = 0;
/// \brief Compile MindSpore Lite model.
///
/// \note CompileGraph should be called before RunGraph.
///
/// \param[in] model Define the model to be compiled.
///
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h.
virtual int CompileGraph(lite::Model *model) = 0;
/// \brief Get input MindSpore Lite MSTensors of model.
///
/// \return The vector of MindSpore Lite MSTensor.
virtual Vector<tensor::MSTensor *> GetInputs() const = 0;
/// \brief Get input MindSpore Lite MSTensors of model by tensor name.
///
/// \param[in] node_name Define tensor name.
///
/// \return The vector of MindSpore Lite MSTensor.
virtual mindspore::tensor::MSTensor *GetInputsByTensorName(const String &tensor_name) const = 0;
/// \brief Run session with callback.
///
/// \param[in] before Define a call_back_function to be called before running each node.
/// \param[in] after Define a call_back_function called after running each node.
///
/// \note RunGraph should be called after CompileGraph.
///
/// \return STATUS as an error code of running graph, STATUS is defined in errorcode.h.
virtual int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) = 0;
/// \brief Get output MindSpore Lite MSTensors of model by node name.
///
/// \param[in] node_name Define node name.
///
/// \note Deprecated, replace with GetOutputByTensorName
///
/// \return The vector of MindSpore Lite MSTensor.
virtual Vector<tensor::MSTensor *> GetOutputsByNodeName(const String &node_name) const = 0;
#ifndef NOT_USE_STL
/// \brief Get output MindSpore Lite MSTensors of model mapped by tensor name.
///
/// \return The map of output tensor name and MindSpore Lite MSTensor.
virtual std::unordered_map<String, mindspore::tensor::MSTensor *> GetOutputs() const = 0;
#endif
/// \brief Get name of output tensors of model compiled by this session.
///
/// \return The vector of string as output tensor names in order.
virtual Vector<String> GetOutputTensorNames() const = 0;
/// \brief Get output MindSpore Lite MSTensors of model by tensor name.
///
/// \param[in] tensor_name Define tensor name.
///
/// \return Pointer of MindSpore Lite MSTensor.
virtual mindspore::tensor::MSTensor *GetOutputByTensorName(const String &tensor_name) const = 0;
/// \brief Bind GLTexture2D object to cl Memory.
///
/// \param[in] inputGlTexture The input GLTexture id for Model.
/// \param[in] outputGLTexture The output GLTexture id for Model.
///
/// \return Status of operation.
virtual int BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGlTexture,
std::map<std::string, unsigned int> *outputGLTexture) = 0;
/// \brief Resize inputs shape.
///
/// \param[in] inputs Define the inputs of the model.
/// \param[in] dims Define the inputs new shape.
///
/// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h.
virtual int Resize(const Vector<tensor::MSTensor *> &inputs, const Vector<Vector<int>> &dims) = 0;
/// \brief Set model to train mode
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
virtual int Train() { return mindspore::lite::RET_ERROR; }
/// \brief Check mode of model
///
/// \return boolean indication if model is in train mode
virtual bool IsTrain() { return false; }
/// \brief Set model to eval mode
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
virtual int Eval() { return mindspore::lite::RET_OK; }
/// \brief Check mode of model
///
/// \return boolean indication if model is in eval mode
virtual bool IsEval() { return true; }
/// \brief Sets the Learning Rate of the training
///
/// \param[in] learning_rate to set
///
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int SetLearningRate(float learning_rate) { return mindspore::lite::RET_ERROR; }
/// \brief Gets the Learning Rate of the training
///
/// \return learning rate. 0.0 if no optimizer was found
virtual float GetLearningRate() { return 0.0; }
/// \brief Setup training with virtual batches
///
/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable
/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration
/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) {
return mindspore::lite::RET_ERROR;
}
/// \brief Get output MindSpore Lite MSTensors of Training model prediction
///
/// \return a vector of output tensors (MindSpore Lite MSTensor).
virtual std::vector<tensor::MSTensor *> GetPredictions() const {
std::vector<tensor::MSTensor *> outputs;
return outputs;
}
/// \brief Save model
/// \param[in] file_name pretrained model file name prefix. '.ms' extenension is added if does not exist
/// \param[in] model_type indication whether to save full model or only the inference part
/// \param[in] quant_type indication whether to quantize exported model
/// \param[in] format of exported file (currently only FT_FLATBUFFERS is supported)
/// \param[in] out_put_tensor_name of exported tensorname
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int Export(const std::string &file_name, lite::ModelType model_type = lite::MT_TRAIN,
lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS,
std::vector<std::string> out_put_tensor_name = {}) {
return mindspore::lite::RET_ERROR;
}
/// \brief Change the size and or content of weight tensors
///
/// \param[in] new_weights a vector of tensors with new shapes and data to use in the model
/// If data pointer is null, the data of the original tensors will be copied to the new ones
///
/// \return STATUS as an error code of operation, STATUS is defined in errorcode.h.
virtual int UpdateWeights(std::vector<tensor::MSTensor *> new_weights) { return mindspore::lite::RET_ERROR; }
/// \brief Get model featuremap MindSpore Lite MSTensors of Training model prediction
///
/// \return a vector of output tensors (MindSpore Lite MSTensor).
virtual std::vector<tensor::MSTensor *> GetFeatureMaps() const {
std::vector<tensor::MSTensor *> features;
return features;
}
/// \brief update model featuremap save to update_ms_file
/// \param[in] features new featuremap
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features) { return mindspore::lite::RET_ERROR; }
/// \brief Get model gradient
///
/// \return a vector of gradient tensors (MindSpore Lite MSTensor).
virtual std::vector<tensor::MSTensor *> GetGradients() const {
std::vector<tensor::MSTensor *> gradients;
return gradients;
}
/// \brief update model gradient
///
/// \param[in] new gradients
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) { return mindspore::lite::RET_ERROR; }
/// \brief Get model optimizer params
///
/// \return a vector of optimizer parameters (MindSpore Lite MSTensor).
virtual std::vector<tensor::MSTensor *> GetOptimizerParams() const {
std::vector<tensor::MSTensor *> params;
return params;
}
/// \brief set model optimizer params
///
/// \param[in] new optimizer params
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int SetOptimizerParams(const std::vector<tensor::MSTensor *> &params) { return mindspore::lite::RET_ERROR; }
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_

View File

@ -24,13 +24,13 @@
namespace mindspore {
namespace lite {
class AccuracyMonitor : public session::TrainLoopCallBack {
class AccuracyMonitor : public TrainLoopCallBack {
public:
explicit AccuracyMonitor(mindspore::dataset::Dataset *dataset, int check_every_n, int max_steps = -1)
: ds_(dataset), check_every_n_(check_every_n), max_steps_(max_steps) {}
~AccuracyMonitor() = default;
void Begin(const session::TrainLoopCallBackData &cb_data) override;
int EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) override;
void Begin(const TrainLoopCallBackData &cb_data) override;
int EpochEnd(const TrainLoopCallBackData &cb_data) override;
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }
private:

View File

@ -25,20 +25,20 @@
namespace mindspore {
namespace lite {
class CkptSaver : public session::TrainLoopCallBack {
class CkptSaver : public TrainLoopCallBack {
public:
CkptSaver(size_t save_every_n, std::string filename_prefix)
: save_every_n_(save_every_n), filename_prefix_(std::move(filename_prefix)) {}
~CkptSaver() override = default;
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override {
int EpochEnd(const TrainLoopCallBackData &cb_data) override {
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";
remove(cpkt_fn.c_str());
cb_data.session_->Export(cpkt_fn);
}
return session::RET_CONTINUE;
return RET_CONTINUE;
}
private:

View File

@ -27,7 +27,7 @@
namespace mindspore {
namespace lite {
class ClassificationTrainAccuracyMonitor : public session::TrainLoopCallBack {
class ClassificationTrainAccuracyMonitor : public TrainLoopCallBack {
public:
explicit ClassificationTrainAccuracyMonitor(int print_every_n = INT_MAX,
int accuracy_metrics = METRICS_CLASSIFICATION,
@ -35,10 +35,10 @@ class ClassificationTrainAccuracyMonitor : public session::TrainLoopCallBack {
const std::vector<int> &output_indexes = {0});
virtual ~ClassificationTrainAccuracyMonitor() = default;
void Begin(const session::TrainLoopCallBackData &cb_data) override;
void EpochBegin(const session::TrainLoopCallBackData &cb_data) override;
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
void StepEnd(const session::TrainLoopCallBackData &cb_data) override;
void Begin(const TrainLoopCallBackData &cb_data) override;
void EpochBegin(const TrainLoopCallBackData &cb_data) override;
int EpochEnd(const TrainLoopCallBackData &cb_data) override;
void StepEnd(const TrainLoopCallBackData &cb_data) override;
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }
private:

View File

@ -25,7 +25,7 @@
namespace mindspore {
namespace lite {
class LossMonitor : public session::TrainLoopCallBack {
class LossMonitor : public TrainLoopCallBack {
public:
/// \brief constructor
///
@ -38,10 +38,10 @@ class LossMonitor : public session::TrainLoopCallBack {
/// \return 0 on success or -1 in case of error
explicit LossMonitor(int print_every_n_steps = INT_MAX) : print_every_n_(print_every_n_steps) {}
virtual ~LossMonitor() = default;
void Begin(const session::TrainLoopCallBackData &cb_data) override;
void EpochBegin(const session::TrainLoopCallBackData &cb_data) override;
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
void StepEnd(const session::TrainLoopCallBackData &cb_data) override;
void Begin(const TrainLoopCallBackData &cb_data) override;
void EpochBegin(const TrainLoopCallBackData &cb_data) override;
int EpochEnd(const TrainLoopCallBackData &cb_data) override;
void StepEnd(const TrainLoopCallBackData &cb_data) override;
const std::vector<GraphPoint> &GetLossPoints() const { return losses_; }
private:

View File

@ -42,11 +42,11 @@ struct StepLRLambda {
float gamma; // LR decay factor
};
class LRScheduler : public session::TrainLoopCallBack {
class LRScheduler : public TrainLoopCallBack {
public:
explicit LRScheduler(LR_Lambda lambda_func, void *lr_cb_data = nullptr, int step_ = 1);
virtual ~LRScheduler() = default;
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
int EpochEnd(const TrainLoopCallBackData &cb_data) override;
private:
LR_Lambda lambda_func_;

View File

@ -21,7 +21,7 @@
#include <unordered_map>
#include "include/train/train_loop_callback.h"
#include "include/train/metrics.h"
#include "include/lite_session.h"
#include "src/runtime/lite_session.h"
namespace mindspore {
class MSTensor;
@ -42,7 +42,7 @@ class TrainLoop {
/// \param[in] train_session Train session object as return from CreateSession\CreateTransferSession API
///
/// \return Pointer of MindSpore Lite TrainLoop
static TrainLoop *CreateTrainLoop(session::LiteSession *train_session);
static lite::TrainLoop *CreateTrainLoop(lite::LiteSession *train_session);
/// \brief Class destructor
virtual ~TrainLoop() = default;
@ -55,7 +55,7 @@ class TrainLoop {
/// \brief Accessor to the LiteSession
///
/// \return pointer of the train_session
const virtual session::LiteSession *train_session() = 0;
const virtual lite::LiteSession *train_session() = 0;
/// \brief Initialize object with metrics
///
@ -85,7 +85,7 @@ class TrainLoop {
/// \param[in] load_func a function that load (and can manipulate) data from Minddata Dataset array into model
///
/// \return 0 on success or -1 in case of error
virtual int Train(int epochs, mindspore::dataset::Dataset *dataset, std::vector<TrainLoopCallBack *> cbs,
virtual int Train(int epochs, mindspore::dataset::Dataset *dataset, std::vector<lite::TrainLoopCallBack *> cbs,
LoadDataFunc load_func) = 0;
/// \brief Performs loop over all data in Eval Mode
@ -96,8 +96,8 @@ class TrainLoop {
/// \param[in] max_steps (with default = INT_MAX the method iterates all dataset)
///
/// \return 0 on success or -1 in case of error
virtual int Eval(mindspore::dataset::Dataset *dataset, std::vector<TrainLoopCallBack *> cbs, LoadDataFunc load_func,
int max_steps) = 0;
virtual int Eval(mindspore::dataset::Dataset *dataset, std::vector<lite::TrainLoopCallBack *> cbs,
LoadDataFunc load_func, int max_steps) = 0;
};
} // namespace session
} // namespace mindspore

View File

@ -23,8 +23,7 @@
#include "include/api/callback/callback.h"
namespace mindspore {
namespace session {
namespace lite {
class LiteSession;
class TrainLoop;
@ -83,6 +82,6 @@ class TrainLoopCallBack {
virtual void StepEnd(const TrainLoopCallBackData &cb_data) {}
};
} // namespace session
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_

View File

@ -1,50 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
#include <string>
#include "include/lite_session.h"
namespace mindspore {
namespace session {
class TrainSession {
public:
/// \brief Static method to create a TransferSession object
///
/// \param[in] filename_backbone Filename to read backbone net flatbuffer from
/// \param[in] filename_head Filename to read head net flatbuffer from
/// \param[in] context Defines the context of the session to be created
/// \param[in] train_mode training mode to initialize Session with
///
/// \return Pointer of MindSpore LiteSession
static LiteSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
const lite::Context *context, bool train_mode = false,
const lite::TrainCfg *cfg = nullptr);
/// \brief Static method to create a TrainSession object
///
/// \param[in] filename name of flatbuffer that holds the flatbuffer
/// \param[in] context Defines the context of the session to be created
/// \param[in] train_mode training mode to initialize Session with
/// \param[in] cfg training configuration, set to null for default configuration
///
/// \return Pointer of MindSpore LiteSession
static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context,
bool train_mode = false, const lite::TrainCfg *cfg = nullptr);
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_

View File

@ -79,7 +79,7 @@ set(JNI_SRC
${CMAKE_CURRENT_SOURCE_DIR}/runtime/version.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_config.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common/jni_utils.cpp
${NEW_NATIVE_DIR}/graph.cpp
${NEW_NATIVE_DIR}/model.cpp
@ -110,7 +110,7 @@ endif()
if(SUPPORT_TRAIN)
set(LITE_TRAIN_SO_NAME mindspore-lite-train)
set(JNI_TRAIN_SRC
${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
${NEW_NATIVE_DIR}/train_config.cpp
)
add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC})

View File

@ -17,7 +17,7 @@
#include <jni.h>
#include <fstream>
#include "common/ms_log.h"
#include "include/lite_session.h"
#include "src/runtime/lite_session.h"
#include "include/errorcode.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSessionWithModel(JNIEnv *env, jobject thiz,
jobject model_buffer,
@ -37,7 +37,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSes
}
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
// create session
auto session = mindspore::session::LiteSession::CreateSession(model_buf, buffer_len, lite_context_ptr);
auto session = mindspore::lite::LiteSession::CreateSession(model_buf, buffer_len, lite_context_ptr);
if (session == nullptr) {
MS_LOGE("CreateSession failed");
return jlong(nullptr);
@ -53,7 +53,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSes
return jlong(nullptr);
}
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
auto session = mindspore::session::LiteSession::CreateSession(lite_context_ptr);
auto session = mindspore::lite::LiteSession::CreateSession(lite_context_ptr);
if (session == nullptr) {
MS_LOGE("CreateSession failed");
return jlong(nullptr);
@ -69,7 +69,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_compil
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
auto *model_pointer = reinterpret_cast<void *>(model_ptr);
if (model_pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
@ -88,7 +88,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_bindThread
MS_LOGE("Session pointer from java is nullptr");
return;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
lite_session_ptr->BindThread(if_bind);
}
@ -99,7 +99,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_runGra
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
auto ret = lite_session_ptr->RunGraph();
return (jboolean)(ret == mindspore::lite::RET_OK);
}
@ -118,7 +118,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
auto inputs = lite_session_ptr->GetInputs();
for (auto input : inputs) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
@ -135,7 +135,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getInputs
MS_LOGE("Session pointer from java is nullptr");
return jlong(nullptr);
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
auto input = lite_session_ptr->GetInputsByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
return jlong(input);
}
@ -155,7 +155,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
auto inputs = lite_session_ptr->GetOutputsByNodeName(env->GetStringUTFChars(node_name, JNI_FALSE));
for (auto input : inputs) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
@ -177,7 +177,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
MS_LOGE("Session pointer from java is nullptr");
return hash_map;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
auto outputs = lite_session_ptr->GetOutputs();
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
@ -203,7 +203,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
auto output_names = lite_session_ptr->GetOutputTensorNames();
for (const auto &output_name : output_names) {
env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str()));
@ -219,7 +219,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getOutput
MS_LOGE("Session pointer from java is nullptr");
return jlong(nullptr);
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
auto output = lite_session_ptr->GetOutputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
return jlong(output);
}
@ -231,7 +231,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_free(JNIEn
MS_LOGE("Session pointer from java is nullptr");
return;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
delete (lite_session_ptr);
}
@ -244,7 +244,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_resize
MS_LOGE("Session pointer from java is nullptr");
return false;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
auto input_size = static_cast<int>(env->GetArrayLength(inputs));
jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
@ -300,7 +300,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_train(
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
auto ret = lite_session_ptr->Train();
return (jboolean)(ret == mindspore::lite::RET_OK);
}
@ -312,7 +312,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_eval(J
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
auto ret = lite_session_ptr->Eval();
return (jboolean)(ret == mindspore::lite::RET_OK);
}
@ -324,7 +324,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_isTrai
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
auto ret = lite_session_ptr->IsTrain();
return (jboolean)(ret);
}
@ -336,7 +336,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_isEval
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
auto ret = lite_session_ptr->IsEval();
return (jboolean)(ret);
}
@ -349,7 +349,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setLea
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
auto ret = lite_session_ptr->SetLearningRate(learning_rate);
return (jboolean)(ret == mindspore::lite::RET_OK);
}
@ -364,7 +364,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setupV
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
auto ret = lite_session_ptr->SetupVirtualBatch(virtual_batch_factor, learning_rate, momentum);
return (jboolean)(ret == mindspore::lite::RET_OK);
}
@ -384,7 +384,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_update
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(tensor_pointer);
newFeatures.emplace_back(ms_tensor_ptr);
}
auto session = reinterpret_cast<mindspore::session::LiteSession *>(session_ptr);
auto session = reinterpret_cast<mindspore::lite::LiteSession *>(session_ptr);
auto ret = session->UpdateFeatureMaps(newFeatures);
return (jboolean)(ret == mindspore::lite::RET_OK);
}
@ -403,7 +403,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getFeat
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *train_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto *train_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
auto inputs = train_session_ptr->GetFeatureMaps();
for (auto input : inputs) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));

View File

@ -16,7 +16,7 @@
#include <jni.h>
#include "common/ms_log.h"
#include "include/train/train_session.h"
#include "src/train/train_session.h"
#include "include/train/train_cfg.h"
#include "include/errorcode.h"
@ -32,7 +32,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createTr
}
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
auto session = mindspore::session::TrainSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE),
auto session = mindspore::lite::TrainSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE),
lite_context_ptr, train_mode, nullptr);
if (session == nullptr) {
MS_LOGE("CreateTrainSession failed");

View File

@ -22,32 +22,32 @@
namespace mindspore {
class TrainLoopCallBackAdapter : public session::TrainLoopCallBack {
class TrainLoopCallBackAdapter : public lite::TrainLoopCallBack {
public:
explicit TrainLoopCallBackAdapter(Model *model, TrainCallBack *call_back) : model_(model), call_back_(call_back) {}
TrainLoopCallBackAdapter() = delete;
void Begin(const session::TrainLoopCallBackData &i_cb_data) override {
void Begin(const lite::TrainLoopCallBackData &i_cb_data) override {
call_back_->Begin(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
};
void End(const session::TrainLoopCallBackData &i_cb_data) override {
void End(const lite::TrainLoopCallBackData &i_cb_data) override {
call_back_->End(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
};
void EpochBegin(const session::TrainLoopCallBackData &i_cb_data) override {
void EpochBegin(const lite::TrainLoopCallBackData &i_cb_data) override {
call_back_->EpochBegin(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
};
int EpochEnd(const session::TrainLoopCallBackData &i_cb_data) override {
int EpochEnd(const lite::TrainLoopCallBackData &i_cb_data) override {
return call_back_->EpochEnd(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
};
void StepBegin(const session::TrainLoopCallBackData &i_cb_data) override {
void StepBegin(const lite::TrainLoopCallBackData &i_cb_data) override {
call_back_->StepBegin(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
};
void StepEnd(const session::TrainLoopCallBackData &i_cb_data) override {
void StepEnd(const lite::TrainLoopCallBackData &i_cb_data) override {
call_back_->StepEnd(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
};

View File

@ -24,11 +24,11 @@ namespace mindspore {
class CallbackImpl {
public:
CallbackImpl() = delete;
explicit CallbackImpl(session::TrainLoopCallBack *cb) : internal_call_back_(cb) {}
session::TrainLoopCallBack *GetInternalCallback() { return internal_call_back_; }
explicit CallbackImpl(lite::TrainLoopCallBack *cb) : internal_call_back_(cb) {}
lite::TrainLoopCallBack *GetInternalCallback() { return internal_call_back_; }
protected:
session::TrainLoopCallBack *internal_call_back_ = nullptr;
lite::TrainLoopCallBack *internal_call_back_ = nullptr;
};
} // namespace mindspore

View File

@ -50,7 +50,7 @@ const std::vector<GraphPoint> &LossMonitor::GetLossPoints() {
return empty_vector;
}
session::TrainLoopCallBack *internal_call_back = callback_impl_->GetInternalCallback();
lite::TrainLoopCallBack *internal_call_back = callback_impl_->GetInternalCallback();
if (internal_call_back == nullptr) {
MS_LOG(ERROR) << "Internal callback is null.";
return empty_vector;

View File

@ -52,7 +52,7 @@ const std::vector<GraphPoint> &TrainAccuracy::GetAccuracyPoints() {
return empty_vector;
}
session::TrainLoopCallBack *internal_call_back = callback_impl_->GetInternalCallback();
lite::TrainLoopCallBack *internal_call_back = callback_impl_->GetInternalCallback();
if (internal_call_back == nullptr) {
MS_LOG(ERROR) << "Internal callback is null.";
return empty_vector;

View File

@ -30,6 +30,7 @@
#include "src/runtime/cxx_api/graph/graph_data.h"
#include "src/runtime/inner_context.h"
#include "src/runtime/lite_session.h"
#include "include/train/train_loop_callback.h"
template <class T>
void clearVectorOfPointers(std::vector<T> *v) {
@ -105,13 +106,13 @@ class ModelImpl {
return kSuccess;
}
std::vector<Metrics *> GetMetrics() { return metrics_; }
const session::LiteSession *GetSession() const { return session_.get(); }
const lite::LiteSession *GetSession() const { return session_.get(); }
protected:
// Utility methods
Status ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
std::vector<session::TrainLoopCallBack *> *o_cbs,
std::vector<session::TrainLoopCallBack *> *adapter_cbs);
std::vector<lite::TrainLoopCallBack *> *o_cbs,
std::vector<lite::TrainLoopCallBack *> *adapter_cbs);
Status PrepareMetrics(Model *model, std::vector<session::Metrics *> *o_ms,
std::vector<session::Metrics *> *adapter_ms);

View File

@ -21,7 +21,7 @@
#include "src/runtime/cxx_api/model/model_impl.h"
#include "src/runtime/cxx_api/callback/callback_impl.h"
#include "src/common/log_adapter.h"
#include "include/train/train_loop.h"
#include "src/train/train_loop.h"
#include "include/train/train_loop_callback.h"
namespace mindspore {
@ -30,7 +30,7 @@ Status Model::Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vecto
MS_LOG(ERROR) << "Model implement or dataset is null.";
return kLiteUninitializedObj;
}
auto loop = std::unique_ptr<session::TrainLoop>(session::TrainLoop::CreateTrainLoop((impl_->session_).get()));
auto loop = std::unique_ptr<lite::TrainLoop>(lite::TrainLoop::CreateTrainLoop((impl_->session_).get()));
if (loop == nullptr) {
MS_LOG(ERROR) << "Error during allocation of train loop";
return kLiteNullptr;
@ -47,8 +47,8 @@ Status Model::Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vecto
(void)loop->Init(metrics);
// Convert Callbacks to be used by loop
std::vector<session::TrainLoopCallBack *> cbs;
std::vector<session::TrainLoopCallBack *> adapter_cbs;
std::vector<lite::TrainLoopCallBack *> cbs;
std::vector<lite::TrainLoopCallBack *> adapter_cbs;
status = impl_->ConvertCallbacks(this, &i_cbs, &cbs, &adapter_cbs);
if (status != kSuccess) {
MS_LOG(ERROR) << "Error during preparation of callbacks";
@ -70,7 +70,7 @@ Status Model::Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCa
return kLiteUninitializedObj;
}
auto loop = std::unique_ptr<session::TrainLoop>(session::TrainLoop::CreateTrainLoop((impl_->session_).get()));
auto loop = std::unique_ptr<lite::TrainLoop>(lite::TrainLoop::CreateTrainLoop((impl_->session_).get()));
if (loop == nullptr) {
MS_LOG(ERROR) << "Error during allocation of train loop";
return kLiteNullptr;
@ -87,8 +87,8 @@ Status Model::Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCa
(void)loop->Init(metrics);
// Convert Callbacks to be used by loop
std::vector<session::TrainLoopCallBack *> cbs;
std::vector<session::TrainLoopCallBack *> adapter_cbs;
std::vector<lite::TrainLoopCallBack *> cbs;
std::vector<lite::TrainLoopCallBack *> adapter_cbs;
status = impl_->ConvertCallbacks(this, &i_cbs, &cbs, &adapter_cbs);
if (status != kSuccess) {
MS_LOG(ERROR) << "Error during preparation of callbacks";

View File

@ -68,8 +68,8 @@ Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *
}
Status ModelImpl::ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
std::vector<session::TrainLoopCallBack *> *o_cbs,
std::vector<session::TrainLoopCallBack *> *adapter_cbs) {
std::vector<lite::TrainLoopCallBack *> *o_cbs,
std::vector<lite::TrainLoopCallBack *> *adapter_cbs) {
if (i_cbs == nullptr || o_cbs == nullptr || adapter_cbs == nullptr) {
MS_LOG(ERROR) << "Null input callbacks";
return kLiteUninitializedObj;

View File

@ -1510,7 +1510,7 @@ int LiteSession::InitGPURuntime() {
}
} // namespace lite
session::LiteSession *session::LiteSession::CreateSession(const lite::Context *context) {
lite::LiteSession *lite::LiteSession::CreateSession(const lite::Context *context) {
if (context == nullptr) {
return nullptr;
}
@ -1536,9 +1536,8 @@ session::LiteSession *session::LiteSession::CreateSession(const lite::Context *c
return session;
}
session::LiteSession *session::LiteSession::CreateSession(const char *model_buf, size_t size,
const lite::Context *context) {
auto *session = LiteSession::CreateSession(context);
lite::LiteSession *lite::LiteSession::CreateSession(const char *model_buf, size_t size, const lite::Context *context) {
auto *session = lite::LiteSession::CreateSession(context);
if (session == nullptr) {
MS_LOG(ERROR) << "Create session failed";
return nullptr;
@ -1553,8 +1552,8 @@ session::LiteSession *session::LiteSession::CreateSession(const char *model_buf,
return session;
}
session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_path, const lite::Context *context) {
auto *session = session::LiteSession::CreateSession(context);
lite::LiteSession *lite::LiteSession::CreateSession(const std::string &model_path, const lite::Context *context) {
auto *session = lite::LiteSession::CreateSession(context);
if (session == nullptr) {
MS_LOG(ERROR) << "Create session failed";
return nullptr;

View File

@ -25,7 +25,6 @@
#include <atomic>
#include "src/runtime/kernel_exec.h"
#include "include/ms_tensor.h"
#include "include/lite_session.h"
#include "src/runtime/lite_model.h"
#include "src/runtime/inner_context.h"
#include "src/runtime/runtime_allocator.h"
@ -41,11 +40,13 @@
namespace mindspore {
namespace lite {
class LiteSession : public session::LiteSession {
class LiteSession {
public:
LiteSession();
~LiteSession() override;
static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context);
virtual ~LiteSession();
static LiteSession *CreateSession(const lite::Context *context);
static LiteSession *CreateSession(const char *model_buf, size_t size, const lite::Context *context);
static LiteSession *CreateSession(const std::string &model_path, const lite::Context *context);
int LoadModelAndCompileByBuf(const char *model_buf, mindspore::ModelType model_type, const size_t &buf_size);
int LoadModelAndCompileByBuf(const char *model_buf, mindspore::ModelType model_type, const size_t &buf_size,
const std::shared_ptr<mindspore::Context> &ms_context);
@ -62,19 +63,19 @@ class LiteSession : public session::LiteSession {
static const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size,
const std::shared_ptr<mindspore::Context> &ms_context);
virtual int Init(InnerContext *context);
void BindThread(bool if_bind) override;
int CompileGraph(Model *model) override;
std::vector<mindspore::tensor::MSTensor *> GetInputs() const override;
mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &name) const override;
int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
std::vector<mindspore::tensor::MSTensor *> GetOutputsByNodeName(const std::string &node_name) const override;
std::vector<std::string> GetOutputTensorNames() const override;
mindspore::tensor::MSTensor *GetOutputByTensorName(const std::string &tensor_name) const override;
std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputs() const override;
int BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
std::map<std::string, unsigned int> *outputGLTexture) override;
int Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs,
const std::vector<std::vector<int>> &dims) override;
virtual void BindThread(bool if_bind);
virtual int CompileGraph(Model *model);
virtual std::vector<mindspore::tensor::MSTensor *> GetInputs() const;
virtual mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &name) const;
virtual int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr);
virtual std::vector<mindspore::tensor::MSTensor *> GetOutputsByNodeName(const std::string &node_name) const;
virtual std::vector<std::string> GetOutputTensorNames() const;
virtual mindspore::tensor::MSTensor *GetOutputByTensorName(const std::string &tensor_name) const;
virtual std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputs() const;
virtual int BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
std::map<std::string, unsigned int> *outputGLTexture);
virtual int Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs,
const std::vector<std::vector<int>> &dims);
void InitExecutionConfig(std::map<std::string, TypeId> *config) { execution_plan_ = config; }
void set_model(Model *model) { this->model_ = model; }
const std::vector<kernel::KernelExec *> &get_kernels() const { return this->kernels_; }
@ -84,6 +85,41 @@ class LiteSession : public session::LiteSession {
}
const std::vector<Tensor *> &GetTensors() const { return this->tensors_; }
virtual int Train() { return mindspore::lite::RET_ERROR; }
virtual bool IsTrain() { return false; }
virtual int Eval() { return mindspore::lite::RET_OK; }
virtual bool IsEval() { return true; }
virtual int SetLearningRate(float learning_rate) { return mindspore::lite::RET_ERROR; }
virtual float GetLearningRate() { return 0.0; }
virtual int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) {
return mindspore::lite::RET_ERROR;
}
virtual std::vector<tensor::MSTensor *> GetPredictions() const {
std::vector<tensor::MSTensor *> outputs;
return outputs;
}
virtual int Export(const std::string &file_name, lite::ModelType model_type = lite::MT_TRAIN,
lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS,
std::vector<std::string> out_put_tensor_name = {}) {
return mindspore::lite::RET_ERROR;
}
virtual int UpdateWeights(std::vector<tensor::MSTensor *> new_weights) { return mindspore::lite::RET_ERROR; }
virtual std::vector<tensor::MSTensor *> GetFeatureMaps() const {
std::vector<tensor::MSTensor *> features;
return features;
}
virtual int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features) { return mindspore::lite::RET_ERROR; }
virtual std::vector<tensor::MSTensor *> GetGradients() const {
std::vector<tensor::MSTensor *> gradients;
return gradients;
}
virtual int ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) { return mindspore::lite::RET_ERROR; }
virtual std::vector<tensor::MSTensor *> GetOptimizerParams() const {
std::vector<tensor::MSTensor *> params;
return params;
}
virtual int SetOptimizerParams(const std::vector<tensor::MSTensor *> &params) { return mindspore::lite::RET_ERROR; }
protected:
static void ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lite::Tensor *dst_tensor);
int CheckTensorValid(lite::Tensor *dst_tensor);

View File

@ -23,17 +23,17 @@
#include <fstream>
#include <memory>
#include "include/errorcode.h"
#include "include/train/train_loop.h"
#include "src/train/train_loop.h"
#include "src/common/utils.h"
#include "src/tensor.h"
namespace mindspore {
namespace lite {
void AccuracyMonitor::Begin(const session::TrainLoopCallBackData &cb_data) {
void AccuracyMonitor::Begin(const lite::TrainLoopCallBackData &cb_data) {
if (cb_data.epoch_ == 0) accuracies_.clear();
}
int AccuracyMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
int AccuracyMonitor::EpochEnd(const lite::TrainLoopCallBackData &cb_data) {
if ((static_cast<int>(cb_data.epoch_) + 1) % check_every_n_ == 0) {
auto ret = cb_data.loop_->Eval(ds_, {}, nullptr, max_steps_);
if (ret != RET_OK) {
@ -42,7 +42,7 @@ int AccuracyMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
}
}
accuracies_.push_back(std::make_pair(cb_data.epoch_, 0.0));
return mindspore::session::RET_CONTINUE;
return RET_CONTINUE;
}
} // namespace lite
} // namespace mindspore

View File

@ -30,11 +30,11 @@ ClassificationTrainAccuracyMonitor::ClassificationTrainAccuracyMonitor(int print
print_every_n_ = print_every_n;
}
void ClassificationTrainAccuracyMonitor::Begin(const session::TrainLoopCallBackData &cb_data) {
void ClassificationTrainAccuracyMonitor::Begin(const TrainLoopCallBackData &cb_data) {
if (cb_data.epoch_ == 0) accuracies_.clear();
}
void ClassificationTrainAccuracyMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) {
void ClassificationTrainAccuracyMonitor::EpochBegin(const TrainLoopCallBackData &cb_data) {
if (accuracies_.size() != cb_data.epoch_) {
MS_LOG(WARNING) << "Accuracies array does not match epoch number";
} else {
@ -42,16 +42,16 @@ void ClassificationTrainAccuracyMonitor::EpochBegin(const session::TrainLoopCall
}
}
int ClassificationTrainAccuracyMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
int ClassificationTrainAccuracyMonitor::EpochEnd(const TrainLoopCallBackData &cb_data) {
if (cb_data.step_ > 0) accuracies_.at(cb_data.epoch_).second /= static_cast<float>(cb_data.step_ + 1);
if ((static_cast<int>(cb_data.epoch_) + 1) % print_every_n_ == 0) {
std::cout << "Epoch (" << (cb_data.epoch_ + 1) << "):\tTraining Accuracy is "
<< accuracies_.at(cb_data.epoch_).second << std::endl;
}
return mindspore::session::RET_CONTINUE;
return RET_CONTINUE;
}
void ClassificationTrainAccuracyMonitor::StepEnd(const session::TrainLoopCallBackData &cb_data) {
void ClassificationTrainAccuracyMonitor::StepEnd(const TrainLoopCallBackData &cb_data) {
auto inputs = cb_data.session_->GetInputs();
auto outputs = cb_data.session_->GetPredictions();

View File

@ -26,11 +26,11 @@
namespace mindspore {
namespace lite {
void LossMonitor::Begin(const session::TrainLoopCallBackData &cb_data) {
void LossMonitor::Begin(const TrainLoopCallBackData &cb_data) {
if (cb_data.epoch_ == 0) losses_.clear();
}
void LossMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) {
void LossMonitor::EpochBegin(const TrainLoopCallBackData &cb_data) {
if (losses_.size() != cb_data.epoch_) {
MS_LOG(WARNING) << "losses array does not match epoch number";
} else {
@ -38,15 +38,15 @@ void LossMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) {
}
}
int LossMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
int LossMonitor::EpochEnd(const TrainLoopCallBackData &cb_data) {
if (cb_data.step_ > 0) losses_.at(cb_data.epoch_).second /= static_cast<float>(cb_data.step_ + 1);
if (print_every_n_ > 0) {
std::cout << "Epoch (" << (cb_data.epoch_ + 1) << "):\tLoss is " << losses_.at(cb_data.epoch_).second << std::endl;
}
return mindspore::session::RET_CONTINUE;
return RET_CONTINUE;
}
void LossMonitor::StepEnd(const session::TrainLoopCallBackData &cb_data) {
void LossMonitor::StepEnd(const TrainLoopCallBackData &cb_data) {
auto outputs = cb_data.session_->GetOutputs();
for (auto it = outputs.begin(); it != outputs.end(); ++it) {
if (it->second->ElementsNum() == 1) {

View File

@ -55,7 +55,7 @@ int StepLRLambda(float *lr, int epoch, void *lr_cb_data) {
LRScheduler::LRScheduler(LR_Lambda lambda_func, void *lr_cb_data, int step)
: lambda_func_(lambda_func), lr_data_(lr_cb_data), step_(step) {}
int LRScheduler::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
int LRScheduler::EpochEnd(const TrainLoopCallBackData &cb_data) {
if (((static_cast<int>(cb_data.epoch_) + 1) % step_) == 0) {
float lr = cb_data.session_->GetLearningRate();
int update = lambda_func_(&lr, cb_data.epoch_, lr_data_);
@ -63,11 +63,11 @@ int LRScheduler::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
int ret = cb_data.session_->SetLearningRate(lr);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Error setting Leraning rate in train session";
return mindspore::session::RET_EXIT;
return RET_EXIT;
}
}
}
return mindspore::session::RET_CONTINUE;
return RET_CONTINUE;
}
} // namespace lite
} // namespace mindspore

View File

@ -29,13 +29,10 @@ namespace lite {
using dataset::Dataset;
using dataset::Iterator;
using dataset::MSTensorVec;
using session::RET_CONTINUE;
using session::RET_EXIT;
using session::RET_STOP_TRAINING;
TrainLoop::~TrainLoop() {}
int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
int TrainLoop::Train(int epochs, Dataset *ds, std::vector<TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
MS_CHECK_GE(epochs, 0, RET_ERROR);
auto ret = train_session_->Train();
@ -43,7 +40,7 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
MS_LOG(ERROR) << "TrainLoop train failed";
return RET_ERROR;
}
session::TrainLoopCallBackData cb_data(true, epoch_, train_session_, this);
TrainLoopCallBackData cb_data(true, epoch_, train_session_, this);
if (load_func == nullptr) load_func = TrainLoop::LoadData;
@ -106,14 +103,14 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
return RET_OK;
}
int TrainLoop::Eval(Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func, int max_steps) {
int TrainLoop::Eval(Dataset *ds, std::vector<TrainLoopCallBack *> cbs, LoadDataFunc load_func, int max_steps) {
MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
auto ret = train_session_->Eval();
if (ret != RET_OK) {
MS_LOG(ERROR) << "TrainLoop train failed";
return RET_ERROR;
}
session::TrainLoopCallBackData cb_data(false, epoch_, train_session_, this);
TrainLoopCallBackData cb_data(false, epoch_, train_session_, this);
if (load_func == nullptr) load_func = TrainLoop::LoadData;
@ -184,7 +181,7 @@ int TrainLoop::LoadData(std::vector<tensor::MSTensor *> inputs, dataset::MSTenso
}
} // namespace lite
session::TrainLoop *session::TrainLoop::CreateTrainLoop(session::LiteSession *train_session) {
lite::TrainLoop *session::TrainLoop::CreateTrainLoop(lite::LiteSession *train_session) {
auto loop = new (std::nothrow) lite::TrainLoop(train_session);
return loop;
}

View File

@ -32,9 +32,9 @@ namespace lite {
class TrainLoop : virtual public session::TrainLoop {
public:
explicit TrainLoop(session::LiteSession *session) : train_session_(session) {}
explicit TrainLoop(lite::LiteSession *session) : train_session_(session) {}
const session::LiteSession *train_session() override { return train_session_; }
const lite::LiteSession *train_session() override { return train_session_; }
int Reset() override {
epoch_ = 0;
@ -54,9 +54,9 @@ class TrainLoop : virtual public session::TrainLoop {
return RET_OK;
}
int Train(int epochs, dataset::Dataset *dataset, std::vector<session::TrainLoopCallBack *> cbs,
int Train(int epochs, dataset::Dataset *dataset, std::vector<lite::TrainLoopCallBack *> cbs,
LoadDataFunc load_func = nullptr) override;
int Eval(dataset::Dataset *dataset, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func = nullptr,
int Eval(dataset::Dataset *dataset, std::vector<lite::TrainLoopCallBack *> cbs, LoadDataFunc load_func = nullptr,
int max_steps = 0) override;
std::vector<mindspore::session::Metrics *> GetMetrics() override { return metrics_; }
@ -64,7 +64,7 @@ class TrainLoop : virtual public session::TrainLoop {
protected:
static int LoadData(std::vector<tensor::MSTensor *> inputs, dataset::MSTensorVec *dataset_vec);
session::LiteSession *train_session_ = nullptr;
lite::LiteSession *train_session_ = nullptr;
unsigned int epoch_ = 0;
KernelCallBack before_cb_ = nullptr;
KernelCallBack after_cb_ = nullptr;

View File

@ -1293,8 +1293,8 @@ size_t TrainSession::GetInplaceTensorOffset(kernel::KernelExec *kernel,
}
} // namespace lite
session::LiteSession *session::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
bool train_mode, const lite::TrainCfg *cfg) {
lite::LiteSession *lite::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
bool train_mode, const lite::TrainCfg *cfg) {
if (context == nullptr) {
MS_LOG(ERROR) << "context cannot be nullptr";
return nullptr;

View File

@ -22,16 +22,11 @@
#include <memory>
#include <map>
#include "include/train/train_cfg.h"
#include "include/train/train_session.h"
#include "src/runtime/lite_session.h"
/*
Inheritance Diagram
+-------------------------------+
| session::LiteSession |
+------------------------------+
|
+--------------+----------------+
| lite::LiteSession |
+------------------------------+
@ -49,6 +44,11 @@ class TrainSession : virtual public lite::LiteSession {
TrainSession();
~TrainSession();
static LiteSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
const lite::Context *context, bool train_mode = false,
const lite::TrainCfg *cfg = nullptr);
static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context,
bool train_mode = false, const lite::TrainCfg *cfg = nullptr);
int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
int CompileGraph(lite::Model *model) override;

View File

@ -42,7 +42,7 @@ TransferSession::TransferSession(const char *model_buf_backbone, size_t size_bac
if (lite_model_ != nullptr) {
std::copy(model_buf_backbone, model_buf_backbone + size_backbone, lite_model_);
backbone_session_ =
reinterpret_cast<lite::LiteSession *>(session::LiteSession::CreateSession(lite_model_, size_backbone, context));
reinterpret_cast<LiteSession *>(LiteSession::CreateSession(lite_model_, size_backbone, context));
if (backbone_session_ != nullptr) {
is_valid_ = true;
} else {
@ -312,10 +312,10 @@ lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size
}
} // namespace lite
session::LiteSession *session::TrainSession::CreateTransferSession(const std::string &filename_backbone,
const std::string &filename_head,
const lite::Context *ctxt, bool train_mode,
const lite::TrainCfg *cfg) {
lite::LiteSession *lite::TrainSession::CreateTransferSession(const std::string &filename_backbone,
const std::string &filename_head,
const lite::Context *ctxt, bool train_mode,
const lite::TrainCfg *cfg) {
size_t size_head = 0;
size_t size_backbone = 0;
std::string filename = filename_head;

View File

@ -65,7 +65,7 @@ class TransferSession : public lite::TrainSession {
std::vector<std::string> out_put_tensor_name = {}) override;
protected:
lite::LiteSession *backbone_session_ = nullptr;
LiteSession *backbone_session_ = nullptr;
char *lite_model_ = nullptr;
std::vector<mindspore::tensor::MSTensor *> combined_inputs_;
std::vector<std::pair<mindspore::tensor::MSTensor *, mindspore::tensor::MSTensor *>> backbone_head_map_;

View File

@ -46,7 +46,7 @@ TEST_F(GraphTest, UserSetGraphOutput1) {
auto context = std::make_shared<lite::Context>();
ASSERT_NE(context, nullptr);
session::LiteSession *session = session::LiteSession::CreateSession(context.get());
lite::LiteSession *session = lite::LiteSession::CreateSession(context.get());
ASSERT_NE(session, nullptr);
int benchmark_ret = session->CompileGraph(model.get());

View File

@ -25,9 +25,8 @@ class MindrtRuntimeTest : public mindspore::CommonTest {
MindrtRuntimeTest() = default;
};
int CheckRuntime(mindspore::session::LiteSession *session) {
mindspore::lite::LiteSession *lite_session = reinterpret_cast<mindspore::lite::LiteSession *>(session);
auto kernels = lite_session->get_kernels();
int CheckRuntime(mindspore::lite::LiteSession *session) {
auto kernels = session->get_kernels();
int cpu_kernel_count = 0;
int gpu_kernel_count = 0;
@ -69,7 +68,7 @@ TEST_F(MindrtRuntimeTest, Runtime) {
gpu_device_ctx.device_info_.gpu_device_info_.enable_float16_ = false;
context->device_list_.push_back(gpu_device_ctx);
mindspore::session::LiteSession *session = mindspore::session::LiteSession::CreateSession(context.get());
mindspore::lite::LiteSession *session = mindspore::lite::LiteSession::CreateSession(context.get());
ASSERT_NE(session, nullptr);
int benchmark_ret = session->CompileGraph(model.get());
@ -87,9 +86,8 @@ TEST_F(MindrtRuntimeTest, Runtime) {
delete session;
}
int CheckRuntime2(mindspore::session::LiteSession *session) {
mindspore::lite::LiteSession *lite_session = reinterpret_cast<mindspore::lite::LiteSession *>(session);
auto kernels = lite_session->get_kernels();
int CheckRuntime2(mindspore::lite::LiteSession *session) {
auto kernels = session->get_kernels();
for (auto kernel : kernels) {
if (kernel->subgraph_type() != mindspore::kernel::kCpuFP16SubGraph) {
@ -119,7 +117,7 @@ TEST_F(MindrtRuntimeTest, RuntimeFp16) {
auto &cpu_device_ctx = context->device_list_[0];
cpu_device_ctx.device_info_.cpu_device_info_.enable_float16_ = true;
mindspore::session::LiteSession *session = mindspore::session::LiteSession::CreateSession(context.get());
mindspore::lite::LiteSession *session = mindspore::lite::LiteSession::CreateSession(context.get());
ASSERT_NE(session, nullptr);
int benchmark_ret = session->CompileGraph(model.get());

View File

@ -31,7 +31,7 @@ class MindrtParallelTest : public mindspore::CommonTest {
MindrtParallelTest() {}
};
int CheckOffline1(session::LiteSession *session) {
int CheckOffline1(lite::LiteSession *session) {
/* ----------- start check -------------- */
lite::LiteSession *lite_session = reinterpret_cast<lite::LiteSession *>(session);
auto kernels = lite_session->get_kernels();
@ -86,7 +86,7 @@ int CheckOffline1(session::LiteSession *session) {
return lite::RET_OK;
}
int CheckRuntime1(session::LiteSession *session) {
int CheckRuntime1(lite::LiteSession *session) {
lite::LiteSession *lite_session = reinterpret_cast<lite::LiteSession *>(session);
auto kernels = lite_session->get_kernels();
if (kernels.size() != 6) {
@ -115,7 +115,7 @@ TEST_F(MindrtParallelTest, offline1) {
ASSERT_NE(context, nullptr);
context->enable_parallel_ = true;
session::LiteSession *session = session::LiteSession::CreateSession(context.get());
lite::LiteSession *session = lite::LiteSession::CreateSession(context.get());
ASSERT_NE(session, nullptr);
int benchmark_ret = session->CompileGraph(model.get());
@ -153,7 +153,7 @@ TEST_F(MindrtParallelTest, runtime1) {
ASSERT_NE(context, nullptr);
context->enable_parallel_ = true;
session::LiteSession *session = session::LiteSession::CreateSession(context.get());
lite::LiteSession *session = lite::LiteSession::CreateSession(context.get());
ASSERT_NE(session, nullptr);
int benchmark_ret = session->CompileGraph(model.get());

View File

@ -326,7 +326,7 @@ TEST_F(SubGraphTest, RecursiveSubGraphTest) {
auto &cpu_device_ctx = context.device_list_[0];
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU;
context.thread_num_ = 2;
auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&context));
auto session = std::shared_ptr<lite::LiteSession>(lite::LiteSession::CreateSession(&context));
ASSERT_NE(session, nullptr);
auto ret = session->CompileGraph(model.get());
ASSERT_EQ(ret, lite::RET_OK);

View File

@ -102,7 +102,7 @@ TEST_F(InferTest, TestConvNode) {
device_list.push_back(device_ctx);
context->thread_num_ = 4;
ASSERT_EQ(lite::RET_OK, context->Init());
auto session = session::LiteSession::CreateSession(context);
auto session = lite::LiteSession::CreateSession(context);
ASSERT_NE(nullptr, session);
auto ret = session->CompileGraph(model);
ASSERT_EQ(lite::RET_OK, ret);
@ -198,7 +198,7 @@ TEST_F(InferTest, TestAddNode) {
device_list.push_back(device_ctx);
context->thread_num_ = 4;
ASSERT_EQ(lite::RET_OK, context->Init());
auto session = session::LiteSession::CreateSession(context);
auto session = lite::LiteSession::CreateSession(context);
ASSERT_NE(nullptr, session);
auto ret = session->CompileGraph(model);
ASSERT_EQ(lite::RET_OK, ret);
@ -234,7 +234,7 @@ TEST_F(InferTest, TestModel) {
context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
context->thread_num_ = 4;
ASSERT_EQ(lite::RET_OK, context->Init());
auto session = session::LiteSession::CreateSession(context);
auto session = lite::LiteSession::CreateSession(context);
ASSERT_NE(nullptr, session);
auto ret = session->CompileGraph(model);
ASSERT_EQ(lite::RET_OK, ret);

View File

@ -27,7 +27,7 @@
#include "include/context.h"
#include "include/errorcode.h"
#include "include/train/train_cfg.h"
#include "include/train/train_session.h"
#include "src/train/train_session.h"
#include "src/common/log_adapter.h"
#include "src/common/file_utils.h"
#include "src/runtime/kernel_registry.h"
@ -38,12 +38,12 @@ namespace mindspore {
class NetworkTest : public mindspore::CommonTest {
public:
NetworkTest() {}
int32_t runNet(mindspore::session::LiteSession *session, const std::string &in, const std::string &out,
int32_t runNet(mindspore::lite::LiteSession *session, const std::string &in, const std::string &out,
const char *tensor_name, bool debug = false);
};
int32_t fileIterator(mindspore::session::LiteSession *session, const std::string &path,
std::function<int32_t(mindspore::session::LiteSession *session, const std::string &)> cb) {
int32_t fileIterator(mindspore::lite::LiteSession *session, const std::string &path,
std::function<int32_t(mindspore::lite::LiteSession *session, const std::string &)> cb) {
int32_t res = 0;
if (auto dir = opendir(path.c_str())) {
while (auto f = readdir(dir)) {
@ -58,7 +58,7 @@ int32_t fileIterator(mindspore::session::LiteSession *session, const std::string
}
void replaceExt(const std::string &src, std::string *dst) { *dst = src.substr(0, src.find_last_of('.')) + ".emb"; }
int32_t NetworkTest::runNet(mindspore::session::LiteSession *session, const std::string &in, const std::string &out,
int32_t NetworkTest::runNet(mindspore::lite::LiteSession *session, const std::string &in, const std::string &out,
const char *tensor_name, bool debug) {
// setup input
auto inputs = session->GetInputs();
@ -101,7 +101,7 @@ TEST_F(NetworkTest, efficient_net) {
context->thread_num_ = 1;
std::string net = "./nets/effnetb0_fwd_nofuse.ms";
auto session = session::TrainSession::CreateTrainSession(net, context, false);
auto session = lite::TrainSession::CreateTrainSession(net, context, false);
ASSERT_NE(session, nullptr);
std::string in = "./nets/effNet_input_x_1_3_224_224.bin";
@ -122,7 +122,7 @@ TEST_F(NetworkTest, mobileface_net) {
context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
context->thread_num_ = 1;
auto session = session::LiteSession::CreateSession(context);
auto session = lite::LiteSession::CreateSession(context);
ASSERT_NE(session, nullptr);
auto ret = session->CompileGraph(model);
ASSERT_EQ(lite::RET_OK, ret);
@ -147,7 +147,7 @@ TEST_F(NetworkTest, noname) {
lite::TrainCfg cfg;
cfg.loss_name_.clear();
cfg.loss_name_.emplace_back("nhwc");
auto session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &cfg);
auto session = mindspore::lite::TrainSession::CreateTrainSession(net, &context, true, &cfg);
ASSERT_NE(session, nullptr);
auto tensors_map = session->GetOutputs();
auto tensor_names = session->GetOutputTensorNames();
@ -168,7 +168,7 @@ TEST_F(NetworkTest, setname) {
train_cfg.loss_name_.clear();
train_cfg.loss_name_.emplace_back("nhwc");
auto session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &train_cfg);
auto session = mindspore::lite::TrainSession::CreateTrainSession(net, &context, true, &train_cfg);
ASSERT_NE(session, nullptr);
auto tensors_map = session->GetOutputs();

View File

@ -68,7 +68,6 @@ add_executable(benchmark
${CMAKE_CURRENT_SOURCE_DIR}/main.cc
${CMAKE_CURRENT_SOURCE_DIR}/run_benchmark.cc
${CMAKE_CURRENT_SOURCE_DIR}/benchmark_base.cc
${CMAKE_CURRENT_SOURCE_DIR}/benchmark.cc
${CMAKE_CURRENT_SOURCE_DIR}/benchmark_unified_api.cc
${CMAKE_CURRENT_SOURCE_DIR}/benchmark_c_api.cc
${COMMON_SRC})

View File

@ -17,7 +17,6 @@
#include "tools/benchmark/run_benchmark.h"
#include <string>
#include <memory>
#include "tools/benchmark/benchmark.h"
#include "tools/benchmark/benchmark_unified_api.h"
#include "tools/benchmark/benchmark_c_api.h"
@ -51,8 +50,6 @@ int RunBenchmark(int argc, const char **argv) {
BenchmarkBase *benchmark = nullptr;
if (api_type == nullptr || std::string(api_type) == "NEW") {
benchmark = new (std::nothrow) BenchmarkUnifiedApi(&flags);
} else if (std::string(api_type) == "OLD") {
benchmark = new (std::nothrow) Benchmark(&flags);
} else if (std::string(api_type) == "C") {
benchmark = new (std::nothrow) tools::BenchmarkCApi(&flags);
} else {