forked from mindspore-Ecosystem/mindspore
Add Train API to CXX_API
This commit is contained in:
parent
e637285abb
commit
fc6d150485
|
@ -0,0 +1,100 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H
|
||||||
|
#define MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/api/data_type.h"
|
||||||
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class Model;
|
||||||
|
class ModelImpl;
|
||||||
|
class CallbackImpl;
|
||||||
|
|
||||||
|
struct TrainCallBackData {
|
||||||
|
TrainCallBackData(bool train_mode, int epoch, int step, Model *model): train_mode_(train_mode), epoch_(epoch),
|
||||||
|
step_(step), model_(model) {}
|
||||||
|
|
||||||
|
bool train_mode_; /**< training mode of LiteSession object */
|
||||||
|
unsigned int epoch_; /**< the current training epoch (starts at 0) */
|
||||||
|
unsigned int step_ = 0; /**< the current step within the epoch */
|
||||||
|
Model *model_; /**< pointer to the Model object */
|
||||||
|
};
|
||||||
|
|
||||||
|
enum CallbackRetValue : uint32_t {
|
||||||
|
kContinue = 0,
|
||||||
|
kStopTraining = 1,
|
||||||
|
kExit = 2,
|
||||||
|
kUnknownRetValue = 0xFFFFFFFF
|
||||||
|
};
|
||||||
|
|
||||||
|
class TrainCallBack {
|
||||||
|
public:
|
||||||
|
virtual ~TrainCallBack() = default;
|
||||||
|
|
||||||
|
/// \brief This method is called once before the network executing
|
||||||
|
///
|
||||||
|
/// \param[in] cb_data info about current execution
|
||||||
|
virtual void Begin(const TrainCallBackData &cb_data) {}
|
||||||
|
|
||||||
|
/// \brief This method is called once following the network execution
|
||||||
|
///
|
||||||
|
/// \param[in] cb_data info about current execution
|
||||||
|
virtual void End(const TrainCallBackData &cb_data) {}
|
||||||
|
|
||||||
|
/// \brief This method is called at the beginning of each epoch
|
||||||
|
///
|
||||||
|
/// \param[in] cb_data info about current execution
|
||||||
|
virtual void EpochBegin(const TrainCallBackData &cb_data) {}
|
||||||
|
|
||||||
|
/// \brief This method is called after the run of each epoch
|
||||||
|
///
|
||||||
|
/// \param[in] cb_data info about current execution
|
||||||
|
///
|
||||||
|
/// \return indication if to continue in the train loop:
|
||||||
|
/// RET_CONTINUE -- continue training
|
||||||
|
/// RET_STOP_TRAINING -- stop training (e.g., due to achieved accuracy)
|
||||||
|
/// RET_EXIT -- Exit training (due to error of some sort)
|
||||||
|
virtual CallbackRetValue EpochEnd(const TrainCallBackData &cb_data) { return kContinue; }
|
||||||
|
|
||||||
|
/// \brief This method is called at the beginning of each step
|
||||||
|
///
|
||||||
|
/// \param[in] cb_data info about current execution
|
||||||
|
virtual void StepBegin(const TrainCallBackData &cb_data) {}
|
||||||
|
|
||||||
|
/// \brief This method is called after each step is ran
|
||||||
|
///
|
||||||
|
/// \param[in] cb_data info about current execution
|
||||||
|
virtual void StepEnd(const TrainCallBackData &cb_data) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
friend class Model;
|
||||||
|
friend class ModelImpl;
|
||||||
|
CallbackImpl* callback_impl_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H
|
|
@ -0,0 +1,39 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_INCLUDE_API_CALLBACK_CKPT_SAVER_H
|
||||||
|
#define MINDSPORE_INCLUDE_API_CALLBACK_CKPT_SAVER_H
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/api/callback/callback.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class CkptSaver: public TrainCallBack {
|
||||||
|
public:
|
||||||
|
explicit CkptSaver(int save_every_n, const std::string &filename_prefix);
|
||||||
|
virtual ~CkptSaver();
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_INCLUDE_API_CALLBACK_CKPT_SAVER_H
|
|
@ -0,0 +1,41 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_INCLUDE_API_CALLBACK_LOSS_MONITOR_H
|
||||||
|
#define MINDSPORE_INCLUDE_API_CALLBACK_LOSS_MONITOR_H
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
#include "include/api/callback/callback.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
using GraphPoint = std::pair<int, float>;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class LossMonitor: public TrainCallBack {
|
||||||
|
public:
|
||||||
|
explicit LossMonitor(int print_every_n_steps = INT_MAX);
|
||||||
|
virtual ~LossMonitor();
|
||||||
|
const std::vector<GraphPoint> &GetLossPoints();
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_INCLUDE_API_CALLBACK_LOSS_MONITOR_H
|
|
@ -0,0 +1,57 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_INCLUDE_API_CALLBACK_LR_SCHEDULER_H
|
||||||
|
#define MINDSPORE_INCLUDE_API_CALLBACK_LR_SCHEDULER_H
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/api/callback/callback.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
constexpr int DONT_UPDATE_LR = 0;
|
||||||
|
constexpr int UPDATE_LR = 1;
|
||||||
|
|
||||||
|
using LR_Lambda = std::function<int(float *lr, int epoch, void *cb_data)>;
|
||||||
|
|
||||||
|
/// \brief Multiply the LR by a factor of gamma every epoch
|
||||||
|
int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication);
|
||||||
|
|
||||||
|
/// \brief Multiply the LR by a factor of gamma every step_size
|
||||||
|
int StepLRLambda(float *lr, int epoch, void *step_size);
|
||||||
|
struct StepLRLambda {
|
||||||
|
StepLRLambda(int step, float g) : step_size(step), gamma(g) {}
|
||||||
|
|
||||||
|
int step_size; // period of LR decay
|
||||||
|
float gamma; // LR decay factor
|
||||||
|
};
|
||||||
|
|
||||||
|
class LRScheduler: public TrainCallBack {
|
||||||
|
public:
|
||||||
|
explicit LRScheduler(LR_Lambda lambda_func, void *lr_cb_data = nullptr, int step = 1);
|
||||||
|
virtual ~LRScheduler();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_INCLUDE_API_CALLBACK_LR_SCHEDULER_H
|
|
@ -0,0 +1,40 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_INCLUDE_API_CALLBACK_TIME_MONITOR_H
|
||||||
|
#define MINDSPORE_INCLUDE_API_CALLBACK_TIME_MONITOR_H
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/api/callback/callback.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class TimeMonitor: public TrainCallBack {
|
||||||
|
public:
|
||||||
|
virtual ~TimeMonitor() = default;
|
||||||
|
void EpochBegin(const TrainCallBackData &cb_data) override;
|
||||||
|
CallbackRetValue EpochEnd(const TrainCallBackData &cb_data) override;
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_INCLUDE_API_CALLBACK_TIME_MONITOR_H
|
|
@ -0,0 +1,47 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_INCLUDE_API_CALLBACK_TRAIN_ACCURACY_H
|
||||||
|
#define MINDSPORE_INCLUDE_API_CALLBACK_TRAIN_ACCURACY_H
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include "include/api/callback/callback.h"
|
||||||
|
#include "include/api/metrics/accuracy.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
using GraphPoint = std::pair<int, float>;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class TrainAccuracy: public TrainCallBack {
|
||||||
|
public:
|
||||||
|
explicit TrainAccuracy(int print_every_n = INT_MAX,
|
||||||
|
int accuracy_metrics = METRICS_CLASSIFICATION,
|
||||||
|
const std::vector<int> &input_indexes = {1},
|
||||||
|
const std::vector<int> &output_indexes = {0});
|
||||||
|
virtual ~TrainAccuracy();
|
||||||
|
const std::vector<GraphPoint> &GetAccuracyPoints();
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_INCLUDE_API_CALLBACK_TRAIN_ACCURACY_H
|
|
@ -0,0 +1,57 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_INCLUDE_API_CFG_H
|
||||||
|
#define MINDSPORE_INCLUDE_API_CFG_H
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/api/data_type.h"
|
||||||
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class MixPrecisionCfg {
|
||||||
|
public:
|
||||||
|
MixPrecisionCfg() {
|
||||||
|
this->dynamic_loss_scale_ = false;
|
||||||
|
this->loss_scale_ = 128.0f;
|
||||||
|
this->num_of_not_nan_iter_th_ = 1000;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool dynamic_loss_scale_ = false; /**< Enable\disable dynamic loss scale during mix precision training */
|
||||||
|
float loss_scale_; /**< Initial loss scale factor */
|
||||||
|
uint32_t num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */
|
||||||
|
};
|
||||||
|
|
||||||
|
class TrainCfg {
|
||||||
|
public:
|
||||||
|
TrainCfg() { this->loss_name_ = "_loss_fn"; }
|
||||||
|
|
||||||
|
OptimizationLevel optimization_level_ = kO0;
|
||||||
|
std::string loss_name_; /**< Set part of the name that identify a loss kernel */
|
||||||
|
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_INCLUDE_API_CFG_H
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_INCLUDE_API_METRICS_ACCURACY_H
|
||||||
|
#define MINDSPORE_INCLUDE_API_METRICS_ACCURACY_H
|
||||||
|
#include <vector>
|
||||||
|
#include "include/api/metrics/metrics.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
constexpr int METRICS_CLASSIFICATION = 0;
|
||||||
|
constexpr int METRICS_MULTILABEL = 1;
|
||||||
|
|
||||||
|
class AccuracyMetrics : public Metrics {
|
||||||
|
public:
|
||||||
|
explicit AccuracyMetrics(int accuracy_metrics = METRICS_CLASSIFICATION, const std::vector<int> &input_indexes = {1},
|
||||||
|
const std::vector<int> &output_indexes = {0});
|
||||||
|
virtual ~AccuracyMetrics();
|
||||||
|
void Clear() override;
|
||||||
|
float Eval() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_INCLUDE_API_METRICS_ACCURACY_H
|
|
@ -0,0 +1,40 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_INCLUDE_API_METRICS_METRICS_H
|
||||||
|
#define MINDSPORE_INCLUDE_API_METRICS_METRICS_H
|
||||||
|
#include <vector>
|
||||||
|
#include "include/api/model.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class MetricsImpl;
|
||||||
|
class ModelImpl;
|
||||||
|
class MSTensor;
|
||||||
|
|
||||||
|
class Metrics {
|
||||||
|
public:
|
||||||
|
virtual ~Metrics() = default;
|
||||||
|
virtual void Clear() {}
|
||||||
|
virtual float Eval() { return 0.0; }
|
||||||
|
virtual void Update(std::vector<MSTensor *> inputs, std::vector<MSTensor *> outputs) {}
|
||||||
|
protected:
|
||||||
|
friend class Model;
|
||||||
|
friend class ModelImpl;
|
||||||
|
MetricsImpl* metrics_impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_INCLUDE_API_METRICS_METRICS_H
|
|
@ -25,11 +25,19 @@
|
||||||
#include "include/api/types.h"
|
#include "include/api/types.h"
|
||||||
#include "include/api/graph.h"
|
#include "include/api/graph.h"
|
||||||
#include "include/api/context.h"
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/callback/callback.h"
|
||||||
#include "include/api/cell.h"
|
#include "include/api/cell.h"
|
||||||
|
#include "include/api/cfg.h"
|
||||||
#include "include/api/dual_abi_helper.h"
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class ModelImpl;
|
class ModelImpl;
|
||||||
|
class Metrics;
|
||||||
|
|
||||||
|
namespace dataset {
|
||||||
|
class Dataset;
|
||||||
|
} // namespace dataset
|
||||||
|
|
||||||
|
|
||||||
class MS_API Model {
|
class MS_API Model {
|
||||||
public:
|
public:
|
||||||
|
@ -38,7 +46,8 @@ class MS_API Model {
|
||||||
Model(const Model &) = delete;
|
Model(const Model &) = delete;
|
||||||
void operator=(const Model &) = delete;
|
void operator=(const Model &) = delete;
|
||||||
|
|
||||||
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr);
|
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr,
|
||||||
|
const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
|
||||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
||||||
|
|
||||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||||
|
@ -47,6 +56,9 @@ class MS_API Model {
|
||||||
std::vector<MSTensor> GetInputs();
|
std::vector<MSTensor> GetInputs();
|
||||||
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
|
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
|
||||||
|
|
||||||
|
Status InitMetrics(std::vector<Metrics *> metrics);
|
||||||
|
std::vector<Metrics *> GetMetrics();
|
||||||
|
|
||||||
std::vector<MSTensor> GetOutputs();
|
std::vector<MSTensor> GetOutputs();
|
||||||
inline std::vector<std::string> GetOutputTensorNames();
|
inline std::vector<std::string> GetOutputTensorNames();
|
||||||
inline MSTensor GetOutputByTensorName(const std::string &tensor_name);
|
inline MSTensor GetOutputByTensorName(const std::string &tensor_name);
|
||||||
|
@ -54,11 +66,16 @@ class MS_API Model {
|
||||||
|
|
||||||
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
|
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
|
||||||
|
|
||||||
|
Status SetTrainMode(bool train);
|
||||||
|
bool GetTrainMode() const;
|
||||||
|
Status Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
|
||||||
|
Status Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
|
||||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
||||||
const std::string &dec_mode = "AES-GCM");
|
const std::string &dec_mode = "AES-GCM");
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class Serialization;
|
||||||
// api without std::string
|
// api without std::string
|
||||||
MSTensor GetInputByTensorName(const std::vector<char> &tensor_name);
|
MSTensor GetInputByTensorName(const std::vector<char> &tensor_name);
|
||||||
std::vector<std::vector<char>> GetOutputTensorNamesChar();
|
std::vector<std::vector<char>> GetOutputTensorNamesChar();
|
||||||
|
|
|
@ -30,18 +30,16 @@ namespace mindspore {
|
||||||
|
|
||||||
class MS_API Serialization {
|
class MS_API Serialization {
|
||||||
public:
|
public:
|
||||||
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph);
|
|
||||||
inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||||
const Key &dec_key, const std::string &dec_mode);
|
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
||||||
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph);
|
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key = {},
|
||||||
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
const std::string &dec_mode = kDecModeAesGcm);
|
||||||
const std::string &dec_mode);
|
|
||||||
inline static Status Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
|
inline static Status Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
|
||||||
const Key &dec_key = {}, const std::string &dec_mode = "AES-GCM");
|
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
||||||
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
|
|
||||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||||
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
|
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
|
||||||
|
QuantizationType quantization_type = kNoQuant, bool export_inference_only = true);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
|
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||||
|
@ -58,10 +56,6 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
|
||||||
return Load(model_data, data_size, model_type, graph, dec_key, StringToChar(dec_mode));
|
return Load(model_data, data_size, model_type, graph, dec_key, StringToChar(dec_mode));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph) {
|
|
||||||
return Load(StringToChar(file), model_type, graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||||
const std::string &dec_mode) {
|
const std::string &dec_mode) {
|
||||||
return Load(StringToChar(file), model_type, graph, dec_key, StringToChar(dec_mode));
|
return Load(StringToChar(file), model_type, graph, dec_key, StringToChar(dec_mode));
|
||||||
|
|
|
@ -78,6 +78,7 @@ enum StatusCode : uint32_t {
|
||||||
kLiteMemoryFailed = kLite | (0x0FFFFFFF & -6), /**< Fail to create memory. */
|
kLiteMemoryFailed = kLite | (0x0FFFFFFF & -6), /**< Fail to create memory. */
|
||||||
kLiteNotSupport = kLite | (0x0FFFFFFF & -7), /**< Fail to support. */
|
kLiteNotSupport = kLite | (0x0FFFFFFF & -7), /**< Fail to support. */
|
||||||
kLiteThreadPoolError = kLite | (0x0FFFFFFF & -8), /**< Error occur in thread pool. */
|
kLiteThreadPoolError = kLite | (0x0FFFFFFF & -8), /**< Error occur in thread pool. */
|
||||||
|
kLiteUninitializedObj = kLite | (0x0FFFFFFF & -9), /**< Object is not initialized. */
|
||||||
|
|
||||||
// Executor error code, range: [-100,-200)
|
// Executor error code, range: [-100,-200)
|
||||||
kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
|
kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
|
||||||
|
|
|
@ -36,10 +36,26 @@ enum ModelType : uint32_t {
|
||||||
kAIR = 1,
|
kAIR = 1,
|
||||||
kOM = 2,
|
kOM = 2,
|
||||||
kONNX = 3,
|
kONNX = 3,
|
||||||
|
kFlatBuffer = 4,
|
||||||
// insert new data type here
|
// insert new data type here
|
||||||
kUnknownType = 0xFFFFFFFF
|
kUnknownType = 0xFFFFFFFF
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum QuantizationType : uint32_t {
|
||||||
|
kNoQuant = 0,
|
||||||
|
kWeightQuant = 1,
|
||||||
|
kFullQuant = 2,
|
||||||
|
kUnknownQuantType = 0xFFFFFFFF
|
||||||
|
};
|
||||||
|
|
||||||
|
enum OptimizationLevel : uint32_t {
|
||||||
|
kO0 = 0, // Do not change
|
||||||
|
kO2 = 2, // Cast network to float16, keep batchnorm and loss in float32,
|
||||||
|
kO3 = 3, // Cast network to float16, including bacthnorm
|
||||||
|
kAuto = 4, // Choose optimization based on device
|
||||||
|
kOptimizationType = 0xFFFFFFFF
|
||||||
|
};
|
||||||
|
|
||||||
class MS_API MSTensor {
|
class MS_API MSTensor {
|
||||||
public:
|
public:
|
||||||
class Impl;
|
class Impl;
|
||||||
|
@ -149,7 +165,11 @@ using Key = struct Key {
|
||||||
size_t len;
|
size_t len;
|
||||||
unsigned char key[32];
|
unsigned char key[32];
|
||||||
Key() : len(0) {}
|
Key() : len(0) {}
|
||||||
|
explicit Key(const char *dec_key, size_t key_len);
|
||||||
};
|
};
|
||||||
|
constexpr char kDecModeAesGcm[] = "AES-GCM";
|
||||||
|
|
||||||
|
|
||||||
/// \brief CallBackParam defined input arguments for callBack function.
|
/// \brief CallBackParam defined input arguments for callBack function.
|
||||||
struct MSCallBackParam {
|
struct MSCallBackParam {
|
||||||
std::string node_name_; /**< node name argument */
|
std::string node_name_; /**< node name argument */
|
||||||
|
|
|
@ -34,7 +34,8 @@ std::string GetDeviceTypeString(enum DeviceType type) {
|
||||||
return "InvalidDeviceType" + std::to_string(static_cast<int>(type));
|
return "InvalidDeviceType" + std::to_string(static_cast<int>(type));
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_context) {
|
Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_context,
|
||||||
|
const std::shared_ptr<TrainCfg> &) {
|
||||||
if (graph_cell.GetGraph() == nullptr) {
|
if (graph_cell.GetGraph() == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid graph input.";
|
MS_LOG(ERROR) << "Invalid graph input.";
|
||||||
return kMCInvalidInput;
|
return kMCInvalidInput;
|
||||||
|
|
|
@ -79,8 +79,20 @@ static Buffer ReadFile(const std::string &file) {
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) {
|
Key::Key(const char *dec_key, size_t key_len) {
|
||||||
return Load(model_data, data_size, model_type, graph, Key{}, StringToChar("AES-GCM"));
|
len = 0;
|
||||||
|
if (key_len >= max_key_len) {
|
||||||
|
MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sec_ret = memcpy_s(key, max_key_len, dec_key, key_len);
|
||||||
|
if (sec_ret != EOK) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s failed, src_len = " << key_len << ", dst_len = " << max_key_len << ", ret = " << sec_ret;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
len = key_len;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||||
|
@ -137,7 +149,7 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
|
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
|
||||||
return Load(file, model_type, graph, Key{}, StringToChar("AES-GCM"));
|
return Load(file, model_type, graph, Key{}, StringToChar(kDecModeAesGcm));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||||
|
@ -256,11 +268,6 @@ Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelTyp
|
||||||
return Status(kMEInvalidInput, err_msg.str());
|
return Status(kMEInvalidInput, err_msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::LoadCheckPoint(const std::string &, std::map<std::string, Buffer> *) {
|
|
||||||
MS_LOG(ERROR) << "Unsupported feature.";
|
|
||||||
return kMEFailed;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Serialization::SetParameters(const std::map<std::string, Buffer> &, Model *) {
|
Status Serialization::SetParameters(const std::map<std::string, Buffer> &, Model *) {
|
||||||
MS_LOG(ERROR) << "Unsupported feature.";
|
MS_LOG(ERROR) << "Unsupported feature.";
|
||||||
return kMEFailed;
|
return kMEFailed;
|
||||||
|
@ -271,7 +278,7 @@ Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
|
||||||
return kMEFailed;
|
return kMEFailed;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::ExportModel(const Model &, ModelType, const std::string &) {
|
Status Serialization::ExportModel(const Model &, ModelType, const std::string &, QuantizationType, bool) {
|
||||||
MS_LOG(ERROR) << "Unsupported feature.";
|
MS_LOG(ERROR) << "Unsupported feature.";
|
||||||
return kMEFailed;
|
return kMEFailed;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
echo "============Exporting=========="
|
echo "============Exporting=========="
|
||||||
|
rm -f lenet_tod.mindir
|
||||||
if [ -n "$2" ]; then
|
if [ -n "$2" ]; then
|
||||||
DOCKER_IMG=$2
|
DOCKER_IMG=$2
|
||||||
docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "PYTHONPATH=../../../../../model_zoo/official/cv/lenet/src python lenet_export.py '$1'; chmod 444 lenet_tod.mindir; rm -rf __pycache__"
|
docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "PYTHONPATH=../../../../../model_zoo/official/cv/lenet/src python lenet_export.py '$1'; chmod 444 lenet_tod.mindir; rm -rf __pycache__"
|
||||||
|
|
|
@ -133,6 +133,7 @@ bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inpu
|
||||||
|
|
||||||
NetRunner::~NetRunner() {
|
NetRunner::~NetRunner() {
|
||||||
if (loop_ != nullptr) delete loop_;
|
if (loop_ != nullptr) delete loop_;
|
||||||
|
if (session_ != nullptr) delete session_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NetRunner::InitAndFigureInputs() {
|
void NetRunner::InitAndFigureInputs() {
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
*.mindir
|
||||||
|
*.ms
|
||||||
|
msl
|
||||||
|
package-*
|
|
@ -0,0 +1,57 @@
|
||||||
|
BASE_DIR=$(realpath ../../../../)
|
||||||
|
APP:=bin/net_runner
|
||||||
|
INF_APP:=bin/infer
|
||||||
|
LMSTLIB:=-lmindspore-lite-train -lminddata-lite
|
||||||
|
LMSLIB:=-lmindspore-lite
|
||||||
|
MSDIR:=$(realpath package-$(TARGET)/lib)
|
||||||
|
ifneq ("$(wildcard $(MSDIR)/libhiai.so)","")
|
||||||
|
LHIAILIB:=-lhiai_ir_build -lhiai_ir -lhiai
|
||||||
|
else
|
||||||
|
LHIAILIB:=
|
||||||
|
endif
|
||||||
|
|
||||||
|
SRC:=src/net_runner.cc
|
||||||
|
OBJ:=$(SRC:.cc=.o)
|
||||||
|
|
||||||
|
INF_SRC:=src/inference.cc
|
||||||
|
INF_OBJ:=$(INF_SRC:.cc=.o)
|
||||||
|
|
||||||
|
CFLAGS := -Ofast -std=c++17 \
|
||||||
|
-I . \
|
||||||
|
-I ./msl/runtime \
|
||||||
|
-I ./msl/runtime/minddata \
|
||||||
|
-I ./msl/tools/third_party/flatbuffers/include
|
||||||
|
|
||||||
|
|
||||||
|
ifeq ($(TARGET),arm64)
|
||||||
|
CXX := ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin/clang++
|
||||||
|
CFLAGS += --target=aarch64-none-linux-android21 --gcc-toolchain=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64 --sysroot=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot -fdata-sections -ffunction-sections
|
||||||
|
LDFLAGS := --target=aarch64-none-linux-android21 --gcc-toolchain=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64 --sysroot=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot -Wl,--gc-sections
|
||||||
|
LDFLAGS += -L$(MSDIR) $(LMSLIB) -pthread -llog -latomic -lm $(LHIAILIB) -Wl,-rpath,$(MSDIR)
|
||||||
|
else
|
||||||
|
CFLAGS += -g
|
||||||
|
LDFLAGS := -L$(MSDIR) $(LMSLIB) -lpthread -Wl,-rpath,$(MSDIR)
|
||||||
|
endif
|
||||||
|
LD := ${CXX}
|
||||||
|
|
||||||
|
|
||||||
|
all:$(APP) $(INF_APP)
|
||||||
|
|
||||||
|
$(APP): $(OBJ)
|
||||||
|
@mkdir -p bin
|
||||||
|
$(LD) $(OBJ) $(LMSTLIB) $(LDFLAGS) -o $@
|
||||||
|
|
||||||
|
$(INF_APP): $(INF_OBJ)
|
||||||
|
@mkdir -p bin
|
||||||
|
$(LD) $(INF_OBJ) $(LDFLAGS) -o $@
|
||||||
|
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -rf src/*.o bin/
|
||||||
|
|
||||||
|
|
||||||
|
mrproper:
|
||||||
|
rm -rf package* msl src/*.o bin/ model/*.mindir model/*.ms model/*.so* model/converter_lite
|
||||||
|
|
||||||
|
%.o:%.cc
|
||||||
|
$(CXX) $(CFLAGS) -c $< -o $@
|
|
@ -0,0 +1,145 @@
|
||||||
|
# Content
|
||||||
|
|
||||||
|
<!-- TOC -->
|
||||||
|
|
||||||
|
- [Overview](#overview)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Script Detailed Description](#script-detailed-description)
|
||||||
|
|
||||||
|
<!-- /TOC -->
|
||||||
|
|
||||||
|
# Overview
|
||||||
|
|
||||||
|
This folder holds code for Training-on-Device of a LeNet model. Part of the code runs on a server using MindSpore infrastructure, another part uses MindSpore Lite conversion utility, and the last part is the actual training of the model on some android-based device.
|
||||||
|
|
||||||
|
# Model Architecture
|
||||||
|
|
||||||
|
LeNet is a very simple network which is composed of only 5 layers, 2 of which are convolutional layers and the remaining 3 are fully connected layers. Such a small network can be fully trained (from scratch) on a device in a short time. Therefore, it is a good example.
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
|
||||||
|
In this example we use the MNIST dataset of handwritten digits as published in [THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)
|
||||||
|
|
||||||
|
- Dataset size:52.4M,60,000 28*28 in 10 classes
|
||||||
|
- Test:10,000 images
|
||||||
|
- Train:60,000 images
|
||||||
|
- Data format:binary files
|
||||||
|
- Note:Data will be processed in dataset.cc
|
||||||
|
|
||||||
|
- The dataset directory structure is as follows:
|
||||||
|
|
||||||
|
```text
|
||||||
|
mnist/
|
||||||
|
├── test
|
||||||
|
│ ├── t10k-images-idx3-ubyte
|
||||||
|
│ └── t10k-labels-idx1-ubyte
|
||||||
|
└── train
|
||||||
|
├── train-images-idx3-ubyte
|
||||||
|
└── train-labels-idx1-ubyte
|
||||||
|
```
|
||||||
|
|
||||||
|
# Environment Requirements
|
||||||
|
|
||||||
|
- Server side
|
||||||
|
- [MindSpore Framework](https://www.mindspore.cn/install/en): it is recommended to install a docker image
|
||||||
|
- MindSpore ToD Framework
|
||||||
|
- [Downloads](https://www.mindspore.cn/tutorial/lite/en/master/use/downloads.html)
|
||||||
|
- [Build](https://www.mindspore.cn/tutorial/lite/en/master/use/build.html)
|
||||||
|
- [Android NDK r20b](https://dl.google.com/android/repository/android-ndk-r20b-linux-x86_64.zip)
|
||||||
|
- [Android SDK](https://developer.android.com/studio?hl=zh-cn#cmdline-tools)
|
||||||
|
- A connected Android device
|
||||||
|
|
||||||
|
# Quick Start
|
||||||
|
|
||||||
|
After installing all the above mentioned, the script in the home directory could be run with the following arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sh ./prepare_and_run.sh -D DATASET_PATH [-d MINDSPORE_DOCKER] [-r RELEASE.tar.gz] [-t arm64|x86]
|
||||||
|
```
|
||||||
|
|
||||||
|
where:
|
||||||
|
|
||||||
|
- DATASET_PATH is the path to the [dataset](#dataset),
|
||||||
|
- MINDSPORE_DOCKER is the image name of the docker that runs [MindSpore](#environment-requirements). If not provided MindSpore will be run locally
|
||||||
|
- RELEASE.tar.gz is a pointer to the MindSpore ToD release tar ball. If not provided, the script will attempt to find MindSpore ToD compilation output
|
||||||
|
- target is defaulted to arm64, i.e., on-device. If x86 is provided, the demo will be run locally. Note that infrastructure is not optimized for running on x86. Also, note that user needs to call "make clean" when switching between targets.
|
||||||
|
|
||||||
|
# Script Detailed Description
|
||||||
|
|
||||||
|
The provided `prepare_and_run.sh` script is performing the following:
|
||||||
|
|
||||||
|
- Prepare the trainable lenet model in a `.ms` format
|
||||||
|
- Prepare the folder that should be pushed into the device
|
||||||
|
- Copy this folder into the device and run the scripts on the device
|
||||||
|
|
||||||
|
See how to run the script and parameters definitions in the [Quick Start Section](#quick-start)
|
||||||
|
|
||||||
|
## Preparing the model
|
||||||
|
|
||||||
|
Within the model folder a `prepare_model.sh` script uses MindSpore infrastructure to export the model into a `.mindir` file. The user can specify a docker image on which MindSpore is installed. Otherwise, the python script will be run locally.
|
||||||
|
The script then converts the `.mindir` to a `.ms` format using the MindSpore ToD converter.
|
||||||
|
The script accepts a tar ball where the converter resides. Otherwise, the script will attempt to find the converter in the MindSpore ToD build output directory.
|
||||||
|
|
||||||
|
## Preparing the Folder
|
||||||
|
|
||||||
|
The `lenet_tod.ms` model file is then copied into the `package` folder as well as scripts, the MindSpore ToD library and the MNIST dataset.
|
||||||
|
Finally, the code (in src) is compiled for arm64 and the binary is copied into the `package` folder.
|
||||||
|
|
||||||
|
### Running the code on the device
|
||||||
|
|
||||||
|
To run the code on the device the script first uses `adb` tool to push the `package` folder into the device. It then runs training (which takes some time) and finally runs evaluation of the trained model using the test data.
|
||||||
|
|
||||||
|
# Folder Directory tree
|
||||||
|
|
||||||
|
``` python
|
||||||
|
train_lenet/
|
||||||
|
├── Makefile # Makefile of src code
|
||||||
|
├── model
|
||||||
|
│ ├── lenet_export.py # Python script that exports the LeNet model to .mindir
|
||||||
|
│ ├── prepare_model.sh # script that export model (using docker) then converts it
|
||||||
|
│ └── train_utils.py # utility function used during the export
|
||||||
|
├── prepare_and_run.sh # main script that creates model, compiles it and send to device for running
|
||||||
|
├── README.md # English manual
|
||||||
|
├── README_CN.md # Chinese manual
|
||||||
|
├── scripts
|
||||||
|
│ ├── eval.sh # on-device script that load the train model and evaluates its accuracy
|
||||||
|
│ └── train.sh # on-device script that load the initial model and train it
|
||||||
|
├── src
|
||||||
|
│ ├── net_runner.cc # program that runs training/evaluation of models
|
||||||
|
│ ├── net_runner.h # net_runner header
|
||||||
|
│ └── utils.h # general utilities
|
||||||
|
```
|
||||||
|
|
||||||
|
When the `prepare_and_run.sh` script is run, the following folder is prepared. It is pushed to the device and then training runs
|
||||||
|
|
||||||
|
``` python
|
||||||
|
├── package
|
||||||
|
│ ├── bin
|
||||||
|
│ │ └── net_runner # the executable that performs the training/evaluation
|
||||||
|
│ ├── dataset
|
||||||
|
│ │ ├── test
|
||||||
|
│ │ │ ├── t10k-images-idx3-ubyte # test images
|
||||||
|
│ │ │ └── t10k-labels-idx1-ubyte # test labels
|
||||||
|
│ │ └── train
|
||||||
|
│ │ ├── train-images-idx3-ubyte # train images
|
||||||
|
│ │ └── train-labels-idx1-ubyte # train labels
|
||||||
|
│ ├── eval.sh # on-device script that load the train model and evaluates its accuracy
|
||||||
|
│ ├── lib
|
||||||
|
│ │ ├── libjpeg.so.62
|
||||||
|
│ │ ├── libminddata-lite.a
|
||||||
|
│ │ ├── libminddata-lite.so
|
||||||
|
│ │ ├── libmindspore-lite.a
|
||||||
|
│ │ ├── libmindspore-lite-jni.so
|
||||||
|
│ │ ├── libmindspore-lite.so
|
||||||
|
│ │ ├── libmindspore-lite-train.a
|
||||||
|
│ │ ├── libmindspore-lite-train-jni.so
|
||||||
|
│ │ ├── libmindspore-lite-train.so
|
||||||
|
│ │ ├── libturbojpeg.so.0
|
||||||
|
│ │ └── mindspore-lite-java.jar
|
||||||
|
│ ├── model
|
||||||
|
│ │ └── lenet_tod.ms # model to train
|
||||||
|
│ └── train.sh # on-device script that load the initial model and train it
|
||||||
|
```
|
|
@ -0,0 +1,134 @@
|
||||||
|
# 目录
|
||||||
|
|
||||||
|
<!-- TOC -->
|
||||||
|
|
||||||
|
- [目录](#目录)
|
||||||
|
- [概述](#概述)
|
||||||
|
- [数据集](#数据集)
|
||||||
|
- [环境要求](#环境要求)
|
||||||
|
- [快速入门](#快速入门)
|
||||||
|
- [脚本详述](#脚本详述)
|
||||||
|
- [模型准备](#模型准备)
|
||||||
|
- [模型训练](#模型训练)
|
||||||
|
- [工程目录](#工程目录)
|
||||||
|
|
||||||
|
<!-- /TOC -->
|
||||||
|
|
||||||
|
# 概述
|
||||||
|
|
||||||
|
本文主要讲解如何在端侧进行LeNet模型训练。首先在服务器或个人笔记本上进行模型转换;然后在安卓设备上训练模型。LeNet由2层卷积和3层全连接层组成,模型结构简单,因此可以在设备上快速训练。
|
||||||
|
|
||||||
|
# 数据集
|
||||||
|
|
||||||
|
本例使用[MNIST手写字数据集](http://yann.lecun.com/exdb/mnist/)
|
||||||
|
|
||||||
|
- 数据集大小:52.4M, 60,000 28*28 10类
|
||||||
|
- 测试集:10,000 images
|
||||||
|
- 训练集:60,000 images
|
||||||
|
|
||||||
|
- 数据格式:二进制文件
|
||||||
|
- 注意:数据处理会在dataset.cc中进行。
|
||||||
|
|
||||||
|
- 数据集目录结构如下:
|
||||||
|
|
||||||
|
```text
|
||||||
|
mnist/
|
||||||
|
├── test
|
||||||
|
│ ├── t10k-images-idx3-ubyte
|
||||||
|
│ └── t10k-labels-idx1-ubyte
|
||||||
|
└── train
|
||||||
|
├── train-images-idx3-ubyte
|
||||||
|
└── train-labels-idx1-ubyte
|
||||||
|
```
|
||||||
|
|
||||||
|
# 环境要求
|
||||||
|
|
||||||
|
- 服务器或个人笔记本
|
||||||
|
- [MindSpore Framework](https://www.mindspore.cn/install): 建议使用Docker安装
|
||||||
|
- [MindSpore ToD Download](https://www.mindspore.cn/tutorial/lite/zh-CN/master/use/downloads.html)
|
||||||
|
- [MindSpore ToD Build](https://www.mindspore.cn/tutorial/lite/zh-CN/master/use/build.html)
|
||||||
|
- [Android NDK r20b](https://dl.google.com/android/repository/android-ndk-r20b-linux-x86_64.zip)
|
||||||
|
- [Android SDK](https://developer.android.com/studio?hl=zh-cn#cmdline-tools)
|
||||||
|
- Android移动设备
|
||||||
|
|
||||||
|
# 快速入门
|
||||||
|
|
||||||
|
安装完毕,在`./mindspore/mindspore/lite/examples/train_lenet`目录下执行脚本,命令如下:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sh ./prepare_and_run.sh -D DATASET_PATH [-d MINDSPORE_DOCKER] [-r RELEASE.tar.gz] [-t arm64|x86]
|
||||||
|
```
|
||||||
|
|
||||||
|
其中,`DATASET_PATH`是数据集路径;`MINDSPORE_DOCKER`是运行MindSpore的docker镜像,如果没有使用docker环境,则使用本地运行;`REALEASE.tar.gz`为端侧运行时训练工具压缩包绝对路径;`-t`选项为设备处理器架构,默认为`arm64`,如果输入`x86`则本地运行。注意:若在不同平台执行训练,需在先执行脚本前运行`make clean`指令。
|
||||||
|
|
||||||
|
# 脚本详述
|
||||||
|
|
||||||
|
`prepare_and_run.sh`脚本的功能如下:
|
||||||
|
|
||||||
|
- 将Python模型文件转换为`.ms`文件。
|
||||||
|
- 编译训练源码并将相关文件传输到设备端
|
||||||
|
- 设备端执行训练
|
||||||
|
|
||||||
|
运行命令参见[快速入门](#快速入门)
|
||||||
|
|
||||||
|
## 模型准备
|
||||||
|
|
||||||
|
脚本`prepare_model.sh`会基于MIndSpore架构将Python模型转换为`lenet_tod.mindir`模型;然后,使用MindSpore ToD 模型转换工具将`lenet_tod.mindir`文件转换为`lenet_tod.ms`文件。如果没有docker环境,则本地执行转换。
|
||||||
|
|
||||||
|
## 模型训练
|
||||||
|
|
||||||
|
将`lenet_tod.ms`模型文件、训练脚本、MindSpore ToD库文件和`MNIST`数据集拷贝到`package`文件夹。`/src`文件夹中代码将会被编译成arm64架构版本,生成的二进制文件会被拷贝至`package`文件夹。最后使用`adb`工具将`package`文件夹传输至设备端,并执行训练。
|
||||||
|
|
||||||
|
# 工程目录
|
||||||
|
|
||||||
|
``` txt
|
||||||
|
train_lenet/
|
||||||
|
├── Makefile # Makefile of src code
|
||||||
|
├── model
|
||||||
|
│ ├── lenet_export.py # Python script that exports the LeNet model to .mindir
|
||||||
|
│ ├── prepare_model.sh # script that export model (using docker) then converts it
|
||||||
|
│ └── train_utils.py # utility function used during the export
|
||||||
|
├── prepare_and_run.sh # main script that creates model, compiles it and send to device for running
|
||||||
|
├── README.md # this manual
|
||||||
|
├── scripts
|
||||||
|
│ ├── eval.sh # on-device script that load the train model and evaluates its accuracy
|
||||||
|
│ ├── run_eval.sh # adb script that launches eval.sh
|
||||||
|
│ ├── run_train.sh # adb script that launches train.sh
|
||||||
|
│ └── train.sh # on-device script that load the initial model and train it
|
||||||
|
├── src
|
||||||
|
│ ├── dataset.cc # dataset handler
|
||||||
|
│ ├── dataset.h # dataset class header
|
||||||
|
│ ├── net_runner.cc # program that runs training/evaluation of models
|
||||||
|
│ └── net_runner.h # net_runner header
|
||||||
|
```
|
||||||
|
|
||||||
|
在脚本`prepare_and_run.sh`运行前,必须确保以下目录结构正确,这些文件将被传入设备用于训练。
|
||||||
|
|
||||||
|
``` txt
|
||||||
|
├── package
|
||||||
|
│ ├── bin
|
||||||
|
│ │ └── net_runner # the executable that performs the training/evaluation
|
||||||
|
│ ├── dataset
|
||||||
|
│ │ ├── test
|
||||||
|
│ │ │ ├── t10k-images-idx3-ubyte # test images
|
||||||
|
│ │ │ └── t10k-labels-idx1-ubyte # test labels
|
||||||
|
│ │ └── train
|
||||||
|
│ │ ├── train-images-idx3-ubyte # train images
|
||||||
|
│ │ └── train-labels-idx1-ubyte # train labels
|
||||||
|
│ ├── eval.sh # on-device script that load the train model and evaluates its accuracy
|
||||||
|
│ ├── lib
|
||||||
|
│ │ ├── libjpeg.so.62
|
||||||
|
│ │ ├── libminddata-lite.a
|
||||||
|
│ │ ├── libminddata-lite.so
|
||||||
|
│ │ ├── libmindspore-lite.a
|
||||||
|
│ │ ├── libmindspore-lite-jni.so
|
||||||
|
│ │ ├── libmindspore-lite.so
|
||||||
|
│ │ ├── libmindspore-lite-train.a
|
||||||
|
│ │ ├── libmindspore-lite-train-jni.so
|
||||||
|
│ │ ├── libmindspore-lite-train.so
|
||||||
|
│ │ ├── libturbojpeg.so.0
|
||||||
|
│ │ └── mindspore-lite-java.jar
|
||||||
|
│ ├── model
|
||||||
|
│ │ └── lenet_tod.ms # model to train
|
||||||
|
│ └── train.sh # on-device script that load the initial model and train it
|
||||||
|
```
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""lenet_export."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import context, Tensor
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.train.serialization import export
|
||||||
|
from lenet import LeNet5
|
||||||
|
from train_utils import train_wrap
|
||||||
|
|
||||||
|
|
||||||
|
n = LeNet5()
|
||||||
|
n.set_train()
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU", save_graphs=False)
|
||||||
|
|
||||||
|
BATCH_SIZE = int(sys.argv[1])
|
||||||
|
x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32)
|
||||||
|
label = Tensor(np.zeros([BATCH_SIZE]).astype(np.int32))
|
||||||
|
net = train_wrap(n)
|
||||||
|
export(net, x, label, file_name="lenet_tod", file_format='MINDIR')
|
||||||
|
|
||||||
|
print("finished exporting")
|
|
@ -0,0 +1,39 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
if [[ -z ${EXPORT} ]]; then
|
||||||
|
echo "============Exporting=========="
|
||||||
|
rm -f lenet_tod.mindir
|
||||||
|
if [ -n "$2" ]; then
|
||||||
|
DOCKER_IMG=$2
|
||||||
|
docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "PYTHONPATH=../../../../../model_zoo/official/cv/lenet/src python lenet_export.py '$1'; chmod 444 lenet_tod.mindir; rm -rf __pycache__"
|
||||||
|
else
|
||||||
|
echo "MindSpore docker was not provided, attempting to run locally"
|
||||||
|
PYTHONPATH=../../../../../model_zoo/official/cv/lenet/src python lenet_export.py $1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
CONVERTER="../../../build/tools/converter/converter_lite"
|
||||||
|
if [ ! -f "$CONVERTER" ]; then
|
||||||
|
if ! command -v converter_lite &> /dev/null
|
||||||
|
then
|
||||||
|
tar -xzf ../../../../../output/mindspore-lite-*-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite libglog.so.0 libmslite_converter_plugin.so
|
||||||
|
if [ -f ./converter_lite ]; then
|
||||||
|
CONVERTER=./converter_lite
|
||||||
|
else
|
||||||
|
echo "converter_lite could not be found in MindSpore build directory nor in system path"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
CONVERTER=converter_lite
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "============Converting========="
|
||||||
|
QUANT_OPTIONS=""
|
||||||
|
if [[ ! -z ${QUANTIZE} ]]; then
|
||||||
|
echo "Quantizing weights"
|
||||||
|
QUANT_OPTIONS="--quantType=WeightQuant --bitNum=8 --quantWeightSize=100 --quantWeightChannel=15"
|
||||||
|
fi
|
||||||
|
LD_LIBRARY_PATH=./ $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""train_utils."""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.common.parameter import ParameterTuple
|
||||||
|
|
||||||
|
|
||||||
|
def train_wrap(net, loss_fn=None, optimizer=None, weights=None):
|
||||||
|
"""
|
||||||
|
train_wrap
|
||||||
|
"""
|
||||||
|
if loss_fn is None:
|
||||||
|
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean', sparse=True)
|
||||||
|
loss_net = nn.WithLossCell(net, loss_fn)
|
||||||
|
loss_net.set_train()
|
||||||
|
if weights is None:
|
||||||
|
weights = ParameterTuple(net.trainable_params())
|
||||||
|
if optimizer is None:
|
||||||
|
optimizer = nn.Adam(weights, learning_rate=0.003, beta1=0.9, beta2=0.999, eps=1e-5, use_locking=False,
|
||||||
|
use_nesterov=False, weight_decay=4e-5, loss_scale=1.0)
|
||||||
|
train_net = nn.TrainOneStepCell(loss_net, optimizer)
|
||||||
|
return train_net
|
|
@ -0,0 +1,159 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
display_usage()
|
||||||
|
{
|
||||||
|
echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz] [-t arm64|x86] [-o] [-b virtual_batch] [-m mindir]\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
checkopts()
|
||||||
|
{
|
||||||
|
TARGET="arm64"
|
||||||
|
DOCKER=""
|
||||||
|
MNIST_DATA_PATH=""
|
||||||
|
ENABLEFP16=""
|
||||||
|
VIRTUAL_BATCH=-1
|
||||||
|
MINDIR_FILE=""
|
||||||
|
while getopts 'D:d:m:r:t:ob:' opt
|
||||||
|
do
|
||||||
|
case "${opt}" in
|
||||||
|
D)
|
||||||
|
MNIST_DATA_PATH=$OPTARG
|
||||||
|
;;
|
||||||
|
d)
|
||||||
|
DOCKER=$OPTARG
|
||||||
|
;;
|
||||||
|
m)
|
||||||
|
MINDIR_FILE=$OPTARG
|
||||||
|
;;
|
||||||
|
t)
|
||||||
|
if [ "$OPTARG" == "arm64" ] || [ "$OPTARG" == "x86" ]; then
|
||||||
|
TARGET=$OPTARG
|
||||||
|
else
|
||||||
|
echo "No such target " $OPTARG
|
||||||
|
display_usage
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
r)
|
||||||
|
TARBALL=$OPTARG
|
||||||
|
;;
|
||||||
|
o)
|
||||||
|
ENABLEFP16="-o"
|
||||||
|
;;
|
||||||
|
b)
|
||||||
|
VIRTUAL_BATCH=$OPTARG
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option ${opt}!"
|
||||||
|
display_usage
|
||||||
|
exit 1
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
checkopts "$@"
|
||||||
|
if [ "$MNIST_DATA_PATH" == "" ]; then
|
||||||
|
echo "MNIST Dataset directory path was not provided"
|
||||||
|
display_usage
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$TARBALL" == "" ]; then
|
||||||
|
if [ "${TARGET}" == "arm64" ]; then
|
||||||
|
file=$(ls ../../../../output/mindspore-lite-*-android-aarch64.tar.gz)
|
||||||
|
else
|
||||||
|
file=$(ls ../../../../output/mindspore-lite-*-linux-x64.tar.gz)
|
||||||
|
fi
|
||||||
|
if [[ ${file} != "" ]] && [[ -f ${file} ]]; then
|
||||||
|
TARBALL=${file}
|
||||||
|
else
|
||||||
|
echo "release.tar.gz was not found"
|
||||||
|
display_usage
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# Prepare the model
|
||||||
|
if [[ "${VIRTUAL_BATCH}" == "-1" ]]; then
|
||||||
|
BATCH=32
|
||||||
|
else
|
||||||
|
BATCH=1
|
||||||
|
fi
|
||||||
|
|
||||||
|
EXPORT=""
|
||||||
|
if [ "$MINDIR_FILE" != "" ]; then
|
||||||
|
cp -f $MINDIR_FILE model/lenet_tod.mindir
|
||||||
|
EXPORT="DONT_EXPORT"
|
||||||
|
fi
|
||||||
|
|
||||||
|
cd model/ || exit 1
|
||||||
|
rm -f *.ms
|
||||||
|
EXPORT=$EXPORT ./prepare_model.sh $BATCH $DOCKER || exit 1
|
||||||
|
cd ../
|
||||||
|
|
||||||
|
# Copy the .ms model to the package folder
|
||||||
|
|
||||||
|
PACKAGE=package-${TARGET}
|
||||||
|
|
||||||
|
rm -rf ${PACKAGE}
|
||||||
|
mkdir -p ${PACKAGE}/model
|
||||||
|
cp model/*.ms ${PACKAGE}/model || exit 1
|
||||||
|
|
||||||
|
# Copy the running script to the package
|
||||||
|
cp scripts/*.sh ${PACKAGE}/
|
||||||
|
|
||||||
|
# Copy the shared MindSpore ToD library
|
||||||
|
tar -xzf ${TARBALL}
|
||||||
|
mv mindspore-*/runtime/lib ${PACKAGE}/
|
||||||
|
mv mindspore-*/runtime/third_party/libjpeg-turbo/lib/* ${PACKAGE}/lib/
|
||||||
|
if [ "${TARGET}" == "arm64" ]; then
|
||||||
|
tar -xzf ${TARBALL} --wildcards --no-anchored hiai_ddk
|
||||||
|
mv mindspore-*/runtime/third_party/hiai_ddk/lib/* ${PACKAGE}/lib/
|
||||||
|
fi
|
||||||
|
|
||||||
|
rm -rf msl
|
||||||
|
mv mindspore-* msl/
|
||||||
|
|
||||||
|
# Copy the dataset to the package
|
||||||
|
cp -r $MNIST_DATA_PATH ${PACKAGE}/dataset || exit 1
|
||||||
|
cp scripts/*.dat ${PACKAGE}/dataset
|
||||||
|
|
||||||
|
echo "==========Compiling============"
|
||||||
|
make clean
|
||||||
|
make TARGET=${TARGET}
|
||||||
|
|
||||||
|
# Copy the executable to the package
|
||||||
|
mv bin ${PACKAGE}/ || exit 1
|
||||||
|
|
||||||
|
if [ "${TARGET}" == "arm64" ]; then
|
||||||
|
cp ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android/libc++_shared.so ${PACKAGE}/lib/ || exit 1
|
||||||
|
|
||||||
|
echo "=======Pushing to device======="
|
||||||
|
adb push ${PACKAGE} /data/local/tmp/
|
||||||
|
|
||||||
|
echo "========Training on Device====="
|
||||||
|
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh train.sh ${ENABLEFP16} -b ${VIRTUAL_BATCH}"
|
||||||
|
echo
|
||||||
|
|
||||||
|
echo "===Evaluating trained Model====="
|
||||||
|
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh eval.sh ${ENABLEFP16}"
|
||||||
|
echo
|
||||||
|
|
||||||
|
echo "====Running Inference Model====="
|
||||||
|
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh infer.sh"
|
||||||
|
echo
|
||||||
|
else
|
||||||
|
cd ${PACKAGE} || exit 1
|
||||||
|
echo "======Training Locally========="
|
||||||
|
./train.sh
|
||||||
|
|
||||||
|
echo "===Evaluating trained Model====="
|
||||||
|
./eval.sh
|
||||||
|
|
||||||
|
echo "====Running Inference Model====="
|
||||||
|
./infer.sh
|
||||||
|
|
||||||
|
cd ..
|
||||||
|
fi
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# an simple tutorial as follows, more parameters can be setting
|
||||||
|
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod_trained.ms -e 0 -d dataset $1
|
|
@ -0,0 +1,18 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# an simple tutorial as follows, more parameters can be setting
|
||||||
|
LD_LIBRARY_PATH=./lib/ bin/infer -f model/lenet_tod_infer.ms
|
|
@ -0,0 +1,18 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# an simple tutorial as follows, more parameters can be setting
|
||||||
|
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod.ms -e 5 -d dataset $1 $2 $3
|
|
@ -0,0 +1,88 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <getopt.h>
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include "src/utils.h"
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/graph.h"
|
||||||
|
#include "include/api/serialization.h"
|
||||||
|
|
||||||
|
static void Usage() { std::cout << "Usage: infer -f <.ms model file>" << std::endl; }
|
||||||
|
|
||||||
|
static std::string ReadArgs(int argc, char *argv[]) {
|
||||||
|
std::string infer_model_fn;
|
||||||
|
int opt;
|
||||||
|
while ((opt = getopt(argc, argv, "f:h")) != -1) {
|
||||||
|
switch (opt) {
|
||||||
|
case 'f':
|
||||||
|
infer_model_fn = std::string(optarg);
|
||||||
|
break;
|
||||||
|
case 'h':
|
||||||
|
default:
|
||||||
|
Usage();
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return infer_model_fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
std::string infer_model_fn = ReadArgs(argc, argv);
|
||||||
|
|
||||||
|
auto context = std::make_shared<mindspore::Context>();
|
||||||
|
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||||
|
cpu_context->SetEnableFP16(false);
|
||||||
|
context->MutableDeviceInfo().push_back(cpu_context);
|
||||||
|
|
||||||
|
mindspore::Graph graph;
|
||||||
|
auto status = mindspore::Serialization::Load(infer_model_fn, mindspore::kFlatBuffer, &graph);
|
||||||
|
if (status != mindspore::kSuccess) {
|
||||||
|
std::cout << "Error " << status << " during serialization of graph " << infer_model_fn;
|
||||||
|
MS_ASSERT(status != mindspore::kSuccess);
|
||||||
|
}
|
||||||
|
|
||||||
|
mindspore::Model model;
|
||||||
|
status = model.Build(mindspore::GraphCell(graph), context);
|
||||||
|
if (status != mindspore::kSuccess) {
|
||||||
|
std::cout << "Error " << status << " during build of model " << infer_model_fn;
|
||||||
|
MS_ASSERT(status != mindspore::kSuccess);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inputs = model.GetInputs();
|
||||||
|
MS_ASSERT(inputs.size() >= 1);
|
||||||
|
|
||||||
|
auto *input_data = reinterpret_cast<float *>(inputs.at(0).MutableData());
|
||||||
|
std::ifstream in;
|
||||||
|
in.open("dataset/batch_of32.dat", std::ios::in | std::ios::binary);
|
||||||
|
in.read(reinterpret_cast<char *>(&input_data), inputs.at(0).ElementNum() * sizeof(float));
|
||||||
|
in.close();
|
||||||
|
|
||||||
|
std::vector<mindspore::MSTensor> outputs;
|
||||||
|
status = model.Predict(inputs, &outputs);
|
||||||
|
if (status != mindspore::kSuccess) {
|
||||||
|
std::cout << "Error " << status << " during running predict of model " << infer_model_fn;
|
||||||
|
MS_ASSERT(status != mindspore::kSuccess);
|
||||||
|
}
|
||||||
|
std::cout << "Got Vector of size: " << outputs.size() << std::endl;
|
||||||
|
for (auto tensor : outputs) {
|
||||||
|
std::cout << "[ " << tensor.Shape().at(0) << ", " << tensor.Shape().at(1) << "]\n";
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,304 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/net_runner.h"
|
||||||
|
#include <math.h>
|
||||||
|
#include <getopt.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <malloc.h>
|
||||||
|
#include <cstring>
|
||||||
|
#include <chrono>
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <utility>
|
||||||
|
#include "include/context.h"
|
||||||
|
#include "include/api/serialization.h"
|
||||||
|
#include "include/api/callback/loss_monitor.h"
|
||||||
|
#include "include/api/metrics/accuracy.h"
|
||||||
|
#include "include/api/callback/ckpt_saver.h"
|
||||||
|
#include "include/api/callback/train_accuracy.h"
|
||||||
|
#include "include/api/callback/lr_scheduler.h"
|
||||||
|
#include "src/utils.h"
|
||||||
|
#include "include/dataset/datasets.h"
|
||||||
|
#include "include/dataset/vision_lite.h"
|
||||||
|
#include "include/dataset/transforms.h"
|
||||||
|
|
||||||
|
using mindspore::AccuracyMetrics;
|
||||||
|
using mindspore::Model;
|
||||||
|
using mindspore::TrainAccuracy;
|
||||||
|
using mindspore::TrainCallBack;
|
||||||
|
using mindspore::TrainCallBackData;
|
||||||
|
using mindspore::dataset::Dataset;
|
||||||
|
using mindspore::dataset::Mnist;
|
||||||
|
using mindspore::dataset::TensorOperation;
|
||||||
|
using mindspore::dataset::transforms::TypeCast;
|
||||||
|
using mindspore::dataset::vision::Normalize;
|
||||||
|
using mindspore::dataset::vision::Resize;
|
||||||
|
|
||||||
|
constexpr int kPrintNum = 10;
|
||||||
|
constexpr float kScalePoint = 255.0f;
|
||||||
|
constexpr int kBatchSize = 2;
|
||||||
|
constexpr int kNCHWDims = 4;
|
||||||
|
constexpr int kNCHWCDim = 2;
|
||||||
|
constexpr int kPrintTimes = 100;
|
||||||
|
constexpr int kSaveEpochs = 3;
|
||||||
|
constexpr float kGammaFactor = 0.7f;
|
||||||
|
class Rescaler : public mindspore::TrainCallBack {
|
||||||
|
public:
|
||||||
|
explicit Rescaler(float scale) : scale_(scale) {
|
||||||
|
if (scale_ == 0) {
|
||||||
|
scale_ = 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
~Rescaler() override = default;
|
||||||
|
void StepBegin(const mindspore::TrainCallBackData &cb_data) override {
|
||||||
|
auto inputs = cb_data.model_->GetInputs();
|
||||||
|
auto *input_data = reinterpret_cast<float *>(inputs.at(0).MutableData());
|
||||||
|
for (int k = 0; k < inputs.at(0).ElementNum(); k++) input_data[k] /= scale_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
float scale_ = 1.0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/* This is an example of a user defined Callback to measure memory and latency of execution */
|
||||||
|
class Measurement : public mindspore::TrainCallBack {
|
||||||
|
public:
|
||||||
|
explicit Measurement(unsigned int epochs)
|
||||||
|
: epochs_(epochs), time_avg_(std::chrono::duration<double, std::milli>(0)) {}
|
||||||
|
~Measurement() override = default;
|
||||||
|
void EpochBegin(const mindspore::TrainCallBackData &cb_data) override {
|
||||||
|
start_time_ = std::chrono::high_resolution_clock::now();
|
||||||
|
}
|
||||||
|
mindspore::CallbackRetValue EpochEnd(const mindspore::TrainCallBackData &cb_data) override {
|
||||||
|
end_time_ = std::chrono::high_resolution_clock::now();
|
||||||
|
auto time = std::chrono::duration<double, std::milli>(end_time_ - start_time_);
|
||||||
|
time_avg_ += time;
|
||||||
|
return mindspore::kContinue;
|
||||||
|
}
|
||||||
|
void End(const mindspore::TrainCallBackData &cb_data) override {
|
||||||
|
if (epochs_ > 0) {
|
||||||
|
std::cout << "AvgRunTime: " << time_avg_.count() / epochs_ << " ms" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct mallinfo info = mallinfo();
|
||||||
|
std::cout << "Total allocation: " << info.arena + info.hblkhd << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::chrono::time_point<std::chrono::high_resolution_clock> start_time_;
|
||||||
|
std::chrono::time_point<std::chrono::high_resolution_clock> end_time_;
|
||||||
|
std::chrono::duration<double, std::milli> time_avg_;
|
||||||
|
unsigned int epochs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Definition of verbose callback function after forwarding operator.
|
||||||
|
bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
|
||||||
|
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
|
||||||
|
const mindspore::CallBackParam &call_param) {
|
||||||
|
printf("%s\n", call_param.node_name.c_str());
|
||||||
|
for (size_t i = 0; i < after_inputs.size(); i++) {
|
||||||
|
int num2p = (after_inputs.at(i)->ElementsNum());
|
||||||
|
printf("in%zu(%d): ", i, num2p);
|
||||||
|
if (num2p > kPrintNum) num2p = kPrintNum;
|
||||||
|
if (after_inputs.at(i)->data_type() == mindspore::kNumberTypeInt32) {
|
||||||
|
auto d = reinterpret_cast<int *>(after_inputs.at(i)->MutableData());
|
||||||
|
for (int j = 0; j < num2p; j++) printf("%d, ", d[j]);
|
||||||
|
} else {
|
||||||
|
auto d = reinterpret_cast<float *>(after_inputs.at(i)->MutableData());
|
||||||
|
for (int j = 0; j < num2p; j++) printf("%f, ", d[j]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < after_outputs.size(); i++) {
|
||||||
|
auto d = reinterpret_cast<float *>(after_outputs.at(i)->MutableData());
|
||||||
|
int num2p = (after_outputs.at(i)->ElementsNum());
|
||||||
|
printf("ou%zu(%d): ", i, num2p);
|
||||||
|
if (num2p > 10) num2p = 10;
|
||||||
|
for (int j = 0; j < num2p; j++) printf("%f, ", d[j]);
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
NetRunner::~NetRunner() {}
|
||||||
|
|
||||||
|
void NetRunner::InitAndFigureInputs() {
|
||||||
|
auto context = std::make_shared<mindspore::Context>();
|
||||||
|
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||||
|
cpu_context->SetEnableFP16(enable_fp16_);
|
||||||
|
context->MutableDeviceInfo().push_back(cpu_context);
|
||||||
|
|
||||||
|
graph_ = new mindspore::Graph();
|
||||||
|
auto status = mindspore::Serialization::Load(ms_file_, mindspore::kFlatBuffer, graph_);
|
||||||
|
if (status != mindspore::kSuccess) {
|
||||||
|
std::cout << "Error " << status << " during serialization of graph " << ms_file_;
|
||||||
|
MS_ASSERT(status != mindspore::kSuccess);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cfg = std::make_shared<mindspore::TrainCfg>();
|
||||||
|
if (enable_fp16_) {
|
||||||
|
cfg.get()->optimization_level_ = mindspore::kO2;
|
||||||
|
}
|
||||||
|
|
||||||
|
model_ = new mindspore::Model();
|
||||||
|
status = model_->Build(mindspore::GraphCell(*graph_), context, cfg);
|
||||||
|
if (status != mindspore::kSuccess) {
|
||||||
|
std::cout << "Error " << status << " during build of model " << ms_file_;
|
||||||
|
MS_ASSERT(status != mindspore::kSuccess);
|
||||||
|
}
|
||||||
|
|
||||||
|
// if (verbose_) {
|
||||||
|
// loop_->SetKernelCallBack(nullptr, after_callback);
|
||||||
|
//}
|
||||||
|
acc_metrics_ = std::shared_ptr<AccuracyMetrics>(new AccuracyMetrics);
|
||||||
|
model_->InitMetrics({acc_metrics_.get()});
|
||||||
|
|
||||||
|
auto inputs = model_->GetInputs();
|
||||||
|
MS_ASSERT(inputs.size() >= 1);
|
||||||
|
auto nhwc_input_dims = inputs.at(0).Shape();
|
||||||
|
// MS_ASSERT(nhwc_input_dims.size() == kNCHWDims);
|
||||||
|
batch_size_ = nhwc_input_dims.at(0);
|
||||||
|
h_ = nhwc_input_dims.at(1);
|
||||||
|
w_ = nhwc_input_dims.at(kNCHWCDim);
|
||||||
|
}
|
||||||
|
|
||||||
|
float NetRunner::CalculateAccuracy(int max_tests) {
|
||||||
|
test_ds_ = Mnist(data_dir_ + "/test", "all");
|
||||||
|
TypeCast typecast_f(mindspore::DataType::kNumberTypeFloat32);
|
||||||
|
Resize resize({h_, w_});
|
||||||
|
test_ds_ = test_ds_->Map({&resize, &typecast_f}, {"image"});
|
||||||
|
|
||||||
|
TypeCast typecast(mindspore::DataType::kNumberTypeInt32);
|
||||||
|
test_ds_ = test_ds_->Map({&typecast}, {"label"});
|
||||||
|
test_ds_ = test_ds_->Batch(batch_size_, true);
|
||||||
|
|
||||||
|
// Rescaler rescale(kScalePoint);
|
||||||
|
|
||||||
|
model_->Evaluate(test_ds_, {});
|
||||||
|
std::cout << "Accuracy is " << acc_metrics_->Eval() << std::endl;
|
||||||
|
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int NetRunner::InitDB() {
|
||||||
|
train_ds_ = Mnist(data_dir_ + "/train", "all");
|
||||||
|
|
||||||
|
TypeCast typecast_f(mindspore::DataType::kNumberTypeFloat32);
|
||||||
|
Resize resize({h_, w_});
|
||||||
|
train_ds_ = train_ds_->Map({&resize, &typecast_f}, {"image"});
|
||||||
|
|
||||||
|
TypeCast typecast(mindspore::DataType::kNumberTypeInt32);
|
||||||
|
train_ds_ = train_ds_->Map({&typecast}, {"label"});
|
||||||
|
|
||||||
|
train_ds_ = train_ds_->Batch(batch_size_, true);
|
||||||
|
|
||||||
|
if (verbose_) {
|
||||||
|
std::cout << "DatasetSize is " << train_ds_->GetDatasetSize() << std::endl;
|
||||||
|
}
|
||||||
|
if (train_ds_->GetDatasetSize() == 0) {
|
||||||
|
std::cout << "No relevant data was found in " << data_dir_ << std::endl;
|
||||||
|
MS_ASSERT(train_ds_->GetDatasetSize() != 0);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int NetRunner::TrainLoop() {
|
||||||
|
mindspore::LossMonitor lm(kPrintTimes);
|
||||||
|
mindspore::TrainAccuracy am(1);
|
||||||
|
|
||||||
|
mindspore::CkptSaver cs(kSaveEpochs, std::string("lenet"));
|
||||||
|
Rescaler rescale(kScalePoint);
|
||||||
|
Measurement measure(epochs_);
|
||||||
|
|
||||||
|
if (virtual_batch_ > 0) {
|
||||||
|
model_->Train(epochs_, train_ds_, {&rescale, &lm, &cs, &measure});
|
||||||
|
} else {
|
||||||
|
struct mindspore::StepLRLambda step_lr_lambda(1, kGammaFactor);
|
||||||
|
mindspore::LRScheduler step_lr_sched(mindspore::StepLRLambda, static_cast<void *>(&step_lr_lambda), 1);
|
||||||
|
model_->Train(epochs_, train_ds_, {&rescale, &lm, &cs, &am, &step_lr_sched, &measure});
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int NetRunner::Main() {
|
||||||
|
InitAndFigureInputs();
|
||||||
|
|
||||||
|
InitDB();
|
||||||
|
|
||||||
|
TrainLoop();
|
||||||
|
|
||||||
|
CalculateAccuracy();
|
||||||
|
|
||||||
|
if (epochs_ > 0) {
|
||||||
|
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms";
|
||||||
|
mindspore::Serialization::ExportModel(*model_, mindspore::kFlatBuffer, trained_fn, mindspore::kNoQuant, false);
|
||||||
|
trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_infer.ms";
|
||||||
|
mindspore::Serialization::ExportModel(*model_, mindspore::kFlatBuffer, trained_fn, mindspore::kNoQuant, true);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NetRunner::Usage() {
|
||||||
|
std::cout << "Usage: net_runner -f <.ms model file> -d <data_dir> [-e <num of training epochs>] "
|
||||||
|
<< "[-v (verbose mode)] [-s <save checkpoint every X iterations>]" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NetRunner::ReadArgs(int argc, char *argv[]) {
|
||||||
|
int opt;
|
||||||
|
while ((opt = getopt(argc, argv, "f:e:d:s:ihc:vob:")) != -1) {
|
||||||
|
switch (opt) {
|
||||||
|
case 'f':
|
||||||
|
ms_file_ = std::string(optarg);
|
||||||
|
break;
|
||||||
|
case 'e':
|
||||||
|
epochs_ = atoi(optarg);
|
||||||
|
break;
|
||||||
|
case 'd':
|
||||||
|
data_dir_ = std::string(optarg);
|
||||||
|
break;
|
||||||
|
case 'v':
|
||||||
|
verbose_ = true;
|
||||||
|
break;
|
||||||
|
case 's':
|
||||||
|
save_checkpoint_ = atoi(optarg);
|
||||||
|
break;
|
||||||
|
case 'o':
|
||||||
|
enable_fp16_ = true;
|
||||||
|
break;
|
||||||
|
case 'b':
|
||||||
|
virtual_batch_ = atoi(optarg);
|
||||||
|
break;
|
||||||
|
case 'h':
|
||||||
|
default:
|
||||||
|
Usage();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
NetRunner nr;
|
||||||
|
|
||||||
|
if (nr.ReadArgs(argc, argv)) {
|
||||||
|
nr.Main();
|
||||||
|
} else {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_EXAMPLES_UNIFIED_API_SRC_NET_RUNNER_H_
|
||||||
|
#define MINDSPORE_LITE_EXAMPLES_UNIFIED_API_SRC_NET_RUNNER_H_
|
||||||
|
|
||||||
|
#include <tuple>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/api/graph.h"
|
||||||
|
#include "include/api/metrics/accuracy.h"
|
||||||
|
#include "include/dataset/datasets.h"
|
||||||
|
|
||||||
|
using mindspore::AccuracyMetrics;
|
||||||
|
using mindspore::dataset::Dataset;
|
||||||
|
|
||||||
|
class NetRunner {
|
||||||
|
public:
|
||||||
|
int Main();
|
||||||
|
bool ReadArgs(int argc, char *argv[]);
|
||||||
|
~NetRunner();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Usage();
|
||||||
|
void InitAndFigureInputs();
|
||||||
|
int InitDB();
|
||||||
|
int TrainLoop();
|
||||||
|
float CalculateAccuracy(int max_tests = 0);
|
||||||
|
float GetLoss() const;
|
||||||
|
|
||||||
|
mindspore::Model *model_ = nullptr;
|
||||||
|
mindspore::Graph *graph_ = nullptr;
|
||||||
|
|
||||||
|
std::shared_ptr<Dataset> train_ds_;
|
||||||
|
std::shared_ptr<Dataset> test_ds_;
|
||||||
|
std::shared_ptr<AccuracyMetrics> acc_metrics_;
|
||||||
|
|
||||||
|
std::string ms_file_ = "";
|
||||||
|
std::string data_dir_ = "";
|
||||||
|
unsigned int epochs_ = 10;
|
||||||
|
bool verbose_ = false;
|
||||||
|
bool enable_fp16_ = false;
|
||||||
|
int virtual_batch_ = -1;
|
||||||
|
int save_checkpoint_ = 0;
|
||||||
|
int batch_size_ = 32;
|
||||||
|
int h_ = 32;
|
||||||
|
int w_ = 32;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_EXAMPLES_UNIFIED_API_SRC_NET_RUNNER_H_
|
|
@ -0,0 +1,31 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_EXAMPLES_UNIFIED_API_SRC_UTILS_H_
|
||||||
|
#define MINDSPORE_LITE_EXAMPLES_UNIFIED_API_SRC_UTILS_H_
|
||||||
|
|
||||||
|
// DEBUG should be defined because the source code use assert to test exception
|
||||||
|
// but the Code Analysis Tool not allow us to use assert directly
|
||||||
|
#define DEBUG TRUE
|
||||||
|
|
||||||
|
#ifdef DEBUG
|
||||||
|
#include <cassert>
|
||||||
|
#define MS_ASSERT(f) assert(f)
|
||||||
|
#else
|
||||||
|
#define MS_ASSERT(f) ((void)0)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_EXAMPLES_UNIFIED_API_SRC_UTILS_H_
|
|
@ -35,17 +35,25 @@ else()
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
file(GLOB CXX_API_SRCS
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/*.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/*.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/graph/*.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor/*.cc
|
||||||
|
)
|
||||||
|
|
||||||
set(API_SRC
|
set(API_SRC
|
||||||
${CORE_DIR}/utils/status.cc
|
${CORE_DIR}/utils/status.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/cell.cc
|
${CXX_API_SRCS}
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/serialization.cc
|
)
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor_utils.cc
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/types.cc
|
file(GLOB CXX_API_TRAIN_SRCS
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/context.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/*.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/model.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/metrics/*.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/model_impl.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/callback/*.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/graph/graph.cc
|
)
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor/tensor_impl.cc
|
set(API_TRAIN_SRC
|
||||||
|
${CXX_API_TRAIN_SRCS}
|
||||||
)
|
)
|
||||||
|
|
||||||
if(SUPPORT_NPU)
|
if(SUPPORT_NPU)
|
||||||
|
@ -121,6 +129,7 @@ if(MSLITE_GPU_BACKEND STREQUAL cuda)
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
set(TRAIN_SRC
|
set(TRAIN_SRC
|
||||||
|
${API_TRAIN_SRC}
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/common/quant_utils.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/common/quant_utils.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_ADAPTER_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_ADAPTER_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/cell.h"
|
||||||
|
#include "include/lite_session.h"
|
||||||
|
#include "include/train/train_loop_callback.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class TrainLoopCallBackAdapter : public session::TrainLoopCallBack {
|
||||||
|
public:
|
||||||
|
explicit TrainLoopCallBackAdapter(Model *model, TrainCallBack *call_back) : model_(model), call_back_(call_back) {}
|
||||||
|
TrainLoopCallBackAdapter() = delete;
|
||||||
|
|
||||||
|
void Begin(const session::TrainLoopCallBackData &i_cb_data) override {
|
||||||
|
call_back_->Begin(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
|
||||||
|
};
|
||||||
|
|
||||||
|
void End(const session::TrainLoopCallBackData &i_cb_data) override {
|
||||||
|
call_back_->End(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
|
||||||
|
};
|
||||||
|
|
||||||
|
void EpochBegin(const session::TrainLoopCallBackData &i_cb_data) override {
|
||||||
|
call_back_->EpochBegin(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
|
||||||
|
};
|
||||||
|
|
||||||
|
int EpochEnd(const session::TrainLoopCallBackData &i_cb_data) override {
|
||||||
|
return call_back_->EpochEnd(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
|
||||||
|
};
|
||||||
|
|
||||||
|
void StepBegin(const session::TrainLoopCallBackData &i_cb_data) override {
|
||||||
|
call_back_->StepBegin(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
|
||||||
|
};
|
||||||
|
|
||||||
|
void StepEnd(const session::TrainLoopCallBackData &i_cb_data) override {
|
||||||
|
call_back_->StepEnd(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_));
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
Model *model_;
|
||||||
|
TrainCallBack *call_back_;
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_ADAPTER_H_
|
|
@ -0,0 +1,46 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_IMPL_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_IMPL_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/cell.h"
|
||||||
|
#include "include/lite_session.h"
|
||||||
|
#include "include/train/train_loop_callback.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class CallbackImpl {
|
||||||
|
public:
|
||||||
|
CallbackImpl() = delete;
|
||||||
|
explicit CallbackImpl(session::TrainLoopCallBack *cb) : internal_call_back_(cb) {}
|
||||||
|
session::TrainLoopCallBack *GetInternalCallback() { return internal_call_back_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
session::TrainLoopCallBack *internal_call_back_ = nullptr;
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_IMPL_H_
|
|
@ -0,0 +1,41 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/train/ckpt_saver.h"
|
||||||
|
#include "include/api/callback/ckpt_saver.h"
|
||||||
|
#include "src/cxx_api/callback/callback_impl.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
CkptSaver::CkptSaver(int save_every_n, const std::string &filename_prefix) {
|
||||||
|
callback_impl_ = new CallbackImpl(new lite::CkptSaver(save_every_n, filename_prefix));
|
||||||
|
}
|
||||||
|
|
||||||
|
CkptSaver::~CkptSaver() {
|
||||||
|
if (callback_impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Callback implement is null.";
|
||||||
|
}
|
||||||
|
auto internal_call_back = callback_impl_->GetInternalCallback();
|
||||||
|
if (internal_call_back != nullptr) {
|
||||||
|
delete internal_call_back;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/train/loss_monitor.h"
|
||||||
|
#include "include/api/callback/loss_monitor.h"
|
||||||
|
#include "src/cxx_api/callback/callback_impl.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
LossMonitor::LossMonitor(int print_every_n_steps) {
|
||||||
|
callback_impl_ = new CallbackImpl(new lite::LossMonitor(print_every_n_steps));
|
||||||
|
}
|
||||||
|
|
||||||
|
LossMonitor::~LossMonitor() {
|
||||||
|
if (callback_impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Callback implement is null.";
|
||||||
|
}
|
||||||
|
auto internal_call_back = callback_impl_->GetInternalCallback();
|
||||||
|
if (internal_call_back != nullptr) {
|
||||||
|
delete internal_call_back;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<GraphPoint> &LossMonitor::GetLossPoints() {
|
||||||
|
static std::vector<GraphPoint> empty_vector;
|
||||||
|
if (callback_impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Callback implement is null.";
|
||||||
|
return empty_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
session::TrainLoopCallBack *internal_call_back = callback_impl_->GetInternalCallback();
|
||||||
|
if (internal_call_back == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Internal callback is null.";
|
||||||
|
return empty_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (reinterpret_cast<lite::LossMonitor *>(internal_call_back))->GetLossPoints();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/train/lr_scheduler.h"
|
||||||
|
#include "include/api/callback/lr_scheduler.h"
|
||||||
|
#include "src/cxx_api/callback/callback_impl.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
int StepLRLambda(float *lr, int epoch, void *lr_cb_data) {
|
||||||
|
if ((lr == nullptr) || (lr_cb_data == nullptr)) {
|
||||||
|
MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda";
|
||||||
|
return DONT_UPDATE_LR;
|
||||||
|
}
|
||||||
|
struct StepLRLambda *step_lr_data = (static_cast<struct StepLRLambda *>(lr_cb_data));
|
||||||
|
if (((epoch + 1) % step_lr_data->step_size) == 0) {
|
||||||
|
*lr = *lr * step_lr_data->gamma;
|
||||||
|
return UPDATE_LR;
|
||||||
|
}
|
||||||
|
return DONT_UPDATE_LR;
|
||||||
|
}
|
||||||
|
|
||||||
|
LRScheduler::LRScheduler(LR_Lambda lambda_func, void *lr_cb_data, int step) {
|
||||||
|
callback_impl_ = new CallbackImpl(new lite::LRScheduler(lambda_func, lr_cb_data, step));
|
||||||
|
}
|
||||||
|
|
||||||
|
LRScheduler::~LRScheduler() {
|
||||||
|
if (callback_impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Callback implement is null.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto internal_call_back = callback_impl_->GetInternalCallback();
|
||||||
|
if (internal_call_back != nullptr) {
|
||||||
|
delete internal_call_back;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/train/classification_train_accuracy_monitor.h"
|
||||||
|
#include "include/api/callback/train_accuracy.h"
|
||||||
|
#include "src/cxx_api/callback/callback_impl.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
TrainAccuracy::TrainAccuracy(int print_every_n, int accuracy_metrics, const std::vector<int> &input_indexes,
|
||||||
|
const std::vector<int> &output_indexes) {
|
||||||
|
callback_impl_ = new CallbackImpl(
|
||||||
|
new lite::ClassificationTrainAccuracyMonitor(print_every_n, accuracy_metrics, input_indexes, output_indexes));
|
||||||
|
}
|
||||||
|
|
||||||
|
TrainAccuracy::~TrainAccuracy() {
|
||||||
|
if (callback_impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Callback implement is null.";
|
||||||
|
}
|
||||||
|
auto internal_call_back = callback_impl_->GetInternalCallback();
|
||||||
|
if (internal_call_back != nullptr) {
|
||||||
|
delete internal_call_back;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<GraphPoint> &TrainAccuracy::GetAccuracyPoints() {
|
||||||
|
static std::vector<GraphPoint> empty_vector;
|
||||||
|
if (callback_impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Callback implement is null.";
|
||||||
|
return empty_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
session::TrainLoopCallBack *internal_call_back = callback_impl_->GetInternalCallback();
|
||||||
|
if (internal_call_back == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Internal callback is null.";
|
||||||
|
return empty_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (reinterpret_cast<lite::ClassificationTrainAccuracyMonitor *>(internal_call_back))->GetAccuracyPoints();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -347,4 +347,5 @@ enum DataType Ascend310DeviceInfo::GetOutputType() const {
|
||||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
return DataType::kTypeUnknown;
|
return DataType::kTypeUnknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "src/cxx_api/converters.h"
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/context.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "src/runtime/inner_allocator.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
Status A2L_ConvertContext(Context *a_context, lite::Context *l_context) {
|
||||||
|
if ((a_context == nullptr) || (l_context == nullptr)) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context pointers.";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto device_list = a_context->MutableDeviceInfo();
|
||||||
|
if (device_list.size() == 0) {
|
||||||
|
MS_LOG(ERROR) << "Invalid device list.";
|
||||||
|
return kLiteInputParamInvalid;
|
||||||
|
}
|
||||||
|
if (device_list.size() > 2) {
|
||||||
|
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
|
||||||
|
return kLiteInputParamInvalid;
|
||||||
|
}
|
||||||
|
l_context->thread_num_ = a_context->GetThreadNum();
|
||||||
|
l_context->enable_parallel_ = a_context->GetEnableParallel();
|
||||||
|
l_context->affinity_core_list_ = a_context->GetThreadAffinityCoreList();
|
||||||
|
l_context->device_list_.clear();
|
||||||
|
if (device_list[0]->GetDeviceType() != kCPU) {
|
||||||
|
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
|
||||||
|
return kLiteInputParamInvalid;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
|
||||||
|
l_context->allocator = cpu_context->GetAllocator();
|
||||||
|
if (l_context->allocator == nullptr) {
|
||||||
|
l_context->allocator = Allocator::Create();
|
||||||
|
if (l_context->allocator == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create Allocator failed.";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Set new allocator.";
|
||||||
|
cpu_context->SetAllocator(l_context->allocator);
|
||||||
|
}
|
||||||
|
|
||||||
|
lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->GetThreadAffinityMode());
|
||||||
|
|
||||||
|
lite::DeviceInfo cpu_info = {0};
|
||||||
|
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
|
||||||
|
l_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
|
||||||
|
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
|
||||||
|
if (device_list.size() == 2) {
|
||||||
|
lite::DeviceInfo device_info = {0};
|
||||||
|
if (device_list[1]->GetDeviceType() == kMaliGPU) {
|
||||||
|
auto gpu_context = device_list[1]->Cast<MaliGPUDeviceInfo>();
|
||||||
|
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
|
||||||
|
l_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
|
||||||
|
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
|
||||||
|
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
|
||||||
|
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
|
||||||
|
device_info.npu_device_info_ = {npu_context->GetFrequency()};
|
||||||
|
l_context->device_list_.push_back({lite::DT_NPU, device_info});
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Invalid device.";
|
||||||
|
return kLiteInputParamInvalid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_CXX_API_CONVERTERS_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_CXX_API_CONVERTERS_H_
|
||||||
|
|
||||||
|
#include <limits.h>
|
||||||
|
#include "include/api/status.h"
|
||||||
|
#include "include/api/types.h"
|
||||||
|
#include "include/lite_types.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
namespace lite {
|
||||||
|
struct Context;
|
||||||
|
class TrainCfg;
|
||||||
|
} // namespace lite
|
||||||
|
|
||||||
|
class Context;
|
||||||
|
class TrainCfg;
|
||||||
|
|
||||||
|
inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) {
|
||||||
|
if (qt == kNoQuant) {
|
||||||
|
return lite::QT_NONE;
|
||||||
|
}
|
||||||
|
if (qt == kWeightQuant) {
|
||||||
|
return lite::QT_WEIGHT;
|
||||||
|
}
|
||||||
|
return lite::QT_DEFAULT;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline lite::CpuBindMode A2L_ConvertAffinityMode(int affinity_mode) {
|
||||||
|
switch (affinity_mode) {
|
||||||
|
case 0:
|
||||||
|
return lite::NO_BIND;
|
||||||
|
case 1:
|
||||||
|
return lite::HIGHER_CPU;
|
||||||
|
case 2:
|
||||||
|
return lite::MID_CPU;
|
||||||
|
default:
|
||||||
|
return lite::NO_BIND;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status A2L_ConvertContext(Context *a_context, lite::Context *l_context);
|
||||||
|
|
||||||
|
Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cfg);
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_SRC_CXX_API_CONVERTERS_H_
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_CXX_API_GRAPH_GRAPH_DATA_H
|
#ifndef MINDSPORE_LITE_SRC_CXX_API_GRAPH_GRAPH_DATA_H_
|
||||||
#define MINDSPORE_LITE_SRC_CXX_API_GRAPH_GRAPH_DATA_H
|
#define MINDSPORE_LITE_SRC_CXX_API_GRAPH_GRAPH_DATA_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -36,9 +36,12 @@ class Graph::GraphData {
|
||||||
|
|
||||||
std::shared_ptr<lite::Model> lite_model() { return lite_model_; }
|
std::shared_ptr<lite::Model> lite_model() { return lite_model_; }
|
||||||
|
|
||||||
|
bool IsTrainModel() { return true; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<lite::Model> lite_model_ = nullptr;
|
std::shared_ptr<lite::Model> lite_model_ = nullptr;
|
||||||
|
std::string file_name_;
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_CXX_API_GRAPH_GRAPH_DATA_H
|
#endif // MINDSPORE_LITE_SRC_CXX_API_GRAPH_GRAPH_DATA_H_
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/train/accuracy_metrics.h"
|
||||||
|
#include "include/api/metrics/accuracy.h"
|
||||||
|
#include "src/cxx_api/metrics/metrics_impl.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
AccuracyMetrics::AccuracyMetrics(int accuracy_metrics, const std::vector<int> &input_indexes,
|
||||||
|
const std::vector<int> &output_indexes) {
|
||||||
|
metrics_impl_ = new MetricsImpl(new lite::AccuracyMetrics(accuracy_metrics, input_indexes, output_indexes));
|
||||||
|
}
|
||||||
|
|
||||||
|
AccuracyMetrics::~AccuracyMetrics() {
|
||||||
|
if (metrics_impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Metrics implement is null.";
|
||||||
|
}
|
||||||
|
auto internal_metrics = metrics_impl_->GetInternalMetrics();
|
||||||
|
if (internal_metrics != nullptr) {
|
||||||
|
delete internal_metrics;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AccuracyMetrics::Clear() {
|
||||||
|
if (metrics_impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Metrics implement is null.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto internal_metrics = metrics_impl_->GetInternalMetrics();
|
||||||
|
(reinterpret_cast<lite::AccuracyMetrics *>(internal_metrics))->Clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
float AccuracyMetrics::Eval() {
|
||||||
|
if (metrics_impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Metrics implement is null.";
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
auto internal_metrics = metrics_impl_->GetInternalMetrics();
|
||||||
|
return (reinterpret_cast<lite::AccuracyMetrics *>(internal_metrics))->Eval();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_CXX_API_METRICS_METRICS_ADAPTER_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_CXX_API_METRICS_METRICS_ADAPTER_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/cell.h"
|
||||||
|
#include "include/api/metrics/metrics.h"
|
||||||
|
#include "include/lite_session.h"
|
||||||
|
#include "include/train/metrics.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class MetricsAdapter : public session::Metrics {
|
||||||
|
public:
|
||||||
|
explicit MetricsAdapter(mindspore::Metrics *metrics) : metrics_(metrics) {}
|
||||||
|
MetricsAdapter() = delete;
|
||||||
|
|
||||||
|
void Clear() override { metrics_->Clear(); }
|
||||||
|
|
||||||
|
float Eval() override { return metrics_->Eval(); }
|
||||||
|
void Update(std::vector<tensor::MSTensor *> inputs, std::vector<tensor::MSTensor *> outputs) override {
|
||||||
|
// metrics_->Update(inputs, outputs); TODO need to implement
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
mindspore::Metrics *metrics_;
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_SRC_CXX_API_METRICS_METRICS_ADAPTER_H_
|
|
@ -0,0 +1,43 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_CXX_API_METRICS_METRICS_IMPL_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_CXX_API_METRICS_METRICS_IMPL_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/train/metrics.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class MetricsImpl {
|
||||||
|
public:
|
||||||
|
MetricsImpl() = delete;
|
||||||
|
explicit MetricsImpl(session::Metrics *metrics) : metrics_(metrics) {}
|
||||||
|
session::Metrics *GetInternalMetrics() { return metrics_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
session::Metrics *metrics_ = nullptr;
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_SRC_CXX_API_METRICS_METRICS_IMPL_H_
|
|
@ -17,11 +17,15 @@
|
||||||
#include "include/api/model.h"
|
#include "include/api/model.h"
|
||||||
#include "include/api/types.h"
|
#include "include/api/types.h"
|
||||||
#include "include/api/context.h"
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/callback/callback.h"
|
||||||
#include "include/api/dual_abi_helper.h"
|
#include "include/api/dual_abi_helper.h"
|
||||||
#include "src/cxx_api/model/model_impl.h"
|
#include "src/cxx_api/model/model_impl.h"
|
||||||
|
#include "src/cxx_api/callback/callback_impl.h"
|
||||||
|
#include "src/cxx_api/callback/callback_adapter.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
|
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
|
||||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||||
|
@ -35,26 +39,33 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
|
||||||
}
|
}
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_context) {
|
|
||||||
|
Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_context,
|
||||||
|
const std::shared_ptr<TrainCfg> &train_cfg) {
|
||||||
|
std::stringstream err_msg;
|
||||||
if (impl_ != nullptr) {
|
if (impl_ != nullptr) {
|
||||||
MS_LOG(DEBUG) << "Model has been already built.";
|
MS_LOG(DEBUG) << "Model has been already built.";
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Model implement is null.";
|
err_msg << "Model implement is null.";
|
||||||
return kLiteNullptr;
|
MS_LOG(ERROR) << err_msg.str();
|
||||||
|
return Status(kLiteNullptr, err_msg.str());
|
||||||
}
|
}
|
||||||
if (graph.GetGraph() == nullptr) {
|
if (graph.GetGraph() == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid graph.";
|
err_msg << "Invalid null graph.";
|
||||||
return kLiteNullptr;
|
MS_LOG(ERROR) << err_msg.str();
|
||||||
|
return Status(kLiteNullptr, err_msg.str());
|
||||||
}
|
}
|
||||||
if (model_context == nullptr) {
|
if (model_context == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
err_msg << "Invalid null context.";
|
||||||
return kLiteNullptr;
|
MS_LOG(ERROR) << err_msg.str();
|
||||||
|
return Status(kLiteNullptr, err_msg.str());
|
||||||
}
|
}
|
||||||
impl_->SetContext(model_context);
|
impl_->SetContext(model_context);
|
||||||
impl_->SetGraph(graph.GetGraph());
|
impl_->SetGraph(graph.GetGraph());
|
||||||
|
impl_->SetConfig(train_cfg);
|
||||||
return impl_->Build();
|
return impl_->Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,10 +97,12 @@ bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type)
|
||||||
|
|
||||||
std::vector<MSTensor> Model::GetInputs() {
|
std::vector<MSTensor> Model::GetInputs() {
|
||||||
std::vector<MSTensor> empty;
|
std::vector<MSTensor> empty;
|
||||||
|
// std::cout << "Model::GetInputs " << std::endl;
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Model implement is null.";
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
return empty;
|
return empty;
|
||||||
}
|
}
|
||||||
|
// std::cout << "Model2::GetInputs " << std::endl;
|
||||||
return impl_->GetInputs();
|
return impl_->GetInputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,4 +148,33 @@ std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_
|
||||||
}
|
}
|
||||||
return impl_->GetOutputsByNodeName(CharToString(node_name));
|
return impl_->GetOutputsByNodeName(CharToString(node_name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Model::SetTrainMode(bool train) {
|
||||||
|
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
|
||||||
|
MS_LOG(ERROR) << "Model is null.";
|
||||||
|
return kLiteUninitializedObj;
|
||||||
|
}
|
||||||
|
auto ret = (train) ? impl_->session_->Train() : impl_->session_->Eval();
|
||||||
|
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Model::GetTrainMode() const { return ((impl_ != nullptr) && (impl_->session_) && (impl_->session_->IsTrain())); }
|
||||||
|
|
||||||
|
Status Model::InitMetrics(std::vector<Metrics *> metrics) {
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
return kLiteUninitializedObj;
|
||||||
|
}
|
||||||
|
return impl_->InitMetrics(metrics);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Metrics *> Model::GetMetrics() {
|
||||||
|
std::vector<Metrics *> empty;
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
return impl_->GetMetrics();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,89 +22,29 @@
|
||||||
#include "include/lite_session.h"
|
#include "include/lite_session.h"
|
||||||
#include "include/context.h"
|
#include "include/context.h"
|
||||||
#include "src/runtime/inner_allocator.h"
|
#include "src/runtime/inner_allocator.h"
|
||||||
|
#include "src/cxx_api/converters.h"
|
||||||
#include "src/cxx_api/graph/graph_data.h"
|
#include "src/cxx_api/graph/graph_data.h"
|
||||||
#include "src/cxx_api/tensor/tensor_impl.h"
|
#include "src/cxx_api/tensor/tensor_impl.h"
|
||||||
#include "src/cxx_api/tensor_utils.h"
|
#include "src/cxx_api/tensor_utils.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
#include "src/train/train_session.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
using mindspore::lite::RET_ERROR;
|
using mindspore::lite::RET_ERROR;
|
||||||
using mindspore::lite::RET_OK;
|
using mindspore::lite::RET_OK;
|
||||||
|
|
||||||
lite::CpuBindMode ModelImpl::GetCpuBindMode() {
|
CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto) {
|
||||||
auto affinity_mode = context_->GetThreadAffinityMode();
|
static CreateTrainSessionProto *proto_ = nullptr;
|
||||||
switch (affinity_mode) {
|
if (proto != nullptr) {
|
||||||
case 0:
|
proto_ = proto;
|
||||||
return lite::NO_BIND;
|
|
||||||
case 1:
|
|
||||||
return lite::HIGHER_CPU;
|
|
||||||
case 2:
|
|
||||||
return lite::MID_CPU;
|
|
||||||
default:
|
|
||||||
return lite::NO_BIND;
|
|
||||||
}
|
}
|
||||||
}
|
return proto_;
|
||||||
|
|
||||||
Status ModelImpl::ConverterContext(const std::shared_ptr<Context> &context, lite::Context *model_context) {
|
|
||||||
auto device_list = context->MutableDeviceInfo();
|
|
||||||
if (device_list.size() == 0) {
|
|
||||||
MS_LOG(ERROR) << "Invalid device list.";
|
|
||||||
return kLiteInputParamInvalid;
|
|
||||||
}
|
|
||||||
if (device_list.size() > 2) {
|
|
||||||
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
|
|
||||||
return kLiteInputParamInvalid;
|
|
||||||
}
|
|
||||||
|
|
||||||
model_context->thread_num_ = context->GetThreadNum();
|
|
||||||
model_context->enable_parallel_ = context->GetEnableParallel();
|
|
||||||
model_context->affinity_core_list_ = context->GetThreadAffinityCoreList();
|
|
||||||
model_context->device_list_.clear();
|
|
||||||
if (device_list[0]->GetDeviceType() != kCPU) {
|
|
||||||
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
|
|
||||||
return kLiteInputParamInvalid;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
|
|
||||||
model_context->allocator = cpu_context->GetAllocator();
|
|
||||||
if (model_context->allocator == nullptr) {
|
|
||||||
model_context->allocator = Allocator::Create();
|
|
||||||
if (model_context->allocator == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Create Allocator failed.";
|
|
||||||
return kLiteNullptr;
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "Set new allocator.";
|
|
||||||
cpu_context->SetAllocator(model_context->allocator);
|
|
||||||
}
|
|
||||||
|
|
||||||
lite::CpuBindMode mode = GetCpuBindMode();
|
|
||||||
lite::DeviceInfo cpu_info = {0};
|
|
||||||
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
|
|
||||||
model_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
|
|
||||||
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
|
|
||||||
if (device_list.size() == 2) {
|
|
||||||
lite::DeviceInfo device_info = {0};
|
|
||||||
if (device_list[1]->GetDeviceType() == kMaliGPU) {
|
|
||||||
auto gpu_context = device_list[1]->Cast<MaliGPUDeviceInfo>();
|
|
||||||
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
|
|
||||||
model_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
|
|
||||||
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
|
|
||||||
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
|
|
||||||
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
|
|
||||||
device_info.npu_device_info_ = {npu_context->GetFrequency()};
|
|
||||||
model_context->device_list_.push_back({lite::DT_NPU, device_info});
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Invalid device.";
|
|
||||||
return kLiteInputParamInvalid;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return kSuccess;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
|
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &ms_context) {
|
const std::shared_ptr<Context> &ms_context) {
|
||||||
lite::Context lite_context;
|
lite::Context lite_context;
|
||||||
auto status = ConverterContext(ms_context, &lite_context);
|
auto status = A2L_ConvertContext(ms_context.get(), &lite_context);
|
||||||
if (status != kSuccess) {
|
if (status != kSuccess) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
@ -123,25 +63,38 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode
|
||||||
|
|
||||||
Status ModelImpl::Build() {
|
Status ModelImpl::Build() {
|
||||||
MS_LOG(DEBUG) << "Start build model.";
|
MS_LOG(DEBUG) << "Start build model.";
|
||||||
auto model = graph_->graph_data_->lite_model();
|
if (graph_ == nullptr || graph_->graph_data_ == nullptr) {
|
||||||
if (graph_ == nullptr || graph_->graph_data_ == nullptr || model == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Invalid graph.";
|
MS_LOG(ERROR) << "Invalid graph.";
|
||||||
return kLiteNullptr;
|
return kLiteNullptr;
|
||||||
}
|
}
|
||||||
if (model->buf == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Lite model has been freed.";
|
|
||||||
return kLiteError;
|
|
||||||
}
|
|
||||||
if (context_ == nullptr) {
|
if (context_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return kLiteNullptr;
|
return kLiteNullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
lite::Context model_context;
|
lite::Context model_context;
|
||||||
auto status = ConverterContext(context_, &model_context);
|
auto status = A2L_ConvertContext(context_.get(), &model_context);
|
||||||
if (status != kSuccess) {
|
if (status != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Failed to convert Context to Lite Context";
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto create_callback = CreateTrainSessionCallbackHolder();
|
||||||
|
if (create_callback != nullptr) {
|
||||||
|
auto session = create_callback(graph_->graph_data_, cfg_, &model_context);
|
||||||
|
if (session != nullptr) {
|
||||||
|
session_ = session;
|
||||||
|
MS_LOG(DEBUG) << "Build model success.";
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto model = graph_->graph_data_->lite_model();
|
||||||
|
if (model == nullptr || model->buf == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Lite model has been freed.";
|
||||||
|
return kLiteError;
|
||||||
|
}
|
||||||
auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&model_context));
|
auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&model_context));
|
||||||
if (session == nullptr) {
|
if (session == nullptr) {
|
||||||
MS_LOG(ERROR) << "Allocate session failed.";
|
MS_LOG(ERROR) << "Allocate session failed.";
|
||||||
|
@ -192,6 +145,9 @@ Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBac
|
||||||
auto ret = session_->RunGraph(before_call_back, after_call_back);
|
auto ret = session_->RunGraph(before_call_back, after_call_back);
|
||||||
return static_cast<StatusCode>(ret);
|
return static_cast<StatusCode>(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ModelImpl::IsTrainModel() { return (graph_ && graph_->graph_data_ && graph_->graph_data_->IsTrainModel()); }
|
||||||
|
|
||||||
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||||
if (outputs == nullptr) {
|
if (outputs == nullptr) {
|
||||||
|
@ -459,4 +415,5 @@ Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<
|
||||||
auto ret = session_->Resize(inner_input, truncated_shape);
|
auto ret = session_->Resize(inner_input, truncated_shape);
|
||||||
return static_cast<StatusCode>(ret);
|
return static_cast<StatusCode>(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_CXX_API_MODEL_MODEL_IMPL_H
|
#ifndef MINDSPORE_LITE_SRC_CXX_API_MODEL_MODEL_IMPL_H_
|
||||||
#define MINDSPORE_LITE_SRC_CXX_API_MODEL_MODEL_IMPL_H
|
#define MINDSPORE_LITE_SRC_CXX_API_MODEL_MODEL_IMPL_H_
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
@ -28,8 +28,30 @@
|
||||||
#include "include/api/context.h"
|
#include "include/api/context.h"
|
||||||
#include "include/api/cell.h"
|
#include "include/api/cell.h"
|
||||||
#include "include/lite_session.h"
|
#include "include/lite_session.h"
|
||||||
|
#include "src/cxx_api/graph/graph_data.h"
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void clearVectorOfPointers(std::vector<T> *v) {
|
||||||
|
if (v != nullptr) {
|
||||||
|
for (typename std::vector<T>::iterator it = v->begin(); it != v->end(); ++it) {
|
||||||
|
delete (*it);
|
||||||
|
}
|
||||||
|
v->clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
|
typedef std::shared_ptr<session::LiteSession>(CreateTrainSessionProto)(std::shared_ptr<Graph::GraphData> graph_data,
|
||||||
|
std::shared_ptr<TrainCfg> cfg,
|
||||||
|
lite::Context *context);
|
||||||
|
CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
|
||||||
|
|
||||||
|
namespace session {
|
||||||
|
class Metrics;
|
||||||
|
class TrainLoopCallBack;
|
||||||
|
} // namespace session
|
||||||
|
|
||||||
class ModelImpl {
|
class ModelImpl {
|
||||||
public:
|
public:
|
||||||
ModelImpl() : graph_(nullptr), session_(nullptr), context_(nullptr) {}
|
ModelImpl() : graph_(nullptr), session_(nullptr), context_(nullptr) {}
|
||||||
|
@ -51,18 +73,36 @@ class ModelImpl {
|
||||||
std::vector<MSTensor> GetOutputsByNodeName(const std::string &name);
|
std::vector<MSTensor> GetOutputsByNodeName(const std::string &name);
|
||||||
|
|
||||||
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
|
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
|
||||||
|
bool IsTrainModel();
|
||||||
|
|
||||||
|
Status InitMetrics(std::vector<Metrics *> metrics) {
|
||||||
|
metrics_ = metrics;
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
std::vector<Metrics *> GetMetrics() { return metrics_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Utility methods
|
||||||
|
Status ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
|
||||||
|
std::vector<session::TrainLoopCallBack *> *o_cbs,
|
||||||
|
std::vector<session::TrainLoopCallBack *> *adapter_cbs);
|
||||||
|
Status PrepareMetrics(Model *model, std::vector<session::Metrics *> *o_ms,
|
||||||
|
std::vector<session::Metrics *> *adapter_ms);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Model;
|
friend class Model;
|
||||||
std::shared_ptr<Graph> graph_;
|
friend class Serialization;
|
||||||
std::shared_ptr<session::LiteSession> session_;
|
std::shared_ptr<Graph> graph_ = nullptr;
|
||||||
std::shared_ptr<Context> context_;
|
std::shared_ptr<session::LiteSession> session_ = nullptr;
|
||||||
|
std::shared_ptr<Context> context_ = nullptr;
|
||||||
|
std::shared_ptr<TrainCfg> cfg_ = nullptr;
|
||||||
|
std::vector<Metrics *> metrics_;
|
||||||
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
||||||
void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
|
void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
|
||||||
|
void SetConfig(const std::shared_ptr<TrainCfg> cfg) { cfg_ = cfg; }
|
||||||
lite::CpuBindMode GetCpuBindMode();
|
lite::CpuBindMode GetCpuBindMode();
|
||||||
Status ConverterContext(const std::shared_ptr<Context> &context, lite::Context *model_context);
|
|
||||||
Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
|
Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_MODEL_IMPL_H
|
#endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_MODEL_IMPL_H_
|
||||||
|
|
|
@ -21,10 +21,29 @@
|
||||||
#include "include/api/types.h"
|
#include "include/api/types.h"
|
||||||
#include "include/model.h"
|
#include "include/model.h"
|
||||||
#include "src/cxx_api/graph/graph_data.h"
|
#include "src/cxx_api/graph/graph_data.h"
|
||||||
|
#include "src/cxx_api/model/model_impl.h"
|
||||||
|
#include "src/cxx_api/converters.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) {
|
Key::Key(const char *dec_key, size_t key_len) {
|
||||||
|
len = 0;
|
||||||
|
if (key_len >= max_key_len) {
|
||||||
|
MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(key, dec_key, key_len);
|
||||||
|
len = key_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||||
|
const Key &dec_key, const std::vector<char> &dec_mode) {
|
||||||
|
if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return kLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
if (model_data == nullptr) {
|
if (model_data == nullptr) {
|
||||||
MS_LOG(ERROR) << "model data is nullptr.";
|
MS_LOG(ERROR) << "model data is nullptr.";
|
||||||
return kLiteNullptr;
|
return kLiteNullptr;
|
||||||
|
@ -37,6 +56,7 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
|
||||||
MS_LOG(ERROR) << "Unsupported IR.";
|
MS_LOG(ERROR) << "Unsupported IR.";
|
||||||
return kLiteInputParamInvalid;
|
return kLiteInputParamInvalid;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto model = std::shared_ptr<lite::Model>(lite::Model::Import(static_cast<const char *>(model_data), data_size));
|
auto model = std::shared_ptr<lite::Model>(lite::Model::Import(static_cast<const char *>(model_data), data_size));
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
MS_LOG(ERROR) << "New model failed.";
|
MS_LOG(ERROR) << "New model failed.";
|
||||||
|
@ -51,26 +71,45 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
|
||||||
const Key &dec_key, const std::vector<char> &dec_mode) {
|
|
||||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
|
||||||
return kLiteError;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
|
|
||||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
|
||||||
return kLiteError;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||||
const std::vector<char> &dec_mode) {
|
const std::vector<char> &dec_mode) {
|
||||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) {
|
||||||
return kLiteError;
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return kLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (graph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "graph is nullptr.";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
if (model_type != kFlatBuffer) {
|
||||||
|
MS_LOG(ERROR) << "Unsupported IR.";
|
||||||
|
return kLiteInputParamInvalid;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string filename(file.data(), file.size());
|
||||||
|
if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
|
||||||
|
filename = filename + ".ms";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto model = std::shared_ptr<lite::Model>(lite::Model::Import(filename.c_str()));
|
||||||
|
if (model == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "New model failed.";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
|
||||||
|
if (graph_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "New graph data failed.";
|
||||||
|
return kLiteMemoryFailed;
|
||||||
|
}
|
||||||
|
*graph = Graph(graph_data);
|
||||||
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
|
Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type,
|
||||||
MS_LOG(ERROR) << "Unsupported feature.";
|
std::vector<Graph> *graphs, const Key &dec_key, const std::vector<char> &dec_mode) {
|
||||||
return kMEFailed;
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return kLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model) {
|
Status Serialization::SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model) {
|
||||||
|
@ -83,8 +122,23 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, Buff
|
||||||
return kMEFailed;
|
return kMEFailed;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file) {
|
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
|
||||||
MS_LOG(ERROR) << "Unsupported feature.";
|
QuantizationType quantization_type, bool export_inference_only) {
|
||||||
return kMEFailed;
|
if (model.impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
return kLiteUninitializedObj;
|
||||||
|
}
|
||||||
|
if (!model.impl_->IsTrainModel()) {
|
||||||
|
MS_LOG(ERROR) << "Model is not TrainModel.";
|
||||||
|
return kLiteError;
|
||||||
|
}
|
||||||
|
if (model_type != kFlatBuffer) {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
|
||||||
|
return kLiteParamInvalid;
|
||||||
|
}
|
||||||
|
auto ret = model.impl_->session_->Export(model_file, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
|
||||||
|
A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS);
|
||||||
|
|
||||||
|
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "src/cxx_api/converters.h"
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/train/train_cfg.h"
|
||||||
|
#include "include/api/cfg.h"
|
||||||
|
#include "src/runtime/inner_allocator.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cfg) {
|
||||||
|
if ((a_train_cfg == nullptr) || (l_train_cfg == nullptr)) {
|
||||||
|
MS_LOG(ERROR) << "Invalid train_cfg pointers";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
l_train_cfg->loss_name_ = a_train_cfg->loss_name_;
|
||||||
|
l_train_cfg->mix_precision_cfg_.dynamic_loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_;
|
||||||
|
l_train_cfg->mix_precision_cfg_.loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_;
|
||||||
|
l_train_cfg->mix_precision_cfg_.keep_batchnorm_fp32_ = (a_train_cfg->optimization_level_ != kO3);
|
||||||
|
l_train_cfg->mix_precision_cfg_.num_of_not_nan_iter_th_ = a_train_cfg->mix_precision_cfg_.num_of_not_nan_iter_th_;
|
||||||
|
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,109 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/api/types.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/callback/callback.h"
|
||||||
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
#include "src/cxx_api/model/model_impl.h"
|
||||||
|
#include "src/cxx_api/callback/callback_impl.h"
|
||||||
|
#include "src/cxx_api/callback/callback_adapter.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
#include "include/train/train_loop.h"
|
||||||
|
#include "include/train/train_loop_callback.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
Status Model::Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> i_cbs) {
|
||||||
|
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
return kLiteUninitializedObj;
|
||||||
|
}
|
||||||
|
auto loop = std::unique_ptr<session::TrainLoop>(session::TrainLoop::CreateTrainLoop((impl_->session_).get()));
|
||||||
|
if (loop == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Error during allocation of train loop";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert Metrics to MS Lite and init loop
|
||||||
|
std::vector<session::Metrics *> metrics;
|
||||||
|
std::vector<session::Metrics *> adapter_metrics;
|
||||||
|
auto status = impl_->PrepareMetrics(this, &metrics, &adapter_metrics);
|
||||||
|
if (status != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Error during preparation of metrics";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
loop->Init(metrics);
|
||||||
|
|
||||||
|
// Convert Callbacks to be used by loop
|
||||||
|
std::vector<session::TrainLoopCallBack *> cbs;
|
||||||
|
std::vector<session::TrainLoopCallBack *> adapter_cbs;
|
||||||
|
status = impl_->ConvertCallbacks(this, &i_cbs, &cbs, &adapter_cbs);
|
||||||
|
if (status != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Error during preparation of callbacks";
|
||||||
|
clearVectorOfPointers(&adapter_metrics);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = loop->Train(epochs, ds.get(), cbs);
|
||||||
|
|
||||||
|
clearVectorOfPointers(&adapter_metrics);
|
||||||
|
clearVectorOfPointers(&adapter_cbs);
|
||||||
|
|
||||||
|
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Model::Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> i_cbs) {
|
||||||
|
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
return kLiteUninitializedObj;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto loop = std::unique_ptr<session::TrainLoop>(session::TrainLoop::CreateTrainLoop((impl_->session_).get()));
|
||||||
|
if (loop == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Error during allocation of train loop";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert Metrics to MS Lite and init loop
|
||||||
|
std::vector<session::Metrics *> metrics;
|
||||||
|
std::vector<session::Metrics *> adapter_metrics;
|
||||||
|
auto status = impl_->PrepareMetrics(this, &metrics, &adapter_metrics);
|
||||||
|
if (status != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Error during preparation of metrics";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
loop->Init(metrics);
|
||||||
|
|
||||||
|
// Convert Callbacks to be used by loop
|
||||||
|
std::vector<session::TrainLoopCallBack *> cbs;
|
||||||
|
std::vector<session::TrainLoopCallBack *> adapter_cbs;
|
||||||
|
status = impl_->ConvertCallbacks(this, &i_cbs, &cbs, &adapter_cbs);
|
||||||
|
if (status != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Error during preparation of callbacks";
|
||||||
|
clearVectorOfPointers(&adapter_metrics);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = loop->Eval(ds.get(), cbs);
|
||||||
|
|
||||||
|
clearVectorOfPointers(&adapter_metrics);
|
||||||
|
clearVectorOfPointers(&adapter_cbs);
|
||||||
|
|
||||||
|
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,161 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/cxx_api/model/model_impl.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "include/api/types.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
#include "include/lite_session.h"
|
||||||
|
#include "include/context.h"
|
||||||
|
#include "include/api/callback/callback.h"
|
||||||
|
#include "include/api/metrics/metrics.h"
|
||||||
|
#include "src/lite_model.h"
|
||||||
|
#include "src/runtime/inner_allocator.h"
|
||||||
|
#include "src/common/string_util.h"
|
||||||
|
#include "src/cxx_api/converters.h"
|
||||||
|
#include "src/cxx_api/graph/graph_data.h"
|
||||||
|
#include "src/cxx_api/tensor/tensor_impl.h"
|
||||||
|
#include "src/cxx_api/tensor_utils.h"
|
||||||
|
#include "src/cxx_api/metrics/metrics_adapter.h"
|
||||||
|
#include "src/cxx_api/metrics/metrics_impl.h"
|
||||||
|
#include "src/cxx_api/callback/callback_adapter.h"
|
||||||
|
#include "src/cxx_api/callback/callback_impl.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
#include "src/train/train_session.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
using mindspore::lite::RET_ERROR;
|
||||||
|
using mindspore::lite::RET_OK;
|
||||||
|
|
||||||
|
std::shared_ptr<session::LiteSession> CreateTrainSession(std::shared_ptr<Graph::GraphData> graph_data,
|
||||||
|
std::shared_ptr<TrainCfg> cfg, lite::Context *context) {
|
||||||
|
bool is_train_session = graph_data->IsTrainModel();
|
||||||
|
|
||||||
|
if (is_train_session) {
|
||||||
|
auto model = graph_data->lite_model();
|
||||||
|
if (model == nullptr || model->buf == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Lite model has been freed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::shared_ptr<session::LiteSession> shared_session;
|
||||||
|
lite::TrainSession *session = new lite::TrainSession();
|
||||||
|
if (session == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "create session failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
shared_session.reset(session);
|
||||||
|
|
||||||
|
lite::TrainCfg train_cfg;
|
||||||
|
if (cfg != nullptr) {
|
||||||
|
auto status = A2L_ConvertConfig(cfg.get(), &train_cfg);
|
||||||
|
if (status != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Failed to convert Config to Lite Config";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = session->Init(context, &train_cfg);
|
||||||
|
if (ret != mindspore::lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "init session failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = session->CompileTrainGraph(model);
|
||||||
|
if (ret != mindspore::lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Compiling Train Graph session failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return shared_session;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Session is not a train session.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
class UnifiedAPISupportTrain {
|
||||||
|
public:
|
||||||
|
UnifiedAPISupportTrain() { CreateTrainSessionCallbackHolder(CreateTrainSession); }
|
||||||
|
};
|
||||||
|
|
||||||
|
UnifiedAPISupportTrain support_train_api;
|
||||||
|
|
||||||
|
Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *out_ms,
|
||||||
|
std::vector<session::Metrics *> *adapter_ms) {
|
||||||
|
if (out_ms == nullptr || adapter_ms == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Null input callbacks";
|
||||||
|
return kLiteUninitializedObj;
|
||||||
|
}
|
||||||
|
auto model_metrics = GetMetrics();
|
||||||
|
for (auto m : model_metrics) {
|
||||||
|
if (m->metrics_impl_) {
|
||||||
|
// For off-the-shelf metrics it is guaranteed that we have also an MSLite implementation
|
||||||
|
auto internal_m = m->metrics_impl_->GetInternalMetrics();
|
||||||
|
if (internal_m == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Internal metric is null.";
|
||||||
|
clearVectorOfPointers(adapter_ms);
|
||||||
|
return kLiteUninitializedObj;
|
||||||
|
}
|
||||||
|
out_ms->push_back(internal_m);
|
||||||
|
} else {
|
||||||
|
// For custom metric we use the metric adapter to mediate between MSLite level to API level
|
||||||
|
auto adapter_m = new MetricsAdapter(m);
|
||||||
|
if (adapter_m == nullptr) { // Error during allocation
|
||||||
|
MS_LOG(ERROR) << "Error during allocation";
|
||||||
|
clearVectorOfPointers(adapter_ms);
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
out_ms->push_back(adapter_m);
|
||||||
|
adapter_ms->push_back(adapter_m);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ModelImpl::ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
|
||||||
|
std::vector<session::TrainLoopCallBack *> *o_cbs,
|
||||||
|
std::vector<session::TrainLoopCallBack *> *adapter_cbs) {
|
||||||
|
if (i_cbs == nullptr || o_cbs == nullptr || adapter_cbs == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Null input callbacks";
|
||||||
|
return kLiteUninitializedObj;
|
||||||
|
}
|
||||||
|
for (auto cb : *i_cbs) {
|
||||||
|
if (cb->callback_impl_) {
|
||||||
|
// For off-the-shelf callback it is guaranteed that we have also an MSLite implementation
|
||||||
|
auto internal_cb = cb->callback_impl_->GetInternalCallback();
|
||||||
|
if (internal_cb == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Internal callback is null";
|
||||||
|
clearVectorOfPointers(adapter_cbs);
|
||||||
|
return kLiteUninitializedObj;
|
||||||
|
}
|
||||||
|
o_cbs->push_back(internal_cb);
|
||||||
|
} else {
|
||||||
|
// For custom callbacks we use the callback adapter to mediate between MSLite level to API level
|
||||||
|
auto adapter_cb = new TrainLoopCallBackAdapter(model, cb);
|
||||||
|
if (adapter_cb == nullptr) { // Error during allocation
|
||||||
|
MS_LOG(ERROR) << "Error during allocation";
|
||||||
|
clearVectorOfPointers(adapter_cbs);
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
o_cbs->push_back(adapter_cb);
|
||||||
|
adapter_cbs->push_back(adapter_cb);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -32,9 +32,7 @@ using session::RET_CONTINUE;
|
||||||
using session::RET_EXIT;
|
using session::RET_EXIT;
|
||||||
using session::RET_STOP_TRAINING;
|
using session::RET_STOP_TRAINING;
|
||||||
|
|
||||||
TrainLoop::~TrainLoop() {
|
TrainLoop::~TrainLoop() {}
|
||||||
if (train_session_ != nullptr) delete train_session_;
|
|
||||||
}
|
|
||||||
|
|
||||||
int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
|
int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
|
||||||
train_session_->Train();
|
train_session_->Train();
|
||||||
|
|
|
@ -129,7 +129,7 @@ int TrainSession::InitCallBack() {
|
||||||
auto in_size = node->input_indices_.size();
|
auto in_size = node->input_indices_.size();
|
||||||
bool force_fp16 = false;
|
bool force_fp16 = false;
|
||||||
for (std::size_t k = 0; k < in_size; k++) {
|
for (std::size_t k = 0; k < in_size; k++) {
|
||||||
schema::Tensor *tensor = model_->all_tensors_.at(node->input_indices_[k]);
|
schema::Tensor *tensor = model_.get()->all_tensors_.at(node->input_indices_[k]);
|
||||||
if ((tensor->dataType() == kNumberTypeFloat16) && (tensor->nodeType() == NodeType_ValueNode)) {
|
if ((tensor->dataType() == kNumberTypeFloat16) && (tensor->nodeType() == NodeType_ValueNode)) {
|
||||||
force_fp16 = true;
|
force_fp16 = true;
|
||||||
break;
|
break;
|
||||||
|
@ -161,7 +161,7 @@ int TrainSession::InitCallBack() {
|
||||||
|
|
||||||
int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; }
|
int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; }
|
||||||
|
|
||||||
int TrainSession::CompileTrainGraph(mindspore::lite::Model *model) {
|
int TrainSession::CompileTrainGraph(std::shared_ptr<Model> model) {
|
||||||
model_ = model;
|
model_ = model;
|
||||||
auto restore = ReplaceOps();
|
auto restore = ReplaceOps();
|
||||||
sched_cb_ = std::make_unique<SchedulerCb>(sched_mix_precision_callback_);
|
sched_cb_ = std::make_unique<SchedulerCb>(sched_mix_precision_callback_);
|
||||||
|
@ -170,7 +170,7 @@ int TrainSession::CompileTrainGraph(mindspore::lite::Model *model) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ret = lite::LiteSession::CompileGraph(model);
|
auto ret = lite::LiteSession::CompileGraph(model_.get());
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "failed to compile train model";
|
MS_LOG(ERROR) << "failed to compile train model";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -193,13 +193,7 @@ int TrainSession::CompileTrainGraph(mindspore::lite::Model *model) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
TrainSession::~TrainSession() {
|
TrainSession::~TrainSession() { FreeWorkSpace(); }
|
||||||
FreeWorkSpace();
|
|
||||||
if (model_ != nullptr) {
|
|
||||||
delete model_;
|
|
||||||
model_ = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int TrainSession::ExecKernels(const KernelCallBack &before, const KernelCallBack &after,
|
int TrainSession::ExecKernels(const KernelCallBack &before, const KernelCallBack &after,
|
||||||
const std::vector<kernel::LiteKernel *> &run_kernels) {
|
const std::vector<kernel::LiteKernel *> &run_kernels) {
|
||||||
|
@ -690,14 +684,14 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
|
||||||
bool orig_train_state = IsTrain();
|
bool orig_train_state = IsTrain();
|
||||||
Eval();
|
Eval();
|
||||||
TrainExport texport(file_name);
|
TrainExport texport(file_name);
|
||||||
int status = texport.ExportInit(model_->name_, model_->version_);
|
int status = texport.ExportInit(model_.get()->name_, model_.get()->version_);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "cannot init export";
|
MS_LOG(ERROR) << "cannot init export";
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
|
status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
|
||||||
(model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, model_,
|
(model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
|
||||||
quant_type);
|
model_.get(), quant_type);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "cannot export Network";
|
MS_LOG(ERROR) << "cannot export Network";
|
||||||
return status;
|
return status;
|
||||||
|
@ -766,7 +760,7 @@ session::LiteSession *session::LiteSession::CreateTrainSession(const std::string
|
||||||
filename = filename + ".ms";
|
filename = filename + ".ms";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto *model = mindspore::lite::Model::Import(filename.c_str());
|
auto model = std::shared_ptr<lite::Model>(lite::Model::Import(filename.c_str()));
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
MS_LOG(ERROR) << "create model for train session failed " << filename;
|
MS_LOG(ERROR) << "create model for train session failed " << filename;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -51,7 +51,7 @@ class TrainSession : virtual public lite::LiteSession {
|
||||||
int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
|
int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
|
||||||
|
|
||||||
int CompileGraph(lite::Model *model) override;
|
int CompileGraph(lite::Model *model) override;
|
||||||
virtual int CompileTrainGraph(lite::Model *model);
|
virtual int CompileTrainGraph(std::shared_ptr<Model> model);
|
||||||
|
|
||||||
virtual int Init(const Context *context, const TrainCfg *train_cfg);
|
virtual int Init(const Context *context, const TrainCfg *train_cfg);
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ class TrainSession : virtual public lite::LiteSession {
|
||||||
virtual void CompileTrainOutputs();
|
virtual void CompileTrainOutputs();
|
||||||
virtual void CompileEvalOutputs();
|
virtual void CompileEvalOutputs();
|
||||||
virtual int InitCallBack();
|
virtual int InitCallBack();
|
||||||
Model *model_ = nullptr;
|
std::shared_ptr<Model> model_ = nullptr;
|
||||||
// TrainCfg train_cfg_;
|
// TrainCfg train_cfg_;
|
||||||
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_node_map_;
|
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_node_map_;
|
||||||
std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_;
|
std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_;
|
||||||
|
|
|
@ -207,7 +207,7 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_, quant_type);
|
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_.get(), quant_type);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "cannot serialize head";
|
MS_LOG(ERROR) << "cannot serialize head";
|
||||||
return status;
|
return status;
|
||||||
|
@ -257,7 +257,7 @@ static session::LiteSession *CreateTransferSessionInt(const char *model_buf_back
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto model = lite::Model::Import(model_buf_head, size_head);
|
auto model = std::shared_ptr<lite::Model>(lite::Model::Import(model_buf_head, size_head));
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
MS_LOG(ERROR) << "create model for head train session failed";
|
MS_LOG(ERROR) << "create model for head train session failed";
|
||||||
delete session;
|
delete session;
|
||||||
|
|
|
@ -60,9 +60,18 @@ echo 'run common ut tests'
|
||||||
# test cases specific for train
|
# test cases specific for train
|
||||||
|
|
||||||
echo 'run train ut tests'
|
echo 'run train ut tests'
|
||||||
## ./lite-test --gtest_filter=NetworkTest.efficient_net
|
# ./lite-test --gtest_filter="TestConvolutionGradFp32*"
|
||||||
## ./lite-test --gtest_filter="NetworkTest.tuning_layer"
|
# ./lite-test --gtest_filter="TestActGradFp32*"
|
||||||
## ./lite-test --gtest_filter="NetworkTest.lenetnet"
|
# ./lite-test --gtest_filter="TestSoftmaxGradFp32*"
|
||||||
|
# ./lite-test --gtest_filter="TestSoftmaxCrossEntropyFp32*"
|
||||||
|
# ./lite-test --gtest_filter="TestDeConvolutionGradFp32*"
|
||||||
|
# ./lite-test --gtest_filter="TestBiasGradFp32*"
|
||||||
|
|
||||||
|
# test cases specific for CXX_API
|
||||||
|
|
||||||
|
# ./lite-test --gtest_filter="TestCxxApiLiteModel*"
|
||||||
|
# ./lite-test --gtest_filter="TestCxxApiLiteSerialization*"
|
||||||
|
|
||||||
echo 'run inference ut tests'
|
echo 'run inference ut tests'
|
||||||
./lite-test --gtest_filter="ControlFlowTest.TestMergeWhileModel"
|
./lite-test --gtest_filter="ControlFlowTest.TestMergeWhileModel"
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,120 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <memory>
|
||||||
|
#include "common/common_test.h"
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/serialization.h"
|
||||||
|
#include "include/api/metrics/accuracy.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class TestCxxApiLiteModel : public mindspore::CommonTest {
|
||||||
|
public:
|
||||||
|
TestCxxApiLiteModel() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteModel, test_build_context_uninitialized_FAILED) {
|
||||||
|
Model model;
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
ASSERT_TRUE(Serialization::Load("./test_data/nets/conv_train_model.ms", ModelType::kFlatBuffer, &graph) == kSuccess);
|
||||||
|
auto status = model.Build(GraphCell(graph), nullptr, nullptr);
|
||||||
|
ASSERT_TRUE(status != kSuccess);
|
||||||
|
auto err_mst = status.GetErrDescription();
|
||||||
|
ASSERT_TRUE(err_mst.find("null") != std::string::npos);
|
||||||
|
}
|
||||||
|
TEST_F(TestCxxApiLiteModel, test_build_graph_uninitialized_FAILED) {
|
||||||
|
Model model;
|
||||||
|
GraphCell graph_cell;
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||||
|
context->MutableDeviceInfo().push_back(cpu_context);
|
||||||
|
|
||||||
|
ASSERT_TRUE(model.Build(graph_cell, context, nullptr) != kSuccess);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteModel, test_build_SUCCES) {
|
||||||
|
Model model;
|
||||||
|
Graph graph;
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||||
|
context->MutableDeviceInfo().push_back(cpu_context);
|
||||||
|
|
||||||
|
ASSERT_TRUE(Serialization::Load("./test_data/nets/conv_train_model.ms", ModelType::kFlatBuffer, &graph) == kSuccess);
|
||||||
|
ASSERT_TRUE(model.Build(GraphCell(graph), context, nullptr) == kSuccess);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteModel, test_train_mode_FAILURE) {
|
||||||
|
Model model;
|
||||||
|
ASSERT_TRUE(model.SetTrainMode(true) != kSuccess);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteModel, test_train_mode_SUCCES) {
|
||||||
|
Model model;
|
||||||
|
Graph graph;
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||||
|
context->MutableDeviceInfo().push_back(cpu_context);
|
||||||
|
|
||||||
|
ASSERT_TRUE(Serialization::Load("./test_data/nets/conv_train_model.ms", ModelType::kFlatBuffer, &graph) == kSuccess);
|
||||||
|
ASSERT_TRUE(model.Build(GraphCell(graph), context, nullptr) == kSuccess);
|
||||||
|
ASSERT_TRUE(model.SetTrainMode(true) == kSuccess);
|
||||||
|
ASSERT_TRUE(model.GetTrainMode() == true);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteModel, test_outputs_FAILURE) {
|
||||||
|
Model model;
|
||||||
|
auto outputs = model.GetOutputs();
|
||||||
|
ASSERT_EQ(outputs.size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteModel, test_outputs_SUCCESS) {
|
||||||
|
Model model;
|
||||||
|
Graph graph;
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||||
|
context->MutableDeviceInfo().push_back(cpu_context);
|
||||||
|
|
||||||
|
ASSERT_TRUE(Serialization::Load("./test_data/nets/conv_train_model.ms", ModelType::kFlatBuffer, &graph) == kSuccess);
|
||||||
|
ASSERT_TRUE(model.Build(GraphCell(graph), context, nullptr) == kSuccess);
|
||||||
|
auto outputs = model.GetOutputs();
|
||||||
|
ASSERT_GT(outputs.size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteModel, test_metrics_FAILURE) {
|
||||||
|
Model model;
|
||||||
|
AccuracyMetrics ac;
|
||||||
|
ASSERT_TRUE(model.InitMetrics({&ac}) != kSuccess);
|
||||||
|
auto metrics = model.GetMetrics();
|
||||||
|
ASSERT_EQ(metrics.size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteModel, test_metrics_SUCCESS) {
|
||||||
|
Model model;
|
||||||
|
Graph graph;
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||||
|
context->MutableDeviceInfo().push_back(cpu_context);
|
||||||
|
|
||||||
|
ASSERT_TRUE(Serialization::Load("./test_data/nets/conv_train_model.ms", ModelType::kFlatBuffer, &graph) == kSuccess);
|
||||||
|
ASSERT_TRUE(model.Build(GraphCell(graph), context, nullptr) == kSuccess);
|
||||||
|
AccuracyMetrics ac;
|
||||||
|
ASSERT_TRUE(model.InitMetrics({&ac}) == kSuccess);
|
||||||
|
auto metrics = model.GetMetrics();
|
||||||
|
ASSERT_EQ(metrics.size(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <memory>
|
||||||
|
#include "common/common_test.h"
|
||||||
|
#include "include/api/serialization.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class TestCxxApiLiteSerialization : public mindspore::CommonTest {
|
||||||
|
public:
|
||||||
|
TestCxxApiLiteSerialization() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteSerialization, test_load_no_encrpty_mindir_SUCCESS) {
|
||||||
|
Graph graph;
|
||||||
|
ASSERT_TRUE(Serialization::Load("./test_data/nets/retinaface1.ms", ModelType::kFlatBuffer, &graph) == kSuccess);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteSerialization, test_load_file_not_exist_FAILED) {
|
||||||
|
Graph graph;
|
||||||
|
auto status = Serialization::Load("./test_data/nets/file_not_exist.mindir", ModelType::kMindIR, &graph);
|
||||||
|
ASSERT_TRUE(status != kSuccess);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteSerialization, test_load_file_not_exist_x2_FAILED) {
|
||||||
|
std::vector<Graph> graphs;
|
||||||
|
auto status = Serialization::Load(std::vector<std::string>(2, "./data/mindir/file_not_exist.mindir"),
|
||||||
|
ModelType::kMindIR, &graphs);
|
||||||
|
ASSERT_TRUE(status != kSuccess);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteSerialization, test_export_uninitialized_FAILED) {
|
||||||
|
Model model;
|
||||||
|
ASSERT_TRUE(Serialization::ExportModel(model, ModelType::kFlatBuffer, "./test_data/nets/export.ms") != kSuccess);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -47,10 +47,10 @@ TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_SUCCESS) {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
std::string key_str = "0123456789ABCDEF";
|
std::string key_str = "0123456789ABCDEF";
|
||||||
Key key;
|
Key key;
|
||||||
memcpy(key.key, key_str.c_str(), key_str.size());
|
memcpy(key.key, key_str.data(), key_str.size());
|
||||||
key.len = key_str.size();
|
key.len = key_str.size();
|
||||||
ASSERT_TRUE(Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph,
|
ASSERT_TRUE(Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph,
|
||||||
key, "AES-GCM") == kSuccess);
|
key, kDecModeAesGcm) == kSuccess);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_FAILED) {
|
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_FAILED) {
|
||||||
|
@ -66,10 +66,10 @@ TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_with_wrong_key_FAILED)
|
||||||
Graph graph;
|
Graph graph;
|
||||||
std::string key_str = "WRONG_KEY";
|
std::string key_str = "WRONG_KEY";
|
||||||
Key key;
|
Key key;
|
||||||
memcpy(key.key, key_str.c_str(), key_str.size());
|
memcpy(key.key, key_str.data(), key_str.size());
|
||||||
key.len = key_str.size();
|
key.len = key_str.size();
|
||||||
auto status = Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph,
|
auto status = Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph,
|
||||||
key, "AES-GCM");
|
key, kDecModeAesGcm);
|
||||||
ASSERT_TRUE(status != kSuccess);
|
ASSERT_TRUE(status != kSuccess);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,9 +77,10 @@ TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_with_wrong_key_FAILE
|
||||||
Graph graph;
|
Graph graph;
|
||||||
std::string key_str = "WRONG_KEY";
|
std::string key_str = "WRONG_KEY";
|
||||||
Key key;
|
Key key;
|
||||||
memcpy(key.key, key_str.c_str(), key_str.size());
|
memcpy(key.key, key_str.data(), key_str.size());
|
||||||
key.len = key_str.size();
|
key.len = key_str.size();
|
||||||
auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph, key, "AES-GCM");
|
auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph,
|
||||||
|
key, kDecModeAesGcm);
|
||||||
ASSERT_TRUE(status != kSuccess);
|
ASSERT_TRUE(status != kSuccess);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue