add new api GetTrainableParams in lite-training

This commit is contained in:
zhangyanhui 2022-12-08 21:44:13 +08:00
parent 847c0abea2
commit 254e1fbea8
14 changed files with 123 additions and 0 deletions

View File

@ -243,6 +243,11 @@ class MS_API Model {
/// \return The vector that includes all weights tensors.
std::vector<MSTensor> GetFeatureMaps() const;
/// \brief Obtain all trainable parameters of the model optimizers.
///
/// \return The vector that includes all trainable parameters.
std::vector<MSTensor> GetTrainableParams() const;
/// \brief Update weights tensors of the model.
///
/// \param[in] new_weights A vector new weights.

View File

@ -540,6 +540,15 @@ std::vector<MSTensor> Model::GetFeatureMaps() const {
return impl_->GetFeatureMaps();
}
std::vector<MSTensor> Model::GetTrainableParams() const {
std::vector<MSTensor> empty;
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return empty;
}
return impl_->GetTrainableParams();
}
Status Model::UpdateFeatureMaps(const std::vector<MSTensor> &new_weights) {
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
MS_LOG(ERROR) << "Model is null.";

View File

@ -600,6 +600,21 @@ std::vector<MSTensor> ModelImpl::GetFeatureMaps() const {
return res;
}
std::vector<MSTensor> ModelImpl::GetTrainableParams() const {
std::vector<MSTensor> empty;
if (session_ == nullptr) {
MS_LOG(ERROR) << "Session is null.";
return empty;
}
auto params = session_->GetTrainableParams();
if (params.empty()) {
MS_LOG(ERROR) << "No trainable parameters available.";
return empty;
}
std::vector<MSTensor> res = LiteTensorsToMSTensors(params, true);
return res;
}
Status ModelImpl::UpdateFeatureMaps(const std::vector<MSTensor> &new_weights) {
if (session_ == nullptr) {
MS_LOG(ERROR) << "Session is null.";
@ -840,6 +855,9 @@ Status ModelImpl::UpdateWeights(const std::vector<MSTensor> &new_weights) {
inner_weights[i] = lite_impl->lite_tensor();
}
auto ret = session_->UpdateWeights(inner_weights);
if (ret != kSuccess) {
MS_LOG(ERROR) << "UpdateWeights failed, and the origin weights may have been changed.";
}
return static_cast<StatusCode>(ret);
}

View File

@ -91,6 +91,7 @@ class ModelImpl {
std::vector<MSTensor> GetGradients() const;
Status ApplyGradients(const std::vector<MSTensor> &gradients);
std::vector<MSTensor> GetFeatureMaps() const;
std::vector<MSTensor> GetTrainableParams() const;
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
std::vector<MSTensor> GetOptimizerParams() const;
Status SetOptimizerParams(const std::vector<MSTensor> &params);

View File

@ -112,6 +112,11 @@ std::vector<int> AdamCPUKernel::GetOptimizerParamsIdxs() const {
return indices;
}
std::vector<int> AdamCPUKernel::GetTrainableParamsIdxs() const {
std::vector<int> indices = {0, 1, 2, 3, 4, 5};
return indices;
}
int AdamCPUKernel::OptimizerStep() {
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_10D - 1);
auto weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIdx)->MutableData());

View File

@ -41,6 +41,7 @@ class AdamCPUKernel : public OptimizerKernel {
int DoExecute(int task_id);
int OptimizerStep() override;
std::vector<int> GetOptimizerParamsIdxs() const override;
std::vector<int> GetTrainableParamsIdxs() const override;
private:
int thread_count_;

View File

@ -116,6 +116,11 @@ std::vector<int> ApplyMomentumCPUKernel::GetOptimizerParamsIdxs() const {
return indices;
}
std::vector<int> ApplyMomentumCPUKernel::GetTrainableParamsIdxs() const {
std::vector<int> indices = {0, 1, 2, 4};
return indices;
}
int ApplyMomentumCPUKernel::OptimizerStep() {
auto weight = reinterpret_cast<float *>(in_tensors_.at(FIRST_INPUT)->data());
CHECK_NULL_RETURN(weight);

View File

@ -43,6 +43,7 @@ class ApplyMomentumCPUKernel : public OptimizerKernel {
int Run() override;
int OptimizerStep() override;
std::vector<int> GetOptimizerParamsIdxs() const override;
std::vector<int> GetTrainableParamsIdxs() const override;
private:
int thread_count_;

View File

@ -218,6 +218,11 @@ std::vector<int> SgdCPUKernel::GetOptimizerParamsIdxs() const {
return indices;
}
std::vector<int> SgdCPUKernel::GetTrainableParamsIdxs() const {
std::vector<int> indices = {0, 2, 3, 4, 5};
return indices;
}
int SgdCPUKernel::OptimizerStep() {
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());

View File

@ -43,6 +43,7 @@ class SgdCPUKernel : public OptimizerKernel {
int DoExecute(int task_id);
int OptimizerStep() override;
std::vector<int> GetOptimizerParamsIdxs() const override;
std::vector<int> GetTrainableParamsIdxs() const override;
private:
int thread_count_;

View File

@ -111,6 +111,10 @@ class LiteSession {
std::vector<lite::Tensor *> features;
return features;
}
virtual std::vector<lite::Tensor *> GetTrainableParams() const {
std::vector<lite::Tensor *> train_params;
return train_params;
}
virtual int UpdateFeatureMaps(const std::vector<lite::Tensor *> &features) { return mindspore::lite::RET_ERROR; }
virtual std::vector<lite::Tensor *> GetGradients() const {
std::vector<lite::Tensor *> gradients;

View File

@ -59,6 +59,11 @@ class OptimizerKernel : public LiteKernel {
return indices;
}
virtual std::vector<int> GetTrainableParamsIdxs() const {
std::vector<int> indices;
return indices;
}
std::vector<lite::Tensor *> GetOptimizerParams() const {
std::vector<lite::Tensor *> params;
auto indices = GetOptimizerParamsIdxs();
@ -95,6 +100,19 @@ class OptimizerKernel : public LiteKernel {
return found;
}
std::vector<lite::Tensor *> GetTrainableParams() const {
std::vector<lite::Tensor *> params;
auto indices = GetTrainableParamsIdxs();
for (size_t ix = 0; ix < indices.size(); ix++) {
auto param = in_tensors_.at(indices[ix]);
if (!param->IsConst()) {
continue;
}
params.push_back(param);
}
return params;
}
lite::Tensor *GetGradients() {
lite::Tensor *grad_sum_tensor = nullptr;
if (grad_sum_ != nullptr) {

View File

@ -49,6 +49,31 @@ void FreeGradients(const std::vector<lite::Tensor *> &gradients) {
delete gradient;
}
} // namespace
void AddNonConstTrainableParams(const std::vector<kernel::KernelExec *> &in_kernels, kernel::OptimizerKernel *optimizer,
std::vector<lite::Tensor *> *params) {
auto indices = optimizer->GetTrainableParamsIdxs();
if (params->size() == indices.size()) {
return;
}
for (size_t ix = 0; ix < indices.size(); ix++) {
auto param = optimizer->in_tensors().at(indices[ix]);
if (param->IsConst()) {
continue;
}
for (size_t i = 0; i < in_kernels.size(); i++) {
auto out_tensors = in_kernels.at(i)->out_tensors();
if (std::find(out_tensors.begin(), out_tensors.end(), param) != out_tensors.end() &&
!in_kernels.at(i)->in_tensors().empty()) {
auto filtered_tensor = in_kernels.at(i)->in_tensors().at(FIRST_INPUT);
if (filtered_tensor->IsConst()) {
params->emplace_back(filtered_tensor);
break;
}
}
}
}
}
} // namespace
const char *kGradName = "Gradients";
const char *kOptimizerName = "optimizer";
@ -354,6 +379,7 @@ int TrainSession::CompileTrainGraph(std::shared_ptr<Model> model) {
RestoreOps(restore);
CompileTrainKernels(); // Prepare a list of train kernels
CompileOptimizedKernels(); // Prepare a list of kernels which are optimized (weight update step)
CompileTrainableParams(); // Prepare trainable parameters of optimizers
CompileTrainOutputs(); // prepare outputs in train mode
CompileEvalOutputs(); // prepare outputs in eval mode
// Prepare a list of eval kernels
@ -835,6 +861,25 @@ void TrainSession::CompileOptimizedKernels() {
}
}
void TrainSession::CompileTrainableParams() {
for (auto kernel : this->train_kernels_) {
if (!IsOptimizer(kernel)) {
continue;
}
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
auto params = optimizer->GetTrainableParams();
auto in_kernels = kernel->in_kernels();
AddNonConstTrainableParams(in_kernels, optimizer, &params);
for (auto param : params) {
if (std::find(trainable_parameters_.begin(), trainable_parameters_.end(), param) != trainable_parameters_.end()) {
continue;
}
trainable_parameters_.emplace_back(param);
}
}
}
int TrainSession::SetLearningRate(float learning_rate) {
if (learning_rate < 0.0f) {
MS_LOG(ERROR) << "learning rate should more than 0";
@ -1246,6 +1291,8 @@ std::vector<lite::Tensor *> TrainSession::GetFeatureMaps() const {
return features;
}
std::vector<lite::Tensor *> TrainSession::GetTrainableParams() const { return trainable_parameters_; }
int TrainSession::UpdateFeatureMaps(const std::vector<lite::Tensor *> &features_map) {
for (auto feature : features_map) {
bool find = false;

View File

@ -101,6 +101,7 @@ class TrainSession : virtual public lite::LiteSession {
int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector<std::string> out_put_tensor_name = {}) override;
std::vector<lite::Tensor *> GetFeatureMaps() const override;
std::vector<lite::Tensor *> GetTrainableParams() const override;
int UpdateFeatureMaps(const std::vector<lite::Tensor *> &features_map) override;
int FindUseInTensorKernel(std::vector<kernel::KernelExec *> *use_in_tensor_kernels,
@ -123,6 +124,7 @@ class TrainSession : virtual public lite::LiteSession {
virtual void CompileTrainKernels();
virtual int CompileInferenceKernels();
virtual void CompileOptimizedKernels();
virtual void CompileTrainableParams();
virtual void CompileTrainOutputs();
virtual void CompileEvalOutputs();
virtual int InitCallBack();
@ -171,6 +173,7 @@ class TrainSession : virtual public lite::LiteSession {
int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector<std::string> out_put_tensor_name = {});
std::map<Tensor *, Tensor *> restored_origin_tensors_;
std::vector<Tensor *> trainable_parameters_;
int virtual_batch_idx_ = 0;
int virtual_batch_multiplier_ = 0;
uint32_t num_of_not_nan_iter_ = 0;