mindspore/include/api/model.h

335 lines
14 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_MODEL_H
#define MINDSPORE_INCLUDE_API_MODEL_H
#include <string>
#include <vector>
#include <map>
#include <memory>
#include <utility>
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/api/graph.h"
#include "include/api/context.h"
#include "include/api/callback/callback.h"
#include "include/api/cell.h"
#include "include/api/cfg.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore {
class ModelImpl;
class Metrics;
namespace dataset {
class Dataset;
} // namespace dataset
/// \brief The Model class is used to define a MindSpore model, facilitating computational graph management.
class MS_API Model {
public:
Model();
~Model();
Model(const Model &) = delete;
void operator=(const Model &) = delete;
/// \brief Builds a model
///
/// \param[in] graph GraphCell is a derivative of Cell. Cell is not available currently. GraphCell can be constructed
/// from Graph, for example, model.Build(GraphCell(graph), context).
/// \param[in] model_context A context used to store options during execution.
/// \param[in] train_cfg A config used by training.
///
/// \return Status.
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr,
const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
/// \brief Builds a Transfer Learning model where the backbone weights are fixed and the head weights are trainable
///
/// \param[in] backbone The static, non-learnable part of the graph
/// \param[in] head The trainable part of the graph
/// \param[in] context A context used to store options during execution
/// \param[in] cfg A config used by training
///
/// \return Status
Status BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context,
const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
/// \brief Resizes the shapes of inputs.
///
/// \param[in] inputs A vector that includes all input tensors in order.
/// \param[in] dims Defines the new shapes of inputs, should be consistent with inputs.
///
/// \return Status.
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
/// \brief Change the size and or content of weight tensors
///
/// \param[in] new_weights a vector of tensors with new shapes and data to use in the model
/// If data pointer is null, the data of the original tensors will be copied to the new ones
///
/// \return Status.
Status UpdateWeights(const std::vector<MSTensor> &new_weights);
/// \brief Inference model.
///
/// \param[in] inputs A vector where model inputs are arranged in sequence.
/// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
/// \param[in] before CallBack before predict.
/// \param[in] after CallBack after predict.
///
/// \return Status.
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
/// \brief Train model by step.
///
/// \param[in] before CallBack before predict.
/// \param[in] after CallBack after predict.
///
/// \return Status.
Status RunStep(const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
/// \brief Inference model with preprocess in model.
///
/// \param[in] inputs A vector where model inputs are arranged in sequence.
/// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
/// \param[in] whether to use data preprocess in model.
/// \param[in] before CallBack before predict.
/// \param[in] after CallBack after predict.
///
/// \return Status.
Status PredictWithPreprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
/// \brief Apply data preprocess if it exits in model.
///
/// \param[in] inputs A vector where model inputs are arranged in sequence.
/// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
///
/// \return Status.
Status Preprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs);
/// \brief Check if data preprocess exists in model.
/// \return true if data preprocess exists.
bool HasPreprocess();
/// \brief Load config file.
///
/// \param[in] config_path config file path.
///
/// \return Status.
inline Status LoadConfig(const std::string &config_path);
/// \brief Update config.
///
/// \param[in] section define the config section.
/// \param[in] config define the config will be updated.
///
/// \return Status.
inline Status UpdateConfig(const std::string &section, const std::pair<std::string, std::string> &config);
/// \brief Obtains all input tensors of the model.
///
/// \return The vector that includes all input tensors.
std::vector<MSTensor> GetInputs();
/// \brief Obtains the input tensor of the model by name.
///
/// \return The input tensor with the given name, if the name is not found, an invalid tensor is returned.
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
/// \brief Obtains all gradient tensors of the model.
///
/// \return The vector that includes all gradient tensors.
std::vector<MSTensor> GetGradients() const;
/// \brief update gradient tensors of the model.
///
/// \param[in] inputs A vector new gradients.
/// \return Status of operation
Status ApplyGradients(const std::vector<MSTensor> &gradients);
/// \brief Obtains all weights tensors of the model.
///
/// \return The vector that includes all gradient tensors.
std::vector<MSTensor> GetFeatureMaps() const;
/// \brief update weights tensors of the model.
///
/// \param[in] inputs A vector new weights.
/// \return Status of operation
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
/// \brief Obtains optimizer params tensors of the model.
///
/// \return The vector that includes all params tensors.
std::vector<MSTensor> GetOptimizerParams() const;
/// \brief update the optimizer parameters
///
/// \param[in] inputs A vector new optimizer params.
/// \return Status of operation
Status SetOptimizerParams(const std::vector<MSTensor> &params);
/// \brief Setup training with virtual batches
///
/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable
/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration
/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration
/// \return Status of operation
Status SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f);
/// \brief Sets the Learning Rate of the training
///
/// \param[in] learning_rate to set
/// \return Status of operation
Status SetLearningRate(float learning_rate);
/// \brief Gets the Learning Rate of the optimizer
///
/// \return learning rate. 0.0 if no optimizer was found
float GetLearningRate();
Status InitMetrics(std::vector<Metrics *> metrics);
std::vector<Metrics *> GetMetrics();
/// \brief Obtains all output tensors of the model.
///
/// \return The vector that includes all output tensors.
std::vector<MSTensor> GetOutputs();
/// \brief Obtains names of all output tensors of the model.
///
/// \return A vector that includes names of all output tensors.
inline std::vector<std::string> GetOutputTensorNames();
/// \brief Obtains the output tensor of the model by name.
///
/// \return The output tensor with the given name, if the name is not found, an invalid tensor is returned.
inline MSTensor GetOutputByTensorName(const std::string &tensor_name);
/// \brief Get output MSTensors of model by node name.
///
/// \param[in] node_name Define node name.
///
/// \note Deprecated, replace with GetOutputByTensorName
///
/// \return The vector of output MSTensor.
inline std::vector<MSTensor> GetOutputsByNodeName(const std::string &node_name);
/// \brief Bind GLTexture2D object to cl Memory.
///
/// \param[in] inputGlTexture The input GLTexture id for Model.
/// \param[in] outputGLTexture The output GLTexture id for Model.
///
/// \return Status of operation.
Status BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
std::map<std::string, unsigned int> *outputGLTexture);
/// \brief Inference model.
///
/// \param[in] device_type Device typeoptions are kGPU, kAscend etc.
/// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM.
///
/// \return Is supported or not.
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
Status SetTrainMode(bool train);
bool GetTrainMode() const;
Status Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
Status Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_data Define the buffer read from a model file.
/// \param[in] size Define bytes number of model buffer.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
///
/// \return Status.
inline Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_path Define the model path.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
///
/// \return Status.
inline Status Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
private:
friend class Serialization;
// api without std::string
MSTensor GetInputByTensorName(const std::vector<char> &tensor_name);
std::vector<std::vector<char>> GetOutputTensorNamesChar();
MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
Status LoadConfig(const std::vector<char> &config_path);
Status UpdateConfig(const std::vector<char> &section, const std::pair<std::vector<char>, std::vector<char>> &config);
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::vector<char> &dec_mode);
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::vector<char> &dec_mode);
std::shared_ptr<ModelImpl> impl_;
};
MSTensor Model::GetInputByTensorName(const std::string &tensor_name) {
return GetInputByTensorName(StringToChar(tensor_name));
}
std::vector<std::string> Model::GetOutputTensorNames() { return VectorCharToString(GetOutputTensorNamesChar()); }
MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) {
return GetOutputByTensorName(StringToChar(tensor_name));
}
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &node_name) {
return GetOutputsByNodeName(StringToChar(node_name));
}
Status Model::LoadConfig(const std::string &config_path) { return LoadConfig(StringToChar(config_path)); }
Status Model::UpdateConfig(const std::string &section, const std::pair<std::string, std::string> &config) {
std::pair<std::vector<char>, std::vector<char>> config_pair = {StringToChar(config.first),
StringToChar(config.second)};
return UpdateConfig(StringToChar(section), config_pair);
}
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode));
}
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) {
return Build(StringToChar(model_path), model_type, model_context, dec_key, StringToChar(dec_mode));
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H