add runstep and featuremap api
This commit is contained in:
parent
4daacb6f7f
commit
a11c9d6900
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
#include "include/api/data_type.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
|
|
|
@ -83,6 +83,14 @@ class MS_API Model {
|
|||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||
|
||||
/// \brief Train model by step.
|
||||
///
|
||||
/// \param[in] before CallBack before predict.
|
||||
/// \param[in] after CallBack after predict.
|
||||
///
|
||||
/// \return Status.
|
||||
Status RunStep(const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||
|
||||
/// \brief Inference model with preprocess in model.
|
||||
///
|
||||
/// \param[in] inputs A vector where model inputs are arranged in sequence.
|
||||
|
@ -143,6 +151,17 @@ class MS_API Model {
|
|||
/// \return Status of operation
|
||||
Status ApplyGradients(const std::vector<MSTensor> &gradients);
|
||||
|
||||
/// \brief Obtains all weights tensors of the model.
|
||||
///
|
||||
/// \return The vector that includes all gradient tensors.
|
||||
std::vector<MSTensor> GetFeatureMaps() const;
|
||||
|
||||
/// \brief update weights tensors of the model.
|
||||
///
|
||||
/// \param[in] inputs A vector new weights.
|
||||
/// \return Status of operation
|
||||
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
|
||||
|
||||
/// \brief Obtains optimizer params tensors of the model.
|
||||
///
|
||||
/// \return The vector that includes all params tensors.
|
||||
|
|
|
@ -19,14 +19,14 @@
|
|||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
#include <mutex>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/callback/callback.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
#include "src/cxx_api/model/model_impl.h"
|
||||
#include "src/cxx_api/callback/callback_impl.h"
|
||||
#include "src/cxx_api/callback/callback_adapter.h"
|
||||
#include "include/api/types.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/cxx_api/callback/callback_adapter.h"
|
||||
#include "src/cxx_api/callback/callback_impl.h"
|
||||
#include "src/cxx_api/model/model_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::mutex g_impl_init_lock;
|
||||
|
@ -113,6 +113,16 @@ Status Model::UpdateWeights(const std::vector<MSTensor> &new_weights) {
|
|||
return impl_->UpdateWeights(new_weights);
|
||||
}
|
||||
|
||||
Status Model::RunStep(const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto inputs = impl_->GetInputs();
|
||||
auto outputs = impl_->GetOutputs();
|
||||
return impl_->Predict(inputs, &outputs, before, after);
|
||||
}
|
||||
|
||||
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
if (impl_ == nullptr) {
|
||||
|
@ -291,6 +301,23 @@ Status Model::ApplyGradients(const std::vector<MSTensor> &gradients) {
|
|||
return impl_->ApplyGradients(gradients);
|
||||
}
|
||||
|
||||
std::vector<MSTensor> Model::GetFeatureMaps() const {
|
||||
std::vector<MSTensor> empty;
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return empty;
|
||||
}
|
||||
return impl_->GetFeatureMaps();
|
||||
}
|
||||
|
||||
Status Model::UpdateFeatureMaps(const std::vector<MSTensor> &new_weights) {
|
||||
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
|
||||
MS_LOG(ERROR) << "Model is null.";
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
return impl_->UpdateFeatureMaps(new_weights);
|
||||
}
|
||||
|
||||
std::vector<MSTensor> Model::GetOptimizerParams() const {
|
||||
std::vector<MSTensor> empty;
|
||||
if (impl_ == nullptr) {
|
||||
|
|
|
@ -410,6 +410,44 @@ Status ModelImpl::ApplyGradients(const std::vector<MSTensor> &gradients) {
|
|||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
std::vector<MSTensor> ModelImpl::GetFeatureMaps() const {
|
||||
std::vector<MSTensor> empty;
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session is null.";
|
||||
return empty;
|
||||
}
|
||||
auto params = session_->GetFeatureMaps();
|
||||
if (params.empty()) {
|
||||
MS_LOG(ERROR) << "No optimizer parameters avelibale.";
|
||||
return empty;
|
||||
}
|
||||
std::vector<MSTensor> res = LiteTensorsToMSTensors(params, false);
|
||||
return res;
|
||||
}
|
||||
|
||||
Status ModelImpl::UpdateFeatureMaps(const std::vector<MSTensor> &new_weights) {
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
if (new_weights.empty()) {
|
||||
MS_LOG(ERROR) << "gradients is null.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
std::vector<tensor::MSTensor *> inner_weights;
|
||||
inner_weights.resize(new_weights.size());
|
||||
for (size_t i = 0; i < new_weights.size(); i++) {
|
||||
auto new_weight = new_weights[i];
|
||||
if (new_weight.impl_ == nullptr || new_weight.impl_->lite_tensor() == nullptr) {
|
||||
MS_LOG(ERROR) << "gradient tensor " << new_weight.Name() << " is null.";
|
||||
return kLiteInputTensorError;
|
||||
}
|
||||
inner_weights[i] = new_weight.impl_->lite_tensor();
|
||||
}
|
||||
auto ret = session_->UpdateFeatureMaps(inner_weights);
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
std::vector<MSTensor> ModelImpl::GetOptimizerParams() const {
|
||||
std::vector<MSTensor> empty;
|
||||
if (session_ == nullptr) {
|
||||
|
|
|
@ -77,6 +77,8 @@ class ModelImpl {
|
|||
std::vector<MSTensor> GetOutputs();
|
||||
std::vector<MSTensor> GetGradients() const;
|
||||
Status ApplyGradients(const std::vector<MSTensor> &gradients);
|
||||
std::vector<MSTensor> GetFeatureMaps() const;
|
||||
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
|
||||
std::vector<MSTensor> GetOptimizerParams() const;
|
||||
Status SetOptimizerParams(const std::vector<MSTensor> ¶ms);
|
||||
MSTensor GetInputByTensorName(const std::string &name);
|
||||
|
|
|
@ -1 +1 @@
|
|||
790864
|
||||
839924
|
||||
|
|
Loading…
Reference in New Issue