forked from mindspore-Ecosystem/mindspore
modify train api
This commit is contained in:
parent
df9c0d2ff5
commit
ba796f95de
|
@ -34,7 +34,12 @@ class MixPrecisionCfg {
|
|||
this->keep_batchnorm_fp32_ = true;
|
||||
this->num_of_not_nan_iter_th_ = iter_th;
|
||||
}
|
||||
|
||||
MixPrecisionCfg(const MixPrecisionCfg &rhs) {
|
||||
this->dynamic_loss_scale_ = rhs.dynamic_loss_scale_;
|
||||
this->loss_scale_ = rhs.loss_scale_;
|
||||
this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_;
|
||||
this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_;
|
||||
}
|
||||
~MixPrecisionCfg() = default;
|
||||
|
||||
bool dynamic_loss_scale_ = false; /**< Enable/disable dynamic loss scale during mix precision training */
|
||||
|
@ -46,13 +51,18 @@ class MixPrecisionCfg {
|
|||
|
||||
class TrainCfg {
|
||||
public:
|
||||
TrainCfg() { this->loss_name_.emplace_back("_loss_fn"); }
|
||||
|
||||
TrainCfg() = default;
|
||||
TrainCfg(const TrainCfg &rhs) {
|
||||
this->loss_name_ = rhs.loss_name_;
|
||||
this->mix_precision_cfg_ = rhs.mix_precision_cfg_;
|
||||
this->accumulate_gradients_ = rhs.accumulate_gradients_;
|
||||
}
|
||||
~TrainCfg() = default;
|
||||
|
||||
OptimizationLevel optimization_level_ = kO0;
|
||||
std::vector<std::string> loss_name_; /**< Set part of the name that identify a loss kernel */
|
||||
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
|
||||
std::vector<std::string> loss_name_ = {"loss_fct",
|
||||
"_loss_fn"}; /**< Set part of the name that identify a loss kernel */
|
||||
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
|
||||
bool accumulate_gradients_ = false;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -233,6 +233,7 @@ class MS_API Model {
|
|||
/// \brief update gradient tensors of the model.
|
||||
///
|
||||
/// \param[in] gradients A vector new gradients.
|
||||
///
|
||||
/// \return Status of operation
|
||||
Status ApplyGradients(const std::vector<MSTensor> &gradients);
|
||||
|
||||
|
@ -244,6 +245,7 @@ class MS_API Model {
|
|||
/// \brief update weights tensors of the model.
|
||||
///
|
||||
/// \param[in] new_weights A vector new weights.
|
||||
///
|
||||
/// \return Status of operation
|
||||
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
|
||||
|
||||
|
@ -252,32 +254,44 @@ class MS_API Model {
|
|||
/// \return The vector that includes all params tensors.
|
||||
std::vector<MSTensor> GetOptimizerParams() const;
|
||||
|
||||
/// \brief update the optimizer parameters
|
||||
/// \brief update the optimizer parameters.
|
||||
///
|
||||
/// \param[in] params A vector new optimizer params.
|
||||
/// \return Status of operation
|
||||
///
|
||||
/// \return Status of operation.
|
||||
Status SetOptimizerParams(const std::vector<MSTensor> ¶ms);
|
||||
|
||||
/// \brief Setup training with virtual batches
|
||||
/// \brief Setup training with virtual batches.
|
||||
///
|
||||
/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable
|
||||
/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration
|
||||
/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration
|
||||
/// \return Status of operation
|
||||
/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable.
|
||||
/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration.
|
||||
/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration.
|
||||
///
|
||||
/// \return Status of operation.
|
||||
Status SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f);
|
||||
|
||||
/// \brief Sets the Learning Rate of the training
|
||||
/// \brief Sets the Learning Rate of the training.
|
||||
///
|
||||
/// \param[in] learning_rate to set
|
||||
/// \return Status of operation
|
||||
/// \param[in] learning_rate to set.
|
||||
///
|
||||
/// \return Status of operation.
|
||||
Status SetLearningRate(float learning_rate);
|
||||
|
||||
/// \brief Gets the Learning Rate of the optimizer
|
||||
/// \brief Gets the Learning Rate of the optimizer.
|
||||
///
|
||||
/// \return learning rate. 0.0 if no optimizer was found
|
||||
/// \return Learning rate. 0.0 if no optimizer was found.
|
||||
float GetLearningRate();
|
||||
|
||||
/// \brief Initialize object with metrics.
|
||||
///
|
||||
/// \param[in] metrics A verctor of metrics objects.
|
||||
///
|
||||
/// \return 0 on success or -1 in case of error
|
||||
Status InitMetrics(std::vector<Metrics *> metrics);
|
||||
|
||||
/// \brief Accessor to TrainLoop metric objects
|
||||
///
|
||||
/// \return A vector of metrics
|
||||
std::vector<Metrics *> GetMetrics();
|
||||
|
||||
/// \brief Obtains all output tensors of the model.
|
||||
|
@ -322,9 +336,33 @@ class MS_API Model {
|
|||
/// \return Is supported or not.
|
||||
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
|
||||
|
||||
/// \brief Set the model running mode.
|
||||
///
|
||||
/// \param[in] train True means model runs in Train Mode, otherwise Eval Mode.
|
||||
///
|
||||
/// \return Status of operation.
|
||||
Status SetTrainMode(bool train);
|
||||
|
||||
/// \brief Get the model running mode.
|
||||
///
|
||||
/// \return Is Train Mode or not.
|
||||
bool GetTrainMode() const;
|
||||
|
||||
/// \brief Performs the training Loop in Train Mode.
|
||||
///
|
||||
/// \param[in] epochs The number of epoch to run.
|
||||
/// \param[in] dataset A smart pointer to MindData Dataset object.
|
||||
/// \param[in] cbs A vector of TrainLoopCallBack objects.
|
||||
///
|
||||
/// \return Status of operation.
|
||||
Status Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
|
||||
|
||||
/// \brief Performs the training loop over all data in Eval Mode.
|
||||
///
|
||||
/// \param[in] dataset A smart pointer to MindData Dataset object.
|
||||
/// \param[in] cbs A vector of TrainLoopCallBack objects.
|
||||
///
|
||||
/// \return Status of operation.
|
||||
Status Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
|
||||
|
||||
private:
|
||||
|
|
|
@ -79,7 +79,7 @@ class TrainLoop {
|
|||
|
||||
/// \brief Performs the training Loop
|
||||
///
|
||||
/// \param[in] epoch The number of epochs to run
|
||||
/// \param[in] epochs The number of epochs to run
|
||||
/// \param[in] dataset Pointer to MindData Dataset object
|
||||
/// \param[in] cbs A vector of TrainLoopCallBack objects
|
||||
/// \param[in] load_func a function that load (and can manipulate) data from Minddata Dataset array into model
|
||||
|
|
|
@ -81,7 +81,7 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<TrainLoopCallBack *> c
|
|||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
int break_loop = false;
|
||||
bool break_loop = false;
|
||||
for (auto cb : cbs) {
|
||||
ret = cb->EpochEnd(cb_data);
|
||||
if (ret != RET_CONTINUE) {
|
||||
|
@ -166,7 +166,7 @@ int TrainLoop::LoadData(std::vector<tensor::MSTensor *> inputs, dataset::MSTenso
|
|||
return RET_STOP_TRAINING;
|
||||
}
|
||||
|
||||
for (unsigned int i = 0; i < num_of_inputs; i++) {
|
||||
for (size_t i = 0; i < num_of_inputs; i++) {
|
||||
auto *input_data = reinterpret_cast<unsigned char *>(inputs.at(i)->MutableData());
|
||||
const auto *row_data = reinterpret_cast<const unsigned char *>(row_vec->at(i).MutableData());
|
||||
auto data_size = row_vec->at(i).DataSize();
|
||||
|
|
Loading…
Reference in New Issue