add runstep and featuremap api

This commit is contained in:
zhengjun10 2021-12-17 10:02:18 +08:00
parent 4daacb6f7f
commit a11c9d6900
6 changed files with 94 additions and 7 deletions

View File

@ -22,6 +22,7 @@
#include <memory>
#include "include/api/data_type.h"
#include "include/api/dual_abi_helper.h"
#include "include/api/types.h"
namespace mindspore {

View File

@ -83,6 +83,14 @@ class MS_API Model {
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
/// \brief Train model by step.
///
/// \param[in] before CallBack before predict.
/// \param[in] after CallBack after predict.
///
/// \return Status.
Status RunStep(const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
/// \brief Inference model with preprocess in model.
///
/// \param[in] inputs A vector where model inputs are arranged in sequence.
@ -143,6 +151,17 @@ class MS_API Model {
/// \return Status of operation
Status ApplyGradients(const std::vector<MSTensor> &gradients);
/// \brief Obtains all weights tensors of the model.
///
/// \return The vector that includes all gradient tensors.
std::vector<MSTensor> GetFeatureMaps() const;
/// \brief update weights tensors of the model.
///
/// \param[in] inputs A vector new weights.
/// \return Status of operation
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
/// \brief Obtains optimizer params tensors of the model.
///
/// \return The vector that includes all params tensors.

View File

@ -19,14 +19,14 @@
#include <cuda_runtime.h>
#endif
#include <mutex>
#include "include/api/types.h"
#include "include/api/context.h"
#include "include/api/callback/callback.h"
#include "include/api/context.h"
#include "include/api/dual_abi_helper.h"
#include "src/cxx_api/model/model_impl.h"
#include "src/cxx_api/callback/callback_impl.h"
#include "src/cxx_api/callback/callback_adapter.h"
#include "include/api/types.h"
#include "src/common/log_adapter.h"
#include "src/cxx_api/callback/callback_adapter.h"
#include "src/cxx_api/callback/callback_impl.h"
#include "src/cxx_api/model/model_impl.h"
namespace mindspore {
std::mutex g_impl_init_lock;
@ -113,6 +113,16 @@ Status Model::UpdateWeights(const std::vector<MSTensor> &new_weights) {
return impl_->UpdateWeights(new_weights);
}
Status Model::RunStep(const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteNullptr;
}
auto inputs = impl_->GetInputs();
auto outputs = impl_->GetOutputs();
return impl_->Predict(inputs, &outputs, before, after);
}
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (impl_ == nullptr) {
@ -291,6 +301,23 @@ Status Model::ApplyGradients(const std::vector<MSTensor> &gradients) {
return impl_->ApplyGradients(gradients);
}
std::vector<MSTensor> Model::GetFeatureMaps() const {
std::vector<MSTensor> empty;
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return empty;
}
return impl_->GetFeatureMaps();
}
Status Model::UpdateFeatureMaps(const std::vector<MSTensor> &new_weights) {
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
MS_LOG(ERROR) << "Model is null.";
return kLiteUninitializedObj;
}
return impl_->UpdateFeatureMaps(new_weights);
}
std::vector<MSTensor> Model::GetOptimizerParams() const {
std::vector<MSTensor> empty;
if (impl_ == nullptr) {

View File

@ -410,6 +410,44 @@ Status ModelImpl::ApplyGradients(const std::vector<MSTensor> &gradients) {
return static_cast<StatusCode>(ret);
}
std::vector<MSTensor> ModelImpl::GetFeatureMaps() const {
std::vector<MSTensor> empty;
if (session_ == nullptr) {
MS_LOG(ERROR) << "Session is null.";
return empty;
}
auto params = session_->GetFeatureMaps();
if (params.empty()) {
MS_LOG(ERROR) << "No optimizer parameters avelibale.";
return empty;
}
std::vector<MSTensor> res = LiteTensorsToMSTensors(params, false);
return res;
}
Status ModelImpl::UpdateFeatureMaps(const std::vector<MSTensor> &new_weights) {
if (session_ == nullptr) {
MS_LOG(ERROR) << "Session is null.";
return kLiteNullptr;
}
if (new_weights.empty()) {
MS_LOG(ERROR) << "gradients is null.";
return kLiteInputParamInvalid;
}
std::vector<tensor::MSTensor *> inner_weights;
inner_weights.resize(new_weights.size());
for (size_t i = 0; i < new_weights.size(); i++) {
auto new_weight = new_weights[i];
if (new_weight.impl_ == nullptr || new_weight.impl_->lite_tensor() == nullptr) {
MS_LOG(ERROR) << "gradient tensor " << new_weight.Name() << " is null.";
return kLiteInputTensorError;
}
inner_weights[i] = new_weight.impl_->lite_tensor();
}
auto ret = session_->UpdateFeatureMaps(inner_weights);
return static_cast<StatusCode>(ret);
}
std::vector<MSTensor> ModelImpl::GetOptimizerParams() const {
std::vector<MSTensor> empty;
if (session_ == nullptr) {

View File

@ -77,6 +77,8 @@ class ModelImpl {
std::vector<MSTensor> GetOutputs();
std::vector<MSTensor> GetGradients() const;
Status ApplyGradients(const std::vector<MSTensor> &gradients);
std::vector<MSTensor> GetFeatureMaps() const;
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
std::vector<MSTensor> GetOptimizerParams() const;
Status SetOptimizerParams(const std::vector<MSTensor> &params);
MSTensor GetInputByTensorName(const std::string &name);

View File

@ -1 +1 @@
790864
839924