modify train api

This commit is contained in:
jianghui58 2022-05-20 17:32:24 +08:00
parent df9c0d2ff5
commit ba796f95de
4 changed files with 68 additions and 20 deletions

View File

@ -34,7 +34,12 @@ class MixPrecisionCfg {
this->keep_batchnorm_fp32_ = true;
this->num_of_not_nan_iter_th_ = iter_th;
}
MixPrecisionCfg(const MixPrecisionCfg &rhs) {
this->dynamic_loss_scale_ = rhs.dynamic_loss_scale_;
this->loss_scale_ = rhs.loss_scale_;
this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_;
this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_;
}
~MixPrecisionCfg() = default;
bool dynamic_loss_scale_ = false; /**< Enable/disable dynamic loss scale during mix precision training */
@ -46,13 +51,18 @@ class MixPrecisionCfg {
class TrainCfg {
public:
TrainCfg() { this->loss_name_.emplace_back("_loss_fn"); }
TrainCfg() = default;
TrainCfg(const TrainCfg &rhs) {
this->loss_name_ = rhs.loss_name_;
this->mix_precision_cfg_ = rhs.mix_precision_cfg_;
this->accumulate_gradients_ = rhs.accumulate_gradients_;
}
~TrainCfg() = default;
OptimizationLevel optimization_level_ = kO0;
std::vector<std::string> loss_name_; /**< Set part of the name that identify a loss kernel */
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
std::vector<std::string> loss_name_ = {"loss_fct",
"_loss_fn"}; /**< Set part of the name that identify a loss kernel */
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
bool accumulate_gradients_ = false;
};
} // namespace mindspore

View File

@ -233,6 +233,7 @@ class MS_API Model {
/// \brief update gradient tensors of the model.
///
/// \param[in] gradients A vector new gradients.
///
/// \return Status of operation
Status ApplyGradients(const std::vector<MSTensor> &gradients);
@ -244,6 +245,7 @@ class MS_API Model {
/// \brief update weights tensors of the model.
///
/// \param[in] new_weights A vector new weights.
///
/// \return Status of operation
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
@ -252,32 +254,44 @@ class MS_API Model {
/// \return The vector that includes all params tensors.
std::vector<MSTensor> GetOptimizerParams() const;
/// \brief update the optimizer parameters
/// \brief update the optimizer parameters.
///
/// \param[in] params A vector new optimizer params.
/// \return Status of operation
///
/// \return Status of operation.
Status SetOptimizerParams(const std::vector<MSTensor> &params);
/// \brief Setup training with virtual batches
/// \brief Setup training with virtual batches.
///
/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable
/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration
/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration
/// \return Status of operation
/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable.
/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration.
/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration.
///
/// \return Status of operation.
Status SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f);
/// \brief Sets the Learning Rate of the training
/// \brief Sets the Learning Rate of the training.
///
/// \param[in] learning_rate to set
/// \return Status of operation
/// \param[in] learning_rate to set.
///
/// \return Status of operation.
Status SetLearningRate(float learning_rate);
/// \brief Gets the Learning Rate of the optimizer
/// \brief Gets the Learning Rate of the optimizer.
///
/// \return learning rate. 0.0 if no optimizer was found
/// \return Learning rate. 0.0 if no optimizer was found.
float GetLearningRate();
/// \brief Initialize object with metrics.
///
/// \param[in] metrics A verctor of metrics objects.
///
/// \return 0 on success or -1 in case of error
Status InitMetrics(std::vector<Metrics *> metrics);
/// \brief Accessor to TrainLoop metric objects
///
/// \return A vector of metrics
std::vector<Metrics *> GetMetrics();
/// \brief Obtains all output tensors of the model.
@ -322,9 +336,33 @@ class MS_API Model {
/// \return Is supported or not.
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
/// \brief Set the model running mode.
///
/// \param[in] train True means model runs in Train Mode, otherwise Eval Mode.
///
/// \return Status of operation.
Status SetTrainMode(bool train);
/// \brief Get the model running mode.
///
/// \return Is Train Mode or not.
bool GetTrainMode() const;
/// \brief Performs the training Loop in Train Mode.
///
/// \param[in] epochs The number of epoch to run.
/// \param[in] dataset A smart pointer to MindData Dataset object.
/// \param[in] cbs A vector of TrainLoopCallBack objects.
///
/// \return Status of operation.
Status Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
/// \brief Performs the training loop over all data in Eval Mode.
///
/// \param[in] dataset A smart pointer to MindData Dataset object.
/// \param[in] cbs A vector of TrainLoopCallBack objects.
///
/// \return Status of operation.
Status Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
private:

View File

@ -79,7 +79,7 @@ class TrainLoop {
/// \brief Performs the training Loop
///
/// \param[in] epoch The number of epochs to run
/// \param[in] epochs The number of epochs to run
/// \param[in] dataset Pointer to MindData Dataset object
/// \param[in] cbs A vector of TrainLoopCallBack objects
/// \param[in] load_func a function that load (and can manipulate) data from Minddata Dataset array into model

View File

@ -81,7 +81,7 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<TrainLoopCallBack *> c
return RET_ERROR;
}
}
int break_loop = false;
bool break_loop = false;
for (auto cb : cbs) {
ret = cb->EpochEnd(cb_data);
if (ret != RET_CONTINUE) {
@ -166,7 +166,7 @@ int TrainLoop::LoadData(std::vector<tensor::MSTensor *> inputs, dataset::MSTenso
return RET_STOP_TRAINING;
}
for (unsigned int i = 0; i < num_of_inputs; i++) {
for (size_t i = 0; i < num_of_inputs; i++) {
auto *input_data = reinterpret_cast<unsigned char *>(inputs.at(i)->MutableData());
const auto *row_data = reinterpret_cast<const unsigned char *>(row_vec->at(i).MutableData());
auto data_size = row_vec->at(i).DataSize();