forked from mindspore-Ecosystem/mindspore
!34451 [MS][LITE] remove lite session api
Merge pull request !34451 from sunsuodong/rm_lite_session
This commit is contained in:
commit
f35e313d33
|
@ -192,6 +192,7 @@ mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition
|
|||
mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::PartitionNode
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/instance_norm_fp16.c:InstanceNormNC8HW8Fp16
|
||||
mindspore/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable
|
||||
mindspore/mindspore/lite/src/train/train_loop.cc:mindspore::lite::TrainLoop::Train
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_winograd_fp32.c:ConvWinogardFp32
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc:mindspore::opt::MatchAdd5Pattern
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/conv_fp32_nchwx_avx512.c:conv2d_compute_fp32_nchwx_avx512
|
||||
|
|
|
@ -1,254 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_
|
||||
|
||||
#ifndef NOT_USE_STL
|
||||
#include <unordered_map>
|
||||
#endif // NOT_USE_STL
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/model.h"
|
||||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "include/lite_types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TrainCfg;
|
||||
}
|
||||
|
||||
namespace session {
|
||||
/// \brief LiteSession defined session in MindSpore Lite for compiling Model and forwarding model.
|
||||
class MS_API LiteSession {
|
||||
public:
|
||||
/// \brief Static method to create a LiteSession pointer.
|
||||
///
|
||||
/// \param[in] context Define the context of session to be created.
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite LiteSession.
|
||||
static LiteSession *CreateSession(const lite::Context *context);
|
||||
|
||||
/// \brief Static method to create a LiteSession pointer which has already compiled a model.
|
||||
///
|
||||
/// \param[in] model_buf Define the buffer read from a model file.
|
||||
/// \param[in] size Define bytes number of model buffer.
|
||||
/// \param[in] context Define the context of session to be created.
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite LiteSession.
|
||||
static LiteSession *CreateSession(const char *model_buf, size_t size, const lite::Context *context);
|
||||
|
||||
/// \brief Destructor of MindSpore Lite LiteSession.
|
||||
virtual ~LiteSession() = default;
|
||||
|
||||
/// \brief Attempt to bind or unbind threads in the thread pool to or from the specified cpu core.
|
||||
///
|
||||
/// \param[in] if_bind Define whether to bind or unbind threads.
|
||||
virtual void BindThread(bool if_bind) = 0;
|
||||
|
||||
/// \brief Compile MindSpore Lite model.
|
||||
///
|
||||
/// \note CompileGraph should be called before RunGraph.
|
||||
///
|
||||
/// \param[in] model Define the model to be compiled.
|
||||
///
|
||||
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h.
|
||||
virtual int CompileGraph(lite::Model *model) = 0;
|
||||
|
||||
/// \brief Get input MindSpore Lite MSTensors of model.
|
||||
///
|
||||
/// \return The vector of MindSpore Lite MSTensor.
|
||||
virtual Vector<tensor::MSTensor *> GetInputs() const = 0;
|
||||
|
||||
/// \brief Get input MindSpore Lite MSTensors of model by tensor name.
|
||||
///
|
||||
/// \param[in] node_name Define tensor name.
|
||||
///
|
||||
/// \return The vector of MindSpore Lite MSTensor.
|
||||
virtual mindspore::tensor::MSTensor *GetInputsByTensorName(const String &tensor_name) const = 0;
|
||||
|
||||
/// \brief Run session with callback.
|
||||
///
|
||||
/// \param[in] before Define a call_back_function to be called before running each node.
|
||||
/// \param[in] after Define a call_back_function called after running each node.
|
||||
///
|
||||
/// \note RunGraph should be called after CompileGraph.
|
||||
///
|
||||
/// \return STATUS as an error code of running graph, STATUS is defined in errorcode.h.
|
||||
virtual int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) = 0;
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model by node name.
|
||||
///
|
||||
/// \param[in] node_name Define node name.
|
||||
///
|
||||
/// \note Deprecated, replace with GetOutputByTensorName
|
||||
///
|
||||
/// \return The vector of MindSpore Lite MSTensor.
|
||||
virtual Vector<tensor::MSTensor *> GetOutputsByNodeName(const String &node_name) const = 0;
|
||||
|
||||
#ifndef NOT_USE_STL
|
||||
/// \brief Get output MindSpore Lite MSTensors of model mapped by tensor name.
|
||||
///
|
||||
/// \return The map of output tensor name and MindSpore Lite MSTensor.
|
||||
virtual std::unordered_map<String, mindspore::tensor::MSTensor *> GetOutputs() const = 0;
|
||||
#endif
|
||||
|
||||
/// \brief Get name of output tensors of model compiled by this session.
|
||||
///
|
||||
/// \return The vector of string as output tensor names in order.
|
||||
virtual Vector<String> GetOutputTensorNames() const = 0;
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model by tensor name.
|
||||
///
|
||||
/// \param[in] tensor_name Define tensor name.
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite MSTensor.
|
||||
virtual mindspore::tensor::MSTensor *GetOutputByTensorName(const String &tensor_name) const = 0;
|
||||
|
||||
/// \brief Bind GLTexture2D object to cl Memory.
|
||||
///
|
||||
/// \param[in] inputGlTexture The input GLTexture id for Model.
|
||||
/// \param[in] outputGLTexture The output GLTexture id for Model.
|
||||
///
|
||||
/// \return Status of operation.
|
||||
virtual int BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGlTexture,
|
||||
std::map<std::string, unsigned int> *outputGLTexture) = 0;
|
||||
|
||||
/// \brief Resize inputs shape.
|
||||
///
|
||||
/// \param[in] inputs Define the inputs of the model.
|
||||
/// \param[in] dims Define the inputs new shape.
|
||||
///
|
||||
/// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h.
|
||||
virtual int Resize(const Vector<tensor::MSTensor *> &inputs, const Vector<Vector<int>> &dims) = 0;
|
||||
|
||||
/// \brief Set model to train mode
|
||||
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
|
||||
virtual int Train() { return mindspore::lite::RET_ERROR; }
|
||||
|
||||
/// \brief Check mode of model
|
||||
///
|
||||
/// \return boolean indication if model is in train mode
|
||||
virtual bool IsTrain() { return false; }
|
||||
|
||||
/// \brief Set model to eval mode
|
||||
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
|
||||
virtual int Eval() { return mindspore::lite::RET_OK; }
|
||||
|
||||
/// \brief Check mode of model
|
||||
///
|
||||
/// \return boolean indication if model is in eval mode
|
||||
virtual bool IsEval() { return true; }
|
||||
|
||||
/// \brief Sets the Learning Rate of the training
|
||||
///
|
||||
/// \param[in] learning_rate to set
|
||||
///
|
||||
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
||||
virtual int SetLearningRate(float learning_rate) { return mindspore::lite::RET_ERROR; }
|
||||
|
||||
/// \brief Gets the Learning Rate of the training
|
||||
///
|
||||
/// \return learning rate. 0.0 if no optimizer was found
|
||||
virtual float GetLearningRate() { return 0.0; }
|
||||
|
||||
/// \brief Setup training with virtual batches
|
||||
///
|
||||
/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable
|
||||
/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration
|
||||
/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration
|
||||
|
||||
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
||||
virtual int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) {
|
||||
return mindspore::lite::RET_ERROR;
|
||||
}
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of Training model prediction
|
||||
///
|
||||
/// \return a vector of output tensors (MindSpore Lite MSTensor).
|
||||
virtual std::vector<tensor::MSTensor *> GetPredictions() const {
|
||||
std::vector<tensor::MSTensor *> outputs;
|
||||
return outputs;
|
||||
}
|
||||
|
||||
/// \brief Save model
|
||||
/// \param[in] file_name pretrained model file name prefix. '.ms' extenension is added if does not exist
|
||||
/// \param[in] model_type indication whether to save full model or only the inference part
|
||||
/// \param[in] quant_type indication whether to quantize exported model
|
||||
/// \param[in] format of exported file (currently only FT_FLATBUFFERS is supported)
|
||||
/// \param[in] out_put_tensor_name of exported tensorname
|
||||
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
||||
virtual int Export(const std::string &file_name, lite::ModelType model_type = lite::MT_TRAIN,
|
||||
lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS,
|
||||
std::vector<std::string> out_put_tensor_name = {}) {
|
||||
return mindspore::lite::RET_ERROR;
|
||||
}
|
||||
|
||||
/// \brief Change the size and or content of weight tensors
|
||||
///
|
||||
/// \param[in] new_weights a vector of tensors with new shapes and data to use in the model
|
||||
/// If data pointer is null, the data of the original tensors will be copied to the new ones
|
||||
///
|
||||
/// \return STATUS as an error code of operation, STATUS is defined in errorcode.h.
|
||||
virtual int UpdateWeights(std::vector<tensor::MSTensor *> new_weights) { return mindspore::lite::RET_ERROR; }
|
||||
|
||||
/// \brief Get model featuremap MindSpore Lite MSTensors of Training model prediction
|
||||
///
|
||||
/// \return a vector of output tensors (MindSpore Lite MSTensor).
|
||||
virtual std::vector<tensor::MSTensor *> GetFeatureMaps() const {
|
||||
std::vector<tensor::MSTensor *> features;
|
||||
return features;
|
||||
}
|
||||
|
||||
/// \brief update model featuremap save to update_ms_file
|
||||
/// \param[in] features new featuremap
|
||||
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
||||
virtual int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features) { return mindspore::lite::RET_ERROR; }
|
||||
|
||||
/// \brief Get model gradient
|
||||
///
|
||||
/// \return a vector of gradient tensors (MindSpore Lite MSTensor).
|
||||
virtual std::vector<tensor::MSTensor *> GetGradients() const {
|
||||
std::vector<tensor::MSTensor *> gradients;
|
||||
return gradients;
|
||||
}
|
||||
|
||||
/// \brief update model gradient
|
||||
///
|
||||
/// \param[in] new gradients
|
||||
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
||||
virtual int ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) { return mindspore::lite::RET_ERROR; }
|
||||
|
||||
/// \brief Get model optimizer params
|
||||
///
|
||||
/// \return a vector of optimizer parameters (MindSpore Lite MSTensor).
|
||||
virtual std::vector<tensor::MSTensor *> GetOptimizerParams() const {
|
||||
std::vector<tensor::MSTensor *> params;
|
||||
return params;
|
||||
}
|
||||
|
||||
/// \brief set model optimizer params
|
||||
///
|
||||
/// \param[in] new optimizer params
|
||||
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
||||
virtual int SetOptimizerParams(const std::vector<tensor::MSTensor *> ¶ms) { return mindspore::lite::RET_ERROR; }
|
||||
};
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_
|
|
@ -24,13 +24,13 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
class AccuracyMonitor : public session::TrainLoopCallBack {
|
||||
class AccuracyMonitor : public TrainLoopCallBack {
|
||||
public:
|
||||
explicit AccuracyMonitor(mindspore::dataset::Dataset *dataset, int check_every_n, int max_steps = -1)
|
||||
: ds_(dataset), check_every_n_(check_every_n), max_steps_(max_steps) {}
|
||||
~AccuracyMonitor() = default;
|
||||
void Begin(const session::TrainLoopCallBackData &cb_data) override;
|
||||
int EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) override;
|
||||
void Begin(const TrainLoopCallBackData &cb_data) override;
|
||||
int EpochEnd(const TrainLoopCallBackData &cb_data) override;
|
||||
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -25,20 +25,20 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
class CkptSaver : public session::TrainLoopCallBack {
|
||||
class CkptSaver : public TrainLoopCallBack {
|
||||
public:
|
||||
CkptSaver(size_t save_every_n, std::string filename_prefix)
|
||||
: save_every_n_(save_every_n), filename_prefix_(std::move(filename_prefix)) {}
|
||||
|
||||
~CkptSaver() override = default;
|
||||
|
||||
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override {
|
||||
int EpochEnd(const TrainLoopCallBackData &cb_data) override {
|
||||
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
|
||||
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";
|
||||
remove(cpkt_fn.c_str());
|
||||
cb_data.session_->Export(cpkt_fn);
|
||||
}
|
||||
return session::RET_CONTINUE;
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
class ClassificationTrainAccuracyMonitor : public session::TrainLoopCallBack {
|
||||
class ClassificationTrainAccuracyMonitor : public TrainLoopCallBack {
|
||||
public:
|
||||
explicit ClassificationTrainAccuracyMonitor(int print_every_n = INT_MAX,
|
||||
int accuracy_metrics = METRICS_CLASSIFICATION,
|
||||
|
@ -35,10 +35,10 @@ class ClassificationTrainAccuracyMonitor : public session::TrainLoopCallBack {
|
|||
const std::vector<int> &output_indexes = {0});
|
||||
virtual ~ClassificationTrainAccuracyMonitor() = default;
|
||||
|
||||
void Begin(const session::TrainLoopCallBackData &cb_data) override;
|
||||
void EpochBegin(const session::TrainLoopCallBackData &cb_data) override;
|
||||
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
|
||||
void StepEnd(const session::TrainLoopCallBackData &cb_data) override;
|
||||
void Begin(const TrainLoopCallBackData &cb_data) override;
|
||||
void EpochBegin(const TrainLoopCallBackData &cb_data) override;
|
||||
int EpochEnd(const TrainLoopCallBackData &cb_data) override;
|
||||
void StepEnd(const TrainLoopCallBackData &cb_data) override;
|
||||
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
class LossMonitor : public session::TrainLoopCallBack {
|
||||
class LossMonitor : public TrainLoopCallBack {
|
||||
public:
|
||||
/// \brief constructor
|
||||
///
|
||||
|
@ -38,10 +38,10 @@ class LossMonitor : public session::TrainLoopCallBack {
|
|||
/// \return 0 on success or -1 in case of error
|
||||
explicit LossMonitor(int print_every_n_steps = INT_MAX) : print_every_n_(print_every_n_steps) {}
|
||||
virtual ~LossMonitor() = default;
|
||||
void Begin(const session::TrainLoopCallBackData &cb_data) override;
|
||||
void EpochBegin(const session::TrainLoopCallBackData &cb_data) override;
|
||||
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
|
||||
void StepEnd(const session::TrainLoopCallBackData &cb_data) override;
|
||||
void Begin(const TrainLoopCallBackData &cb_data) override;
|
||||
void EpochBegin(const TrainLoopCallBackData &cb_data) override;
|
||||
int EpochEnd(const TrainLoopCallBackData &cb_data) override;
|
||||
void StepEnd(const TrainLoopCallBackData &cb_data) override;
|
||||
const std::vector<GraphPoint> &GetLossPoints() const { return losses_; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -42,11 +42,11 @@ struct StepLRLambda {
|
|||
float gamma; // LR decay factor
|
||||
};
|
||||
|
||||
class LRScheduler : public session::TrainLoopCallBack {
|
||||
class LRScheduler : public TrainLoopCallBack {
|
||||
public:
|
||||
explicit LRScheduler(LR_Lambda lambda_func, void *lr_cb_data = nullptr, int step_ = 1);
|
||||
virtual ~LRScheduler() = default;
|
||||
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
|
||||
int EpochEnd(const TrainLoopCallBackData &cb_data) override;
|
||||
|
||||
private:
|
||||
LR_Lambda lambda_func_;
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <unordered_map>
|
||||
#include "include/train/train_loop_callback.h"
|
||||
#include "include/train/metrics.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "src/runtime/lite_session.h"
|
||||
|
||||
namespace mindspore {
|
||||
class MSTensor;
|
||||
|
@ -42,7 +42,7 @@ class TrainLoop {
|
|||
/// \param[in] train_session Train session object as return from CreateSession\CreateTransferSession API
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite TrainLoop
|
||||
static TrainLoop *CreateTrainLoop(session::LiteSession *train_session);
|
||||
static lite::TrainLoop *CreateTrainLoop(lite::LiteSession *train_session);
|
||||
|
||||
/// \brief Class destructor
|
||||
virtual ~TrainLoop() = default;
|
||||
|
@ -55,7 +55,7 @@ class TrainLoop {
|
|||
/// \brief Accessor to the LiteSession
|
||||
///
|
||||
/// \return pointer of the train_session
|
||||
const virtual session::LiteSession *train_session() = 0;
|
||||
const virtual lite::LiteSession *train_session() = 0;
|
||||
|
||||
/// \brief Initialize object with metrics
|
||||
///
|
||||
|
@ -85,7 +85,7 @@ class TrainLoop {
|
|||
/// \param[in] load_func a function that load (and can manipulate) data from Minddata Dataset array into model
|
||||
///
|
||||
/// \return 0 on success or -1 in case of error
|
||||
virtual int Train(int epochs, mindspore::dataset::Dataset *dataset, std::vector<TrainLoopCallBack *> cbs,
|
||||
virtual int Train(int epochs, mindspore::dataset::Dataset *dataset, std::vector<lite::TrainLoopCallBack *> cbs,
|
||||
LoadDataFunc load_func) = 0;
|
||||
|
||||
/// \brief Performs loop over all data in Eval Mode
|
||||
|
@ -96,8 +96,8 @@ class TrainLoop {
|
|||
/// \param[in] max_steps (with default = INT_MAX the method iterates all dataset)
|
||||
///
|
||||
/// \return 0 on success or -1 in case of error
|
||||
virtual int Eval(mindspore::dataset::Dataset *dataset, std::vector<TrainLoopCallBack *> cbs, LoadDataFunc load_func,
|
||||
int max_steps) = 0;
|
||||
virtual int Eval(mindspore::dataset::Dataset *dataset, std::vector<lite::TrainLoopCallBack *> cbs,
|
||||
LoadDataFunc load_func, int max_steps) = 0;
|
||||
};
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
#include "include/api/callback/callback.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
||||
namespace lite {
|
||||
class LiteSession;
|
||||
class TrainLoop;
|
||||
|
||||
|
@ -83,6 +82,6 @@ class TrainLoopCallBack {
|
|||
virtual void StepEnd(const TrainLoopCallBackData &cb_data) {}
|
||||
};
|
||||
|
||||
} // namespace session
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
|
||||
|
|
|
@ -1,50 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
|
||||
#include <string>
|
||||
#include "include/lite_session.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
class TrainSession {
|
||||
public:
|
||||
/// \brief Static method to create a TransferSession object
|
||||
///
|
||||
/// \param[in] filename_backbone Filename to read backbone net flatbuffer from
|
||||
/// \param[in] filename_head Filename to read head net flatbuffer from
|
||||
/// \param[in] context Defines the context of the session to be created
|
||||
/// \param[in] train_mode training mode to initialize Session with
|
||||
///
|
||||
/// \return Pointer of MindSpore LiteSession
|
||||
static LiteSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
|
||||
const lite::Context *context, bool train_mode = false,
|
||||
const lite::TrainCfg *cfg = nullptr);
|
||||
|
||||
/// \brief Static method to create a TrainSession object
|
||||
///
|
||||
/// \param[in] filename name of flatbuffer that holds the flatbuffer
|
||||
/// \param[in] context Defines the context of the session to be created
|
||||
/// \param[in] train_mode training mode to initialize Session with
|
||||
/// \param[in] cfg training configuration, set to null for default configuration
|
||||
///
|
||||
/// \return Pointer of MindSpore LiteSession
|
||||
static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context,
|
||||
bool train_mode = false, const lite::TrainCfg *cfg = nullptr);
|
||||
};
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
|
|
@ -79,7 +79,7 @@ set(JNI_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/version.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_config.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_tensor.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/jni_utils.cpp
|
||||
${NEW_NATIVE_DIR}/graph.cpp
|
||||
${NEW_NATIVE_DIR}/model.cpp
|
||||
|
@ -110,7 +110,7 @@ endif()
|
|||
if(SUPPORT_TRAIN)
|
||||
set(LITE_TRAIN_SO_NAME mindspore-lite-train)
|
||||
set(JNI_TRAIN_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
|
||||
${NEW_NATIVE_DIR}/train_config.cpp
|
||||
)
|
||||
add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC})
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include <jni.h>
|
||||
#include <fstream>
|
||||
#include "common/ms_log.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "src/runtime/lite_session.h"
|
||||
#include "include/errorcode.h"
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSessionWithModel(JNIEnv *env, jobject thiz,
|
||||
jobject model_buffer,
|
||||
|
@ -37,7 +37,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSes
|
|||
}
|
||||
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
|
||||
// create session
|
||||
auto session = mindspore::session::LiteSession::CreateSession(model_buf, buffer_len, lite_context_ptr);
|
||||
auto session = mindspore::lite::LiteSession::CreateSession(model_buf, buffer_len, lite_context_ptr);
|
||||
if (session == nullptr) {
|
||||
MS_LOGE("CreateSession failed");
|
||||
return jlong(nullptr);
|
||||
|
@ -53,7 +53,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSes
|
|||
return jlong(nullptr);
|
||||
}
|
||||
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
|
||||
auto session = mindspore::session::LiteSession::CreateSession(lite_context_ptr);
|
||||
auto session = mindspore::lite::LiteSession::CreateSession(lite_context_ptr);
|
||||
if (session == nullptr) {
|
||||
MS_LOGE("CreateSession failed");
|
||||
return jlong(nullptr);
|
||||
|
@ -69,7 +69,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_compil
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
|
||||
auto *model_pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (model_pointer == nullptr) {
|
||||
MS_LOGE("Model pointer from java is nullptr");
|
||||
|
@ -88,7 +88,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_bindThread
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
|
||||
lite_session_ptr->BindThread(if_bind);
|
||||
}
|
||||
|
||||
|
@ -99,7 +99,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_runGra
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
|
||||
auto ret = lite_session_ptr->RunGraph();
|
||||
return (jboolean)(ret == mindspore::lite::RET_OK);
|
||||
}
|
||||
|
@ -118,7 +118,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return ret;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
|
||||
auto inputs = lite_session_ptr->GetInputs();
|
||||
for (auto input : inputs) {
|
||||
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
|
||||
|
@ -135,7 +135,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getInputs
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return jlong(nullptr);
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
|
||||
auto input = lite_session_ptr->GetInputsByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
|
||||
return jlong(input);
|
||||
}
|
||||
|
@ -155,7 +155,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return ret;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
|
||||
auto inputs = lite_session_ptr->GetOutputsByNodeName(env->GetStringUTFChars(node_name, JNI_FALSE));
|
||||
for (auto input : inputs) {
|
||||
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
|
||||
|
@ -177,7 +177,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return hash_map;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
|
||||
auto outputs = lite_session_ptr->GetOutputs();
|
||||
jclass long_object = env->FindClass("java/lang/Long");
|
||||
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
|
||||
|
@ -203,7 +203,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return ret;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
|
||||
auto output_names = lite_session_ptr->GetOutputTensorNames();
|
||||
for (const auto &output_name : output_names) {
|
||||
env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str()));
|
||||
|
@ -219,7 +219,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getOutput
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return jlong(nullptr);
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
|
||||
auto output = lite_session_ptr->GetOutputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
|
||||
return jlong(output);
|
||||
}
|
||||
|
@ -231,7 +231,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_free(JNIEn
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
|
||||
delete (lite_session_ptr);
|
||||
}
|
||||
|
||||
|
@ -244,7 +244,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_resize
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return false;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
|
||||
|
||||
auto input_size = static_cast<int>(env->GetArrayLength(inputs));
|
||||
jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
|
||||
|
@ -300,7 +300,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_train(
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
|
||||
auto ret = lite_session_ptr->Train();
|
||||
return (jboolean)(ret == mindspore::lite::RET_OK);
|
||||
}
|
||||
|
@ -312,7 +312,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_eval(J
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
|
||||
auto ret = lite_session_ptr->Eval();
|
||||
return (jboolean)(ret == mindspore::lite::RET_OK);
|
||||
}
|
||||
|
@ -324,7 +324,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_isTrai
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
|
||||
auto ret = lite_session_ptr->IsTrain();
|
||||
return (jboolean)(ret);
|
||||
}
|
||||
|
@ -336,7 +336,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_isEval
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
|
||||
auto ret = lite_session_ptr->IsEval();
|
||||
return (jboolean)(ret);
|
||||
}
|
||||
|
@ -349,7 +349,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setLea
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
|
||||
auto ret = lite_session_ptr->SetLearningRate(learning_rate);
|
||||
return (jboolean)(ret == mindspore::lite::RET_OK);
|
||||
}
|
||||
|
@ -364,7 +364,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setupV
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
|
||||
auto *lite_session_ptr = static_cast<mindspore::lite::LiteSession *>(session_pointer);
|
||||
auto ret = lite_session_ptr->SetupVirtualBatch(virtual_batch_factor, learning_rate, momentum);
|
||||
return (jboolean)(ret == mindspore::lite::RET_OK);
|
||||
}
|
||||
|
@ -384,7 +384,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_update
|
|||
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(tensor_pointer);
|
||||
newFeatures.emplace_back(ms_tensor_ptr);
|
||||
}
|
||||
auto session = reinterpret_cast<mindspore::session::LiteSession *>(session_ptr);
|
||||
auto session = reinterpret_cast<mindspore::lite::LiteSession *>(session_ptr);
|
||||
auto ret = session->UpdateFeatureMaps(newFeatures);
|
||||
return (jboolean)(ret == mindspore::lite::RET_OK);
|
||||
}
|
||||
|
@ -403,7 +403,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getFeat
|
|||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return ret;
|
||||
}
|
||||
auto *train_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto *train_session_ptr = static_cast<mindspore::lite::LiteSession *>(pointer);
|
||||
auto inputs = train_session_ptr->GetFeatureMaps();
|
||||
for (auto input : inputs) {
|
||||
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include <jni.h>
|
||||
#include "common/ms_log.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "src/train/train_session.h"
|
||||
#include "include/train/train_cfg.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
|
@ -32,7 +32,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createTr
|
|||
}
|
||||
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
|
||||
|
||||
auto session = mindspore::session::TrainSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE),
|
||||
auto session = mindspore::lite::TrainSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE),
|
||||
lite_context_ptr, train_mode, nullptr);
|
||||
if (session == nullptr) {
|
||||
MS_LOGE("CreateTrainSession failed");
|
||||
|
|
|
@ -22,32 +22,32 @@
|
|||
|
||||
namespace mindspore {
|
||||
|
||||
class TrainLoopCallBackAdapter : public session::TrainLoopCallBack {
|
||||
class TrainLoopCallBackAdapter : public lite::TrainLoopCallBack {
|
||||
public:
|
||||
explicit TrainLoopCallBackAdapter(Model *model, TrainCallBack *call_back) : model_(model), call_back_(call_back) {}
|
||||
TrainLoopCallBackAdapter() = delete;
|
||||
|
||||
void Begin(const session::TrainLoopCallBackData &i_cb_data) override {
|
||||
void Begin(const lite::TrainLoopCallBackData &i_cb_data) override {
|
||||
call_back_->Begin(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
|
||||
};
|
||||
|
||||
void End(const session::TrainLoopCallBackData &i_cb_data) override {
|
||||
void End(const lite::TrainLoopCallBackData &i_cb_data) override {
|
||||
call_back_->End(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
|
||||
};
|
||||
|
||||
void EpochBegin(const session::TrainLoopCallBackData &i_cb_data) override {
|
||||
void EpochBegin(const lite::TrainLoopCallBackData &i_cb_data) override {
|
||||
call_back_->EpochBegin(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
|
||||
};
|
||||
|
||||
int EpochEnd(const session::TrainLoopCallBackData &i_cb_data) override {
|
||||
int EpochEnd(const lite::TrainLoopCallBackData &i_cb_data) override {
|
||||
return call_back_->EpochEnd(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
|
||||
};
|
||||
|
||||
void StepBegin(const session::TrainLoopCallBackData &i_cb_data) override {
|
||||
void StepBegin(const lite::TrainLoopCallBackData &i_cb_data) override {
|
||||
call_back_->StepBegin(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
|
||||
};
|
||||
|
||||
void StepEnd(const session::TrainLoopCallBackData &i_cb_data) override {
|
||||
void StepEnd(const lite::TrainLoopCallBackData &i_cb_data) override {
|
||||
call_back_->StepEnd(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
|
||||
};
|
||||
|
||||
|
|
|
@ -24,11 +24,11 @@ namespace mindspore {
|
|||
class CallbackImpl {
|
||||
public:
|
||||
CallbackImpl() = delete;
|
||||
explicit CallbackImpl(session::TrainLoopCallBack *cb) : internal_call_back_(cb) {}
|
||||
session::TrainLoopCallBack *GetInternalCallback() { return internal_call_back_; }
|
||||
explicit CallbackImpl(lite::TrainLoopCallBack *cb) : internal_call_back_(cb) {}
|
||||
lite::TrainLoopCallBack *GetInternalCallback() { return internal_call_back_; }
|
||||
|
||||
protected:
|
||||
session::TrainLoopCallBack *internal_call_back_ = nullptr;
|
||||
lite::TrainLoopCallBack *internal_call_back_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ const std::vector<GraphPoint> &LossMonitor::GetLossPoints() {
|
|||
return empty_vector;
|
||||
}
|
||||
|
||||
session::TrainLoopCallBack *internal_call_back = callback_impl_->GetInternalCallback();
|
||||
lite::TrainLoopCallBack *internal_call_back = callback_impl_->GetInternalCallback();
|
||||
if (internal_call_back == nullptr) {
|
||||
MS_LOG(ERROR) << "Internal callback is null.";
|
||||
return empty_vector;
|
||||
|
|
|
@ -52,7 +52,7 @@ const std::vector<GraphPoint> &TrainAccuracy::GetAccuracyPoints() {
|
|||
return empty_vector;
|
||||
}
|
||||
|
||||
session::TrainLoopCallBack *internal_call_back = callback_impl_->GetInternalCallback();
|
||||
lite::TrainLoopCallBack *internal_call_back = callback_impl_->GetInternalCallback();
|
||||
if (internal_call_back == nullptr) {
|
||||
MS_LOG(ERROR) << "Internal callback is null.";
|
||||
return empty_vector;
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "src/runtime/cxx_api/graph/graph_data.h"
|
||||
#include "src/runtime/inner_context.h"
|
||||
#include "src/runtime/lite_session.h"
|
||||
#include "include/train/train_loop_callback.h"
|
||||
|
||||
template <class T>
|
||||
void clearVectorOfPointers(std::vector<T> *v) {
|
||||
|
@ -105,13 +106,13 @@ class ModelImpl {
|
|||
return kSuccess;
|
||||
}
|
||||
std::vector<Metrics *> GetMetrics() { return metrics_; }
|
||||
const session::LiteSession *GetSession() const { return session_.get(); }
|
||||
const lite::LiteSession *GetSession() const { return session_.get(); }
|
||||
|
||||
protected:
|
||||
// Utility methods
|
||||
Status ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
|
||||
std::vector<session::TrainLoopCallBack *> *o_cbs,
|
||||
std::vector<session::TrainLoopCallBack *> *adapter_cbs);
|
||||
std::vector<lite::TrainLoopCallBack *> *o_cbs,
|
||||
std::vector<lite::TrainLoopCallBack *> *adapter_cbs);
|
||||
Status PrepareMetrics(Model *model, std::vector<session::Metrics *> *o_ms,
|
||||
std::vector<session::Metrics *> *adapter_ms);
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "src/runtime/cxx_api/model/model_impl.h"
|
||||
#include "src/runtime/cxx_api/callback/callback_impl.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/train/train_loop.h"
|
||||
#include "src/train/train_loop.h"
|
||||
#include "include/train/train_loop_callback.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -30,7 +30,7 @@ Status Model::Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vecto
|
|||
MS_LOG(ERROR) << "Model implement or dataset is null.";
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
auto loop = std::unique_ptr<session::TrainLoop>(session::TrainLoop::CreateTrainLoop((impl_->session_).get()));
|
||||
auto loop = std::unique_ptr<lite::TrainLoop>(lite::TrainLoop::CreateTrainLoop((impl_->session_).get()));
|
||||
if (loop == nullptr) {
|
||||
MS_LOG(ERROR) << "Error during allocation of train loop";
|
||||
return kLiteNullptr;
|
||||
|
@ -47,8 +47,8 @@ Status Model::Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vecto
|
|||
(void)loop->Init(metrics);
|
||||
|
||||
// Convert Callbacks to be used by loop
|
||||
std::vector<session::TrainLoopCallBack *> cbs;
|
||||
std::vector<session::TrainLoopCallBack *> adapter_cbs;
|
||||
std::vector<lite::TrainLoopCallBack *> cbs;
|
||||
std::vector<lite::TrainLoopCallBack *> adapter_cbs;
|
||||
status = impl_->ConvertCallbacks(this, &i_cbs, &cbs, &adapter_cbs);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "Error during preparation of callbacks";
|
||||
|
@ -70,7 +70,7 @@ Status Model::Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCa
|
|||
return kLiteUninitializedObj;
|
||||
}
|
||||
|
||||
auto loop = std::unique_ptr<session::TrainLoop>(session::TrainLoop::CreateTrainLoop((impl_->session_).get()));
|
||||
auto loop = std::unique_ptr<lite::TrainLoop>(lite::TrainLoop::CreateTrainLoop((impl_->session_).get()));
|
||||
if (loop == nullptr) {
|
||||
MS_LOG(ERROR) << "Error during allocation of train loop";
|
||||
return kLiteNullptr;
|
||||
|
@ -87,8 +87,8 @@ Status Model::Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCa
|
|||
(void)loop->Init(metrics);
|
||||
|
||||
// Convert Callbacks to be used by loop
|
||||
std::vector<session::TrainLoopCallBack *> cbs;
|
||||
std::vector<session::TrainLoopCallBack *> adapter_cbs;
|
||||
std::vector<lite::TrainLoopCallBack *> cbs;
|
||||
std::vector<lite::TrainLoopCallBack *> adapter_cbs;
|
||||
status = impl_->ConvertCallbacks(this, &i_cbs, &cbs, &adapter_cbs);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "Error during preparation of callbacks";
|
||||
|
|
|
@ -68,8 +68,8 @@ Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *
|
|||
}
|
||||
|
||||
Status ModelImpl::ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
|
||||
std::vector<session::TrainLoopCallBack *> *o_cbs,
|
||||
std::vector<session::TrainLoopCallBack *> *adapter_cbs) {
|
||||
std::vector<lite::TrainLoopCallBack *> *o_cbs,
|
||||
std::vector<lite::TrainLoopCallBack *> *adapter_cbs) {
|
||||
if (i_cbs == nullptr || o_cbs == nullptr || adapter_cbs == nullptr) {
|
||||
MS_LOG(ERROR) << "Null input callbacks";
|
||||
return kLiteUninitializedObj;
|
||||
|
|
|
@ -1510,7 +1510,7 @@ int LiteSession::InitGPURuntime() {
|
|||
}
|
||||
} // namespace lite
|
||||
|
||||
session::LiteSession *session::LiteSession::CreateSession(const lite::Context *context) {
|
||||
lite::LiteSession *lite::LiteSession::CreateSession(const lite::Context *context) {
|
||||
if (context == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1536,9 +1536,8 @@ session::LiteSession *session::LiteSession::CreateSession(const lite::Context *c
|
|||
return session;
|
||||
}
|
||||
|
||||
session::LiteSession *session::LiteSession::CreateSession(const char *model_buf, size_t size,
|
||||
const lite::Context *context) {
|
||||
auto *session = LiteSession::CreateSession(context);
|
||||
lite::LiteSession *lite::LiteSession::CreateSession(const char *model_buf, size_t size, const lite::Context *context) {
|
||||
auto *session = lite::LiteSession::CreateSession(context);
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "Create session failed";
|
||||
return nullptr;
|
||||
|
@ -1553,8 +1552,8 @@ session::LiteSession *session::LiteSession::CreateSession(const char *model_buf,
|
|||
return session;
|
||||
}
|
||||
|
||||
session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_path, const lite::Context *context) {
|
||||
auto *session = session::LiteSession::CreateSession(context);
|
||||
lite::LiteSession *lite::LiteSession::CreateSession(const std::string &model_path, const lite::Context *context) {
|
||||
auto *session = lite::LiteSession::CreateSession(context);
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "Create session failed";
|
||||
return nullptr;
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include <atomic>
|
||||
#include "src/runtime/kernel_exec.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "src/runtime/lite_model.h"
|
||||
#include "src/runtime/inner_context.h"
|
||||
#include "src/runtime/runtime_allocator.h"
|
||||
|
@ -41,11 +40,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class LiteSession : public session::LiteSession {
|
||||
class LiteSession {
|
||||
public:
|
||||
LiteSession();
|
||||
~LiteSession() override;
|
||||
static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context);
|
||||
virtual ~LiteSession();
|
||||
static LiteSession *CreateSession(const lite::Context *context);
|
||||
static LiteSession *CreateSession(const char *model_buf, size_t size, const lite::Context *context);
|
||||
static LiteSession *CreateSession(const std::string &model_path, const lite::Context *context);
|
||||
int LoadModelAndCompileByBuf(const char *model_buf, mindspore::ModelType model_type, const size_t &buf_size);
|
||||
int LoadModelAndCompileByBuf(const char *model_buf, mindspore::ModelType model_type, const size_t &buf_size,
|
||||
const std::shared_ptr<mindspore::Context> &ms_context);
|
||||
|
@ -62,19 +63,19 @@ class LiteSession : public session::LiteSession {
|
|||
static const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size,
|
||||
const std::shared_ptr<mindspore::Context> &ms_context);
|
||||
virtual int Init(InnerContext *context);
|
||||
void BindThread(bool if_bind) override;
|
||||
int CompileGraph(Model *model) override;
|
||||
std::vector<mindspore::tensor::MSTensor *> GetInputs() const override;
|
||||
mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &name) const override;
|
||||
int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
|
||||
std::vector<mindspore::tensor::MSTensor *> GetOutputsByNodeName(const std::string &node_name) const override;
|
||||
std::vector<std::string> GetOutputTensorNames() const override;
|
||||
mindspore::tensor::MSTensor *GetOutputByTensorName(const std::string &tensor_name) const override;
|
||||
std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputs() const override;
|
||||
int BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
|
||||
std::map<std::string, unsigned int> *outputGLTexture) override;
|
||||
int Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs,
|
||||
const std::vector<std::vector<int>> &dims) override;
|
||||
virtual void BindThread(bool if_bind);
|
||||
virtual int CompileGraph(Model *model);
|
||||
virtual std::vector<mindspore::tensor::MSTensor *> GetInputs() const;
|
||||
virtual mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &name) const;
|
||||
virtual int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr);
|
||||
virtual std::vector<mindspore::tensor::MSTensor *> GetOutputsByNodeName(const std::string &node_name) const;
|
||||
virtual std::vector<std::string> GetOutputTensorNames() const;
|
||||
virtual mindspore::tensor::MSTensor *GetOutputByTensorName(const std::string &tensor_name) const;
|
||||
virtual std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputs() const;
|
||||
virtual int BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
|
||||
std::map<std::string, unsigned int> *outputGLTexture);
|
||||
virtual int Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs,
|
||||
const std::vector<std::vector<int>> &dims);
|
||||
void InitExecutionConfig(std::map<std::string, TypeId> *config) { execution_plan_ = config; }
|
||||
void set_model(Model *model) { this->model_ = model; }
|
||||
const std::vector<kernel::KernelExec *> &get_kernels() const { return this->kernels_; }
|
||||
|
@ -84,6 +85,41 @@ class LiteSession : public session::LiteSession {
|
|||
}
|
||||
const std::vector<Tensor *> &GetTensors() const { return this->tensors_; }
|
||||
|
||||
virtual int Train() { return mindspore::lite::RET_ERROR; }
|
||||
virtual bool IsTrain() { return false; }
|
||||
virtual int Eval() { return mindspore::lite::RET_OK; }
|
||||
virtual bool IsEval() { return true; }
|
||||
virtual int SetLearningRate(float learning_rate) { return mindspore::lite::RET_ERROR; }
|
||||
virtual float GetLearningRate() { return 0.0; }
|
||||
virtual int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) {
|
||||
return mindspore::lite::RET_ERROR;
|
||||
}
|
||||
virtual std::vector<tensor::MSTensor *> GetPredictions() const {
|
||||
std::vector<tensor::MSTensor *> outputs;
|
||||
return outputs;
|
||||
}
|
||||
virtual int Export(const std::string &file_name, lite::ModelType model_type = lite::MT_TRAIN,
|
||||
lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS,
|
||||
std::vector<std::string> out_put_tensor_name = {}) {
|
||||
return mindspore::lite::RET_ERROR;
|
||||
}
|
||||
virtual int UpdateWeights(std::vector<tensor::MSTensor *> new_weights) { return mindspore::lite::RET_ERROR; }
|
||||
virtual std::vector<tensor::MSTensor *> GetFeatureMaps() const {
|
||||
std::vector<tensor::MSTensor *> features;
|
||||
return features;
|
||||
}
|
||||
virtual int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features) { return mindspore::lite::RET_ERROR; }
|
||||
virtual std::vector<tensor::MSTensor *> GetGradients() const {
|
||||
std::vector<tensor::MSTensor *> gradients;
|
||||
return gradients;
|
||||
}
|
||||
virtual int ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) { return mindspore::lite::RET_ERROR; }
|
||||
virtual std::vector<tensor::MSTensor *> GetOptimizerParams() const {
|
||||
std::vector<tensor::MSTensor *> params;
|
||||
return params;
|
||||
}
|
||||
virtual int SetOptimizerParams(const std::vector<tensor::MSTensor *> ¶ms) { return mindspore::lite::RET_ERROR; }
|
||||
|
||||
protected:
|
||||
static void ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lite::Tensor *dst_tensor);
|
||||
int CheckTensorValid(lite::Tensor *dst_tensor);
|
||||
|
|
|
@ -23,17 +23,17 @@
|
|||
#include <fstream>
|
||||
#include <memory>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/train/train_loop.h"
|
||||
#include "src/train/train_loop.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "src/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
void AccuracyMonitor::Begin(const session::TrainLoopCallBackData &cb_data) {
|
||||
void AccuracyMonitor::Begin(const lite::TrainLoopCallBackData &cb_data) {
|
||||
if (cb_data.epoch_ == 0) accuracies_.clear();
|
||||
}
|
||||
|
||||
int AccuracyMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
|
||||
int AccuracyMonitor::EpochEnd(const lite::TrainLoopCallBackData &cb_data) {
|
||||
if ((static_cast<int>(cb_data.epoch_) + 1) % check_every_n_ == 0) {
|
||||
auto ret = cb_data.loop_->Eval(ds_, {}, nullptr, max_steps_);
|
||||
if (ret != RET_OK) {
|
||||
|
@ -42,7 +42,7 @@ int AccuracyMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
|
|||
}
|
||||
}
|
||||
accuracies_.push_back(std::make_pair(cb_data.epoch_, 0.0));
|
||||
return mindspore::session::RET_CONTINUE;
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,11 +30,11 @@ ClassificationTrainAccuracyMonitor::ClassificationTrainAccuracyMonitor(int print
|
|||
print_every_n_ = print_every_n;
|
||||
}
|
||||
|
||||
void ClassificationTrainAccuracyMonitor::Begin(const session::TrainLoopCallBackData &cb_data) {
|
||||
void ClassificationTrainAccuracyMonitor::Begin(const TrainLoopCallBackData &cb_data) {
|
||||
if (cb_data.epoch_ == 0) accuracies_.clear();
|
||||
}
|
||||
|
||||
void ClassificationTrainAccuracyMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) {
|
||||
void ClassificationTrainAccuracyMonitor::EpochBegin(const TrainLoopCallBackData &cb_data) {
|
||||
if (accuracies_.size() != cb_data.epoch_) {
|
||||
MS_LOG(WARNING) << "Accuracies array does not match epoch number";
|
||||
} else {
|
||||
|
@ -42,16 +42,16 @@ void ClassificationTrainAccuracyMonitor::EpochBegin(const session::TrainLoopCall
|
|||
}
|
||||
}
|
||||
|
||||
int ClassificationTrainAccuracyMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
|
||||
int ClassificationTrainAccuracyMonitor::EpochEnd(const TrainLoopCallBackData &cb_data) {
|
||||
if (cb_data.step_ > 0) accuracies_.at(cb_data.epoch_).second /= static_cast<float>(cb_data.step_ + 1);
|
||||
if ((static_cast<int>(cb_data.epoch_) + 1) % print_every_n_ == 0) {
|
||||
std::cout << "Epoch (" << (cb_data.epoch_ + 1) << "):\tTraining Accuracy is "
|
||||
<< accuracies_.at(cb_data.epoch_).second << std::endl;
|
||||
}
|
||||
return mindspore::session::RET_CONTINUE;
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
|
||||
void ClassificationTrainAccuracyMonitor::StepEnd(const session::TrainLoopCallBackData &cb_data) {
|
||||
void ClassificationTrainAccuracyMonitor::StepEnd(const TrainLoopCallBackData &cb_data) {
|
||||
auto inputs = cb_data.session_->GetInputs();
|
||||
auto outputs = cb_data.session_->GetPredictions();
|
||||
|
||||
|
|
|
@ -26,11 +26,11 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
void LossMonitor::Begin(const session::TrainLoopCallBackData &cb_data) {
|
||||
void LossMonitor::Begin(const TrainLoopCallBackData &cb_data) {
|
||||
if (cb_data.epoch_ == 0) losses_.clear();
|
||||
}
|
||||
|
||||
void LossMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) {
|
||||
void LossMonitor::EpochBegin(const TrainLoopCallBackData &cb_data) {
|
||||
if (losses_.size() != cb_data.epoch_) {
|
||||
MS_LOG(WARNING) << "losses array does not match epoch number";
|
||||
} else {
|
||||
|
@ -38,15 +38,15 @@ void LossMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) {
|
|||
}
|
||||
}
|
||||
|
||||
int LossMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
|
||||
int LossMonitor::EpochEnd(const TrainLoopCallBackData &cb_data) {
|
||||
if (cb_data.step_ > 0) losses_.at(cb_data.epoch_).second /= static_cast<float>(cb_data.step_ + 1);
|
||||
if (print_every_n_ > 0) {
|
||||
std::cout << "Epoch (" << (cb_data.epoch_ + 1) << "):\tLoss is " << losses_.at(cb_data.epoch_).second << std::endl;
|
||||
}
|
||||
return mindspore::session::RET_CONTINUE;
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
|
||||
void LossMonitor::StepEnd(const session::TrainLoopCallBackData &cb_data) {
|
||||
void LossMonitor::StepEnd(const TrainLoopCallBackData &cb_data) {
|
||||
auto outputs = cb_data.session_->GetOutputs();
|
||||
for (auto it = outputs.begin(); it != outputs.end(); ++it) {
|
||||
if (it->second->ElementsNum() == 1) {
|
||||
|
|
|
@ -55,7 +55,7 @@ int StepLRLambda(float *lr, int epoch, void *lr_cb_data) {
|
|||
LRScheduler::LRScheduler(LR_Lambda lambda_func, void *lr_cb_data, int step)
|
||||
: lambda_func_(lambda_func), lr_data_(lr_cb_data), step_(step) {}
|
||||
|
||||
int LRScheduler::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
|
||||
int LRScheduler::EpochEnd(const TrainLoopCallBackData &cb_data) {
|
||||
if (((static_cast<int>(cb_data.epoch_) + 1) % step_) == 0) {
|
||||
float lr = cb_data.session_->GetLearningRate();
|
||||
int update = lambda_func_(&lr, cb_data.epoch_, lr_data_);
|
||||
|
@ -63,11 +63,11 @@ int LRScheduler::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
|
|||
int ret = cb_data.session_->SetLearningRate(lr);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Error setting Leraning rate in train session";
|
||||
return mindspore::session::RET_EXIT;
|
||||
return RET_EXIT;
|
||||
}
|
||||
}
|
||||
}
|
||||
return mindspore::session::RET_CONTINUE;
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,13 +29,10 @@ namespace lite {
|
|||
using dataset::Dataset;
|
||||
using dataset::Iterator;
|
||||
using dataset::MSTensorVec;
|
||||
using session::RET_CONTINUE;
|
||||
using session::RET_EXIT;
|
||||
using session::RET_STOP_TRAINING;
|
||||
|
||||
TrainLoop::~TrainLoop() {}
|
||||
|
||||
int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
|
||||
int TrainLoop::Train(int epochs, Dataset *ds, std::vector<TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
|
||||
MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
|
||||
MS_CHECK_GE(epochs, 0, RET_ERROR);
|
||||
auto ret = train_session_->Train();
|
||||
|
@ -43,7 +40,7 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
|
|||
MS_LOG(ERROR) << "TrainLoop train failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
session::TrainLoopCallBackData cb_data(true, epoch_, train_session_, this);
|
||||
TrainLoopCallBackData cb_data(true, epoch_, train_session_, this);
|
||||
|
||||
if (load_func == nullptr) load_func = TrainLoop::LoadData;
|
||||
|
||||
|
@ -106,14 +103,14 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int TrainLoop::Eval(Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func, int max_steps) {
|
||||
int TrainLoop::Eval(Dataset *ds, std::vector<TrainLoopCallBack *> cbs, LoadDataFunc load_func, int max_steps) {
|
||||
MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
|
||||
auto ret = train_session_->Eval();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "TrainLoop train failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
session::TrainLoopCallBackData cb_data(false, epoch_, train_session_, this);
|
||||
TrainLoopCallBackData cb_data(false, epoch_, train_session_, this);
|
||||
|
||||
if (load_func == nullptr) load_func = TrainLoop::LoadData;
|
||||
|
||||
|
@ -184,7 +181,7 @@ int TrainLoop::LoadData(std::vector<tensor::MSTensor *> inputs, dataset::MSTenso
|
|||
}
|
||||
} // namespace lite
|
||||
|
||||
session::TrainLoop *session::TrainLoop::CreateTrainLoop(session::LiteSession *train_session) {
|
||||
lite::TrainLoop *session::TrainLoop::CreateTrainLoop(lite::LiteSession *train_session) {
|
||||
auto loop = new (std::nothrow) lite::TrainLoop(train_session);
|
||||
return loop;
|
||||
}
|
||||
|
|
|
@ -32,9 +32,9 @@ namespace lite {
|
|||
|
||||
class TrainLoop : virtual public session::TrainLoop {
|
||||
public:
|
||||
explicit TrainLoop(session::LiteSession *session) : train_session_(session) {}
|
||||
explicit TrainLoop(lite::LiteSession *session) : train_session_(session) {}
|
||||
|
||||
const session::LiteSession *train_session() override { return train_session_; }
|
||||
const lite::LiteSession *train_session() override { return train_session_; }
|
||||
|
||||
int Reset() override {
|
||||
epoch_ = 0;
|
||||
|
@ -54,9 +54,9 @@ class TrainLoop : virtual public session::TrainLoop {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Train(int epochs, dataset::Dataset *dataset, std::vector<session::TrainLoopCallBack *> cbs,
|
||||
int Train(int epochs, dataset::Dataset *dataset, std::vector<lite::TrainLoopCallBack *> cbs,
|
||||
LoadDataFunc load_func = nullptr) override;
|
||||
int Eval(dataset::Dataset *dataset, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func = nullptr,
|
||||
int Eval(dataset::Dataset *dataset, std::vector<lite::TrainLoopCallBack *> cbs, LoadDataFunc load_func = nullptr,
|
||||
int max_steps = 0) override;
|
||||
|
||||
std::vector<mindspore::session::Metrics *> GetMetrics() override { return metrics_; }
|
||||
|
@ -64,7 +64,7 @@ class TrainLoop : virtual public session::TrainLoop {
|
|||
protected:
|
||||
static int LoadData(std::vector<tensor::MSTensor *> inputs, dataset::MSTensorVec *dataset_vec);
|
||||
|
||||
session::LiteSession *train_session_ = nullptr;
|
||||
lite::LiteSession *train_session_ = nullptr;
|
||||
unsigned int epoch_ = 0;
|
||||
KernelCallBack before_cb_ = nullptr;
|
||||
KernelCallBack after_cb_ = nullptr;
|
||||
|
|
|
@ -1293,8 +1293,8 @@ size_t TrainSession::GetInplaceTensorOffset(kernel::KernelExec *kernel,
|
|||
}
|
||||
} // namespace lite
|
||||
|
||||
session::LiteSession *session::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
|
||||
bool train_mode, const lite::TrainCfg *cfg) {
|
||||
lite::LiteSession *lite::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
|
||||
bool train_mode, const lite::TrainCfg *cfg) {
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "context cannot be nullptr";
|
||||
return nullptr;
|
||||
|
|
|
@ -22,16 +22,11 @@
|
|||
#include <memory>
|
||||
#include <map>
|
||||
#include "include/train/train_cfg.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "src/runtime/lite_session.h"
|
||||
|
||||
/*
|
||||
Inheritance Diagram
|
||||
|
||||
+-------------------------------+
|
||||
| session::LiteSession |
|
||||
+--------------↑----------------+
|
||||
|
|
||||
+--------------+----------------+
|
||||
| lite::LiteSession |
|
||||
+--------------↑----------------+
|
||||
|
@ -49,6 +44,11 @@ class TrainSession : virtual public lite::LiteSession {
|
|||
TrainSession();
|
||||
~TrainSession();
|
||||
|
||||
static LiteSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
|
||||
const lite::Context *context, bool train_mode = false,
|
||||
const lite::TrainCfg *cfg = nullptr);
|
||||
static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context,
|
||||
bool train_mode = false, const lite::TrainCfg *cfg = nullptr);
|
||||
int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
|
||||
|
||||
int CompileGraph(lite::Model *model) override;
|
||||
|
|
|
@ -42,7 +42,7 @@ TransferSession::TransferSession(const char *model_buf_backbone, size_t size_bac
|
|||
if (lite_model_ != nullptr) {
|
||||
std::copy(model_buf_backbone, model_buf_backbone + size_backbone, lite_model_);
|
||||
backbone_session_ =
|
||||
reinterpret_cast<lite::LiteSession *>(session::LiteSession::CreateSession(lite_model_, size_backbone, context));
|
||||
reinterpret_cast<LiteSession *>(LiteSession::CreateSession(lite_model_, size_backbone, context));
|
||||
if (backbone_session_ != nullptr) {
|
||||
is_valid_ = true;
|
||||
} else {
|
||||
|
@ -312,10 +312,10 @@ lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size
|
|||
}
|
||||
} // namespace lite
|
||||
|
||||
session::LiteSession *session::TrainSession::CreateTransferSession(const std::string &filename_backbone,
|
||||
const std::string &filename_head,
|
||||
const lite::Context *ctxt, bool train_mode,
|
||||
const lite::TrainCfg *cfg) {
|
||||
lite::LiteSession *lite::TrainSession::CreateTransferSession(const std::string &filename_backbone,
|
||||
const std::string &filename_head,
|
||||
const lite::Context *ctxt, bool train_mode,
|
||||
const lite::TrainCfg *cfg) {
|
||||
size_t size_head = 0;
|
||||
size_t size_backbone = 0;
|
||||
std::string filename = filename_head;
|
||||
|
|
|
@ -65,7 +65,7 @@ class TransferSession : public lite::TrainSession {
|
|||
std::vector<std::string> out_put_tensor_name = {}) override;
|
||||
|
||||
protected:
|
||||
lite::LiteSession *backbone_session_ = nullptr;
|
||||
LiteSession *backbone_session_ = nullptr;
|
||||
char *lite_model_ = nullptr;
|
||||
std::vector<mindspore::tensor::MSTensor *> combined_inputs_;
|
||||
std::vector<std::pair<mindspore::tensor::MSTensor *, mindspore::tensor::MSTensor *>> backbone_head_map_;
|
||||
|
|
|
@ -46,7 +46,7 @@ TEST_F(GraphTest, UserSetGraphOutput1) {
|
|||
auto context = std::make_shared<lite::Context>();
|
||||
ASSERT_NE(context, nullptr);
|
||||
|
||||
session::LiteSession *session = session::LiteSession::CreateSession(context.get());
|
||||
lite::LiteSession *session = lite::LiteSession::CreateSession(context.get());
|
||||
ASSERT_NE(session, nullptr);
|
||||
|
||||
int benchmark_ret = session->CompileGraph(model.get());
|
||||
|
|
|
@ -25,9 +25,8 @@ class MindrtRuntimeTest : public mindspore::CommonTest {
|
|||
MindrtRuntimeTest() = default;
|
||||
};
|
||||
|
||||
int CheckRuntime(mindspore::session::LiteSession *session) {
|
||||
mindspore::lite::LiteSession *lite_session = reinterpret_cast<mindspore::lite::LiteSession *>(session);
|
||||
auto kernels = lite_session->get_kernels();
|
||||
int CheckRuntime(mindspore::lite::LiteSession *session) {
|
||||
auto kernels = session->get_kernels();
|
||||
|
||||
int cpu_kernel_count = 0;
|
||||
int gpu_kernel_count = 0;
|
||||
|
@ -69,7 +68,7 @@ TEST_F(MindrtRuntimeTest, Runtime) {
|
|||
gpu_device_ctx.device_info_.gpu_device_info_.enable_float16_ = false;
|
||||
context->device_list_.push_back(gpu_device_ctx);
|
||||
|
||||
mindspore::session::LiteSession *session = mindspore::session::LiteSession::CreateSession(context.get());
|
||||
mindspore::lite::LiteSession *session = mindspore::lite::LiteSession::CreateSession(context.get());
|
||||
ASSERT_NE(session, nullptr);
|
||||
|
||||
int benchmark_ret = session->CompileGraph(model.get());
|
||||
|
@ -87,9 +86,8 @@ TEST_F(MindrtRuntimeTest, Runtime) {
|
|||
delete session;
|
||||
}
|
||||
|
||||
int CheckRuntime2(mindspore::session::LiteSession *session) {
|
||||
mindspore::lite::LiteSession *lite_session = reinterpret_cast<mindspore::lite::LiteSession *>(session);
|
||||
auto kernels = lite_session->get_kernels();
|
||||
int CheckRuntime2(mindspore::lite::LiteSession *session) {
|
||||
auto kernels = session->get_kernels();
|
||||
|
||||
for (auto kernel : kernels) {
|
||||
if (kernel->subgraph_type() != mindspore::kernel::kCpuFP16SubGraph) {
|
||||
|
@ -119,7 +117,7 @@ TEST_F(MindrtRuntimeTest, RuntimeFp16) {
|
|||
auto &cpu_device_ctx = context->device_list_[0];
|
||||
cpu_device_ctx.device_info_.cpu_device_info_.enable_float16_ = true;
|
||||
|
||||
mindspore::session::LiteSession *session = mindspore::session::LiteSession::CreateSession(context.get());
|
||||
mindspore::lite::LiteSession *session = mindspore::lite::LiteSession::CreateSession(context.get());
|
||||
ASSERT_NE(session, nullptr);
|
||||
|
||||
int benchmark_ret = session->CompileGraph(model.get());
|
||||
|
|
|
@ -31,7 +31,7 @@ class MindrtParallelTest : public mindspore::CommonTest {
|
|||
MindrtParallelTest() {}
|
||||
};
|
||||
|
||||
int CheckOffline1(session::LiteSession *session) {
|
||||
int CheckOffline1(lite::LiteSession *session) {
|
||||
/* ----------- start check -------------- */
|
||||
lite::LiteSession *lite_session = reinterpret_cast<lite::LiteSession *>(session);
|
||||
auto kernels = lite_session->get_kernels();
|
||||
|
@ -86,7 +86,7 @@ int CheckOffline1(session::LiteSession *session) {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int CheckRuntime1(session::LiteSession *session) {
|
||||
int CheckRuntime1(lite::LiteSession *session) {
|
||||
lite::LiteSession *lite_session = reinterpret_cast<lite::LiteSession *>(session);
|
||||
auto kernels = lite_session->get_kernels();
|
||||
if (kernels.size() != 6) {
|
||||
|
@ -115,7 +115,7 @@ TEST_F(MindrtParallelTest, offline1) {
|
|||
ASSERT_NE(context, nullptr);
|
||||
context->enable_parallel_ = true;
|
||||
|
||||
session::LiteSession *session = session::LiteSession::CreateSession(context.get());
|
||||
lite::LiteSession *session = lite::LiteSession::CreateSession(context.get());
|
||||
ASSERT_NE(session, nullptr);
|
||||
|
||||
int benchmark_ret = session->CompileGraph(model.get());
|
||||
|
@ -153,7 +153,7 @@ TEST_F(MindrtParallelTest, runtime1) {
|
|||
ASSERT_NE(context, nullptr);
|
||||
context->enable_parallel_ = true;
|
||||
|
||||
session::LiteSession *session = session::LiteSession::CreateSession(context.get());
|
||||
lite::LiteSession *session = lite::LiteSession::CreateSession(context.get());
|
||||
ASSERT_NE(session, nullptr);
|
||||
|
||||
int benchmark_ret = session->CompileGraph(model.get());
|
||||
|
|
|
@ -326,7 +326,7 @@ TEST_F(SubGraphTest, RecursiveSubGraphTest) {
|
|||
auto &cpu_device_ctx = context.device_list_[0];
|
||||
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU;
|
||||
context.thread_num_ = 2;
|
||||
auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&context));
|
||||
auto session = std::shared_ptr<lite::LiteSession>(lite::LiteSession::CreateSession(&context));
|
||||
ASSERT_NE(session, nullptr);
|
||||
auto ret = session->CompileGraph(model.get());
|
||||
ASSERT_EQ(ret, lite::RET_OK);
|
||||
|
|
|
@ -102,7 +102,7 @@ TEST_F(InferTest, TestConvNode) {
|
|||
device_list.push_back(device_ctx);
|
||||
context->thread_num_ = 4;
|
||||
ASSERT_EQ(lite::RET_OK, context->Init());
|
||||
auto session = session::LiteSession::CreateSession(context);
|
||||
auto session = lite::LiteSession::CreateSession(context);
|
||||
ASSERT_NE(nullptr, session);
|
||||
auto ret = session->CompileGraph(model);
|
||||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
|
@ -198,7 +198,7 @@ TEST_F(InferTest, TestAddNode) {
|
|||
device_list.push_back(device_ctx);
|
||||
context->thread_num_ = 4;
|
||||
ASSERT_EQ(lite::RET_OK, context->Init());
|
||||
auto session = session::LiteSession::CreateSession(context);
|
||||
auto session = lite::LiteSession::CreateSession(context);
|
||||
ASSERT_NE(nullptr, session);
|
||||
auto ret = session->CompileGraph(model);
|
||||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
|
@ -234,7 +234,7 @@ TEST_F(InferTest, TestModel) {
|
|||
context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->thread_num_ = 4;
|
||||
ASSERT_EQ(lite::RET_OK, context->Init());
|
||||
auto session = session::LiteSession::CreateSession(context);
|
||||
auto session = lite::LiteSession::CreateSession(context);
|
||||
ASSERT_NE(nullptr, session);
|
||||
auto ret = session->CompileGraph(model);
|
||||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "include/train/train_cfg.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "src/train/train_session.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/runtime/kernel_registry.h"
|
||||
|
@ -38,12 +38,12 @@ namespace mindspore {
|
|||
class NetworkTest : public mindspore::CommonTest {
|
||||
public:
|
||||
NetworkTest() {}
|
||||
int32_t runNet(mindspore::session::LiteSession *session, const std::string &in, const std::string &out,
|
||||
int32_t runNet(mindspore::lite::LiteSession *session, const std::string &in, const std::string &out,
|
||||
const char *tensor_name, bool debug = false);
|
||||
};
|
||||
|
||||
int32_t fileIterator(mindspore::session::LiteSession *session, const std::string &path,
|
||||
std::function<int32_t(mindspore::session::LiteSession *session, const std::string &)> cb) {
|
||||
int32_t fileIterator(mindspore::lite::LiteSession *session, const std::string &path,
|
||||
std::function<int32_t(mindspore::lite::LiteSession *session, const std::string &)> cb) {
|
||||
int32_t res = 0;
|
||||
if (auto dir = opendir(path.c_str())) {
|
||||
while (auto f = readdir(dir)) {
|
||||
|
@ -58,7 +58,7 @@ int32_t fileIterator(mindspore::session::LiteSession *session, const std::string
|
|||
}
|
||||
void replaceExt(const std::string &src, std::string *dst) { *dst = src.substr(0, src.find_last_of('.')) + ".emb"; }
|
||||
|
||||
int32_t NetworkTest::runNet(mindspore::session::LiteSession *session, const std::string &in, const std::string &out,
|
||||
int32_t NetworkTest::runNet(mindspore::lite::LiteSession *session, const std::string &in, const std::string &out,
|
||||
const char *tensor_name, bool debug) {
|
||||
// setup input
|
||||
auto inputs = session->GetInputs();
|
||||
|
@ -101,7 +101,7 @@ TEST_F(NetworkTest, efficient_net) {
|
|||
context->thread_num_ = 1;
|
||||
|
||||
std::string net = "./nets/effnetb0_fwd_nofuse.ms";
|
||||
auto session = session::TrainSession::CreateTrainSession(net, context, false);
|
||||
auto session = lite::TrainSession::CreateTrainSession(net, context, false);
|
||||
ASSERT_NE(session, nullptr);
|
||||
|
||||
std::string in = "./nets/effNet_input_x_1_3_224_224.bin";
|
||||
|
@ -122,7 +122,7 @@ TEST_F(NetworkTest, mobileface_net) {
|
|||
context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->thread_num_ = 1;
|
||||
|
||||
auto session = session::LiteSession::CreateSession(context);
|
||||
auto session = lite::LiteSession::CreateSession(context);
|
||||
ASSERT_NE(session, nullptr);
|
||||
auto ret = session->CompileGraph(model);
|
||||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
|
@ -147,7 +147,7 @@ TEST_F(NetworkTest, noname) {
|
|||
lite::TrainCfg cfg;
|
||||
cfg.loss_name_.clear();
|
||||
cfg.loss_name_.emplace_back("nhwc");
|
||||
auto session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &cfg);
|
||||
auto session = mindspore::lite::TrainSession::CreateTrainSession(net, &context, true, &cfg);
|
||||
ASSERT_NE(session, nullptr);
|
||||
auto tensors_map = session->GetOutputs();
|
||||
auto tensor_names = session->GetOutputTensorNames();
|
||||
|
@ -168,7 +168,7 @@ TEST_F(NetworkTest, setname) {
|
|||
train_cfg.loss_name_.clear();
|
||||
train_cfg.loss_name_.emplace_back("nhwc");
|
||||
|
||||
auto session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &train_cfg);
|
||||
auto session = mindspore::lite::TrainSession::CreateTrainSession(net, &context, true, &train_cfg);
|
||||
ASSERT_NE(session, nullptr);
|
||||
|
||||
auto tensors_map = session->GetOutputs();
|
||||
|
|
|
@ -68,7 +68,6 @@ add_executable(benchmark
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/main.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/run_benchmark.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/benchmark_base.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/benchmark.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/benchmark_unified_api.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/benchmark_c_api.cc
|
||||
${COMMON_SRC})
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#include "tools/benchmark/run_benchmark.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "tools/benchmark/benchmark.h"
|
||||
#include "tools/benchmark/benchmark_unified_api.h"
|
||||
#include "tools/benchmark/benchmark_c_api.h"
|
||||
|
||||
|
@ -51,8 +50,6 @@ int RunBenchmark(int argc, const char **argv) {
|
|||
BenchmarkBase *benchmark = nullptr;
|
||||
if (api_type == nullptr || std::string(api_type) == "NEW") {
|
||||
benchmark = new (std::nothrow) BenchmarkUnifiedApi(&flags);
|
||||
} else if (std::string(api_type) == "OLD") {
|
||||
benchmark = new (std::nothrow) Benchmark(&flags);
|
||||
} else if (std::string(api_type) == "C") {
|
||||
benchmark = new (std::nothrow) tools::BenchmarkCApi(&flags);
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue