forked from mindspore-Ecosystem/mindspore
!46620 add new api GetTrainableParams in lite-training
Merge pull request !46620 from zhangyanhui/code_mas
This commit is contained in:
commit
a381f51732
|
@ -243,6 +243,11 @@ class MS_API Model {
|
|||
/// \return The vector that includes all weights tensors.
|
||||
std::vector<MSTensor> GetFeatureMaps() const;
|
||||
|
||||
/// \brief Obtain all trainable parameters of the model optimizers.
|
||||
///
|
||||
/// \return The vector that includes all trainable parameters.
|
||||
std::vector<MSTensor> GetTrainableParams() const;
|
||||
|
||||
/// \brief Update weights tensors of the model.
|
||||
///
|
||||
/// \param[in] new_weights A vector new weights.
|
||||
|
|
|
@ -540,6 +540,15 @@ std::vector<MSTensor> Model::GetFeatureMaps() const {
|
|||
return impl_->GetFeatureMaps();
|
||||
}
|
||||
|
||||
std::vector<MSTensor> Model::GetTrainableParams() const {
|
||||
std::vector<MSTensor> empty;
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return empty;
|
||||
}
|
||||
return impl_->GetTrainableParams();
|
||||
}
|
||||
|
||||
Status Model::UpdateFeatureMaps(const std::vector<MSTensor> &new_weights) {
|
||||
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
|
||||
MS_LOG(ERROR) << "Model is null.";
|
||||
|
|
|
@ -600,6 +600,21 @@ std::vector<MSTensor> ModelImpl::GetFeatureMaps() const {
|
|||
return res;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> ModelImpl::GetTrainableParams() const {
|
||||
std::vector<MSTensor> empty;
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session is null.";
|
||||
return empty;
|
||||
}
|
||||
auto params = session_->GetTrainableParams();
|
||||
if (params.empty()) {
|
||||
MS_LOG(ERROR) << "No trainable parameters available.";
|
||||
return empty;
|
||||
}
|
||||
std::vector<MSTensor> res = LiteTensorsToMSTensors(params, true);
|
||||
return res;
|
||||
}
|
||||
|
||||
Status ModelImpl::UpdateFeatureMaps(const std::vector<MSTensor> &new_weights) {
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session is null.";
|
||||
|
@ -840,6 +855,9 @@ Status ModelImpl::UpdateWeights(const std::vector<MSTensor> &new_weights) {
|
|||
inner_weights[i] = lite_impl->lite_tensor();
|
||||
}
|
||||
auto ret = session_->UpdateWeights(inner_weights);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "UpdateWeights failed, and the origin weights may have been changed.";
|
||||
}
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
|
|
|
@ -91,6 +91,7 @@ class ModelImpl {
|
|||
std::vector<MSTensor> GetGradients() const;
|
||||
Status ApplyGradients(const std::vector<MSTensor> &gradients);
|
||||
std::vector<MSTensor> GetFeatureMaps() const;
|
||||
std::vector<MSTensor> GetTrainableParams() const;
|
||||
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
|
||||
std::vector<MSTensor> GetOptimizerParams() const;
|
||||
Status SetOptimizerParams(const std::vector<MSTensor> ¶ms);
|
||||
|
|
|
@ -112,6 +112,11 @@ std::vector<int> AdamCPUKernel::GetOptimizerParamsIdxs() const {
|
|||
return indices;
|
||||
}
|
||||
|
||||
std::vector<int> AdamCPUKernel::GetTrainableParamsIdxs() const {
|
||||
std::vector<int> indices = {0, 1, 2, 3, 4, 5};
|
||||
return indices;
|
||||
}
|
||||
|
||||
int AdamCPUKernel::OptimizerStep() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_10D - 1);
|
||||
auto weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIdx)->MutableData());
|
||||
|
|
|
@ -41,6 +41,7 @@ class AdamCPUKernel : public OptimizerKernel {
|
|||
int DoExecute(int task_id);
|
||||
int OptimizerStep() override;
|
||||
std::vector<int> GetOptimizerParamsIdxs() const override;
|
||||
std::vector<int> GetTrainableParamsIdxs() const override;
|
||||
|
||||
private:
|
||||
int thread_count_;
|
||||
|
|
|
@ -116,6 +116,11 @@ std::vector<int> ApplyMomentumCPUKernel::GetOptimizerParamsIdxs() const {
|
|||
return indices;
|
||||
}
|
||||
|
||||
std::vector<int> ApplyMomentumCPUKernel::GetTrainableParamsIdxs() const {
|
||||
std::vector<int> indices = {0, 1, 2, 4};
|
||||
return indices;
|
||||
}
|
||||
|
||||
int ApplyMomentumCPUKernel::OptimizerStep() {
|
||||
auto weight = reinterpret_cast<float *>(in_tensors_.at(FIRST_INPUT)->data());
|
||||
CHECK_NULL_RETURN(weight);
|
||||
|
|
|
@ -43,6 +43,7 @@ class ApplyMomentumCPUKernel : public OptimizerKernel {
|
|||
int Run() override;
|
||||
int OptimizerStep() override;
|
||||
std::vector<int> GetOptimizerParamsIdxs() const override;
|
||||
std::vector<int> GetTrainableParamsIdxs() const override;
|
||||
|
||||
private:
|
||||
int thread_count_;
|
||||
|
|
|
@ -218,6 +218,11 @@ std::vector<int> SgdCPUKernel::GetOptimizerParamsIdxs() const {
|
|||
return indices;
|
||||
}
|
||||
|
||||
std::vector<int> SgdCPUKernel::GetTrainableParamsIdxs() const {
|
||||
std::vector<int> indices = {0, 2, 3, 4, 5};
|
||||
return indices;
|
||||
}
|
||||
|
||||
int SgdCPUKernel::OptimizerStep() {
|
||||
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ class SgdCPUKernel : public OptimizerKernel {
|
|||
int DoExecute(int task_id);
|
||||
int OptimizerStep() override;
|
||||
std::vector<int> GetOptimizerParamsIdxs() const override;
|
||||
std::vector<int> GetTrainableParamsIdxs() const override;
|
||||
|
||||
private:
|
||||
int thread_count_;
|
||||
|
|
|
@ -111,6 +111,10 @@ class LiteSession {
|
|||
std::vector<lite::Tensor *> features;
|
||||
return features;
|
||||
}
|
||||
virtual std::vector<lite::Tensor *> GetTrainableParams() const {
|
||||
std::vector<lite::Tensor *> train_params;
|
||||
return train_params;
|
||||
}
|
||||
virtual int UpdateFeatureMaps(const std::vector<lite::Tensor *> &features) { return mindspore::lite::RET_ERROR; }
|
||||
virtual std::vector<lite::Tensor *> GetGradients() const {
|
||||
std::vector<lite::Tensor *> gradients;
|
||||
|
|
|
@ -59,6 +59,11 @@ class OptimizerKernel : public LiteKernel {
|
|||
return indices;
|
||||
}
|
||||
|
||||
virtual std::vector<int> GetTrainableParamsIdxs() const {
|
||||
std::vector<int> indices;
|
||||
return indices;
|
||||
}
|
||||
|
||||
std::vector<lite::Tensor *> GetOptimizerParams() const {
|
||||
std::vector<lite::Tensor *> params;
|
||||
auto indices = GetOptimizerParamsIdxs();
|
||||
|
@ -95,6 +100,19 @@ class OptimizerKernel : public LiteKernel {
|
|||
return found;
|
||||
}
|
||||
|
||||
std::vector<lite::Tensor *> GetTrainableParams() const {
|
||||
std::vector<lite::Tensor *> params;
|
||||
auto indices = GetTrainableParamsIdxs();
|
||||
for (size_t ix = 0; ix < indices.size(); ix++) {
|
||||
auto param = in_tensors_.at(indices[ix]);
|
||||
if (!param->IsConst()) {
|
||||
continue;
|
||||
}
|
||||
params.push_back(param);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
lite::Tensor *GetGradients() {
|
||||
lite::Tensor *grad_sum_tensor = nullptr;
|
||||
if (grad_sum_ != nullptr) {
|
||||
|
|
|
@ -49,6 +49,31 @@ void FreeGradients(const std::vector<lite::Tensor *> &gradients) {
|
|||
delete gradient;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AddNonConstTrainableParams(const std::vector<kernel::KernelExec *> &in_kernels, kernel::OptimizerKernel *optimizer,
|
||||
std::vector<lite::Tensor *> *params) {
|
||||
auto indices = optimizer->GetTrainableParamsIdxs();
|
||||
if (params->size() == indices.size()) {
|
||||
return;
|
||||
}
|
||||
for (size_t ix = 0; ix < indices.size(); ix++) {
|
||||
auto param = optimizer->in_tensors().at(indices[ix]);
|
||||
if (param->IsConst()) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 0; i < in_kernels.size(); i++) {
|
||||
auto out_tensors = in_kernels.at(i)->out_tensors();
|
||||
if (std::find(out_tensors.begin(), out_tensors.end(), param) != out_tensors.end() &&
|
||||
!in_kernels.at(i)->in_tensors().empty()) {
|
||||
auto filtered_tensor = in_kernels.at(i)->in_tensors().at(FIRST_INPUT);
|
||||
if (filtered_tensor->IsConst()) {
|
||||
params->emplace_back(filtered_tensor);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
const char *kGradName = "Gradients";
|
||||
const char *kOptimizerName = "optimizer";
|
||||
|
@ -354,6 +379,7 @@ int TrainSession::CompileTrainGraph(std::shared_ptr<Model> model) {
|
|||
RestoreOps(restore);
|
||||
CompileTrainKernels(); // Prepare a list of train kernels
|
||||
CompileOptimizedKernels(); // Prepare a list of kernels which are optimized (weight update step)
|
||||
CompileTrainableParams(); // Prepare trainable parameters of optimizers
|
||||
CompileTrainOutputs(); // prepare outputs in train mode
|
||||
CompileEvalOutputs(); // prepare outputs in eval mode
|
||||
// Prepare a list of eval kernels
|
||||
|
@ -835,6 +861,25 @@ void TrainSession::CompileOptimizedKernels() {
|
|||
}
|
||||
}
|
||||
|
||||
void TrainSession::CompileTrainableParams() {
|
||||
for (auto kernel : this->train_kernels_) {
|
||||
if (!IsOptimizer(kernel)) {
|
||||
continue;
|
||||
}
|
||||
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||
auto params = optimizer->GetTrainableParams();
|
||||
auto in_kernels = kernel->in_kernels();
|
||||
AddNonConstTrainableParams(in_kernels, optimizer, ¶ms);
|
||||
|
||||
for (auto param : params) {
|
||||
if (std::find(trainable_parameters_.begin(), trainable_parameters_.end(), param) != trainable_parameters_.end()) {
|
||||
continue;
|
||||
}
|
||||
trainable_parameters_.emplace_back(param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int TrainSession::SetLearningRate(float learning_rate) {
|
||||
if (learning_rate < 0.0f) {
|
||||
MS_LOG(ERROR) << "learning rate should more than 0";
|
||||
|
@ -1246,6 +1291,8 @@ std::vector<lite::Tensor *> TrainSession::GetFeatureMaps() const {
|
|||
return features;
|
||||
}
|
||||
|
||||
std::vector<lite::Tensor *> TrainSession::GetTrainableParams() const { return trainable_parameters_; }
|
||||
|
||||
int TrainSession::UpdateFeatureMaps(const std::vector<lite::Tensor *> &features_map) {
|
||||
for (auto feature : features_map) {
|
||||
bool find = false;
|
||||
|
|
|
@ -101,6 +101,7 @@ class TrainSession : virtual public lite::LiteSession {
|
|||
int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType,
|
||||
std::vector<std::string> out_put_tensor_name = {}) override;
|
||||
std::vector<lite::Tensor *> GetFeatureMaps() const override;
|
||||
std::vector<lite::Tensor *> GetTrainableParams() const override;
|
||||
|
||||
int UpdateFeatureMaps(const std::vector<lite::Tensor *> &features_map) override;
|
||||
int FindUseInTensorKernel(std::vector<kernel::KernelExec *> *use_in_tensor_kernels,
|
||||
|
@ -123,6 +124,7 @@ class TrainSession : virtual public lite::LiteSession {
|
|||
virtual void CompileTrainKernels();
|
||||
virtual int CompileInferenceKernels();
|
||||
virtual void CompileOptimizedKernels();
|
||||
virtual void CompileTrainableParams();
|
||||
virtual void CompileTrainOutputs();
|
||||
virtual void CompileEvalOutputs();
|
||||
virtual int InitCallBack();
|
||||
|
@ -171,6 +173,7 @@ class TrainSession : virtual public lite::LiteSession {
|
|||
int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType,
|
||||
std::vector<std::string> out_put_tensor_name = {});
|
||||
std::map<Tensor *, Tensor *> restored_origin_tensors_;
|
||||
std::vector<Tensor *> trainable_parameters_;
|
||||
int virtual_batch_idx_ = 0;
|
||||
int virtual_batch_multiplier_ = 0;
|
||||
uint32_t num_of_not_nan_iter_ = 0;
|
||||
|
|
Loading…
Reference in New Issue