forked from mindspore-Ecosystem/mindspore
!28003 Add CXX_API to cover more ToD features
Merge pull request !28003 from ehaleva/cxx_api
This commit is contained in:
commit
d300c2912b
|
@ -45,7 +45,7 @@ class MS_API Model {
|
|||
Model(const Model &) = delete;
|
||||
void operator=(const Model &) = delete;
|
||||
|
||||
/// \brief Builds a model so that it can run on a device.
|
||||
/// \brief Builds a model
|
||||
///
|
||||
/// \param[in] graph GraphCell is a derivative of Cell. Cell is not available currently. GraphCell can be constructed
|
||||
/// from Graph, for example, model.Build(GraphCell(graph), context).
|
||||
|
@ -56,6 +56,17 @@ class MS_API Model {
|
|||
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr,
|
||||
const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
|
||||
|
||||
/// \brief Builds a Transfer Learning model where the backbone weights are fixed and the head weights are trainable
|
||||
///
|
||||
/// \param[in] backbone The static, non-learnable part of the graph
|
||||
/// \param[in] head The trainable part of the graph
|
||||
/// \param[in] context A context used to store options during execution
|
||||
/// \param[in] cfg A config used by training
|
||||
///
|
||||
/// \return Status
|
||||
Status BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context,
|
||||
const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
|
||||
|
||||
/// \brief Resizes the shapes of inputs.
|
||||
///
|
||||
/// \param[in] inputs A vector that includes all input tensors in order.
|
||||
|
@ -173,6 +184,25 @@ class MS_API Model {
|
|||
/// \return Status of operation
|
||||
Status SetOptimizerParams(const std::vector<MSTensor> ¶ms);
|
||||
|
||||
/// \brief Setup training with virtual batches
|
||||
///
|
||||
/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable
|
||||
/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration
|
||||
/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration
|
||||
/// \return Status of operation
|
||||
Status SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f);
|
||||
|
||||
/// \brief Sets the Learning Rate of the training
|
||||
///
|
||||
/// \param[in] learning_rate to set
|
||||
/// \return Status of operation
|
||||
Status SetLearningRate(float learning_rate);
|
||||
|
||||
/// \brief Gets the Learning Rate of the optimizer
|
||||
///
|
||||
/// \return learning rate. 0.0 if no optimizer was found
|
||||
float GetLearningRate();
|
||||
|
||||
Status InitMetrics(std::vector<Metrics *> metrics);
|
||||
std::vector<Metrics *> GetMetrics();
|
||||
|
||||
|
|
|
@ -207,6 +207,8 @@ int NetRunner::TrainLoop() {
|
|||
Measurement measure(epochs_);
|
||||
|
||||
if (virtual_batch_ > 0) {
|
||||
auto status = model_->SetupVirtualBatch(virtual_batch_);
|
||||
MS_ASSERT(status == mindspore::kSuccess);
|
||||
model_->Train(epochs_, train_ds_, {&rescale, &lm, &cs, &measure});
|
||||
} else {
|
||||
struct mindspore::StepLRLambda step_lr_lambda(1, kGammaFactor);
|
||||
|
@ -237,7 +239,7 @@ int NetRunner::Main() {
|
|||
|
||||
void NetRunner::Usage() {
|
||||
std::cout << "Usage: net_runner -f <.ms model file> -d <data_dir> [-e <num of training epochs>] "
|
||||
<< "[-v (verbose mode)] [-s <save checkpoint every X iterations>]" << std::endl;
|
||||
<< "[-b <virtual batch size>] [-v (verbose mode)] [-s <save checkpoint every X iterations>]" << std::endl;
|
||||
}
|
||||
|
||||
bool NetRunner::ReadArgs(int argc, char *argv[]) {
|
||||
|
|
|
@ -365,4 +365,28 @@ std::vector<Metrics *> Model::GetMetrics() {
|
|||
return impl_->GetMetrics();
|
||||
}
|
||||
|
||||
Status Model::SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
return impl_->SetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
|
||||
}
|
||||
|
||||
Status Model::SetLearningRate(float learning_rate) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
return impl_->SetLearningRate(learning_rate);
|
||||
}
|
||||
|
||||
float Model::GetLearningRate() {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(WARNING) << "Model implement is null.";
|
||||
return 0.0;
|
||||
}
|
||||
return impl_->GetLearningRate();
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -650,6 +650,32 @@ Status ModelImpl::UpdateWeights(const std::vector<MSTensor> &new_weights) {
|
|||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
Status ModelImpl::SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto ret = session_->SetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
Status ModelImpl::SetLearningRate(float learning_rate) {
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto ret = session_->SetLearningRate(learning_rate);
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
float ModelImpl::GetLearningRate() {
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(WARNING) << "Session is null.";
|
||||
return 0.0;
|
||||
}
|
||||
return session_->GetLearningRate();
|
||||
}
|
||||
|
||||
lite::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) {
|
||||
auto session = new (std::nothrow) lite::LiteSession();
|
||||
if (session == nullptr) {
|
||||
|
@ -669,4 +695,5 @@ lite::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) {
|
|||
}
|
||||
return session;
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -92,6 +92,10 @@ class ModelImpl {
|
|||
|
||||
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
|
||||
bool IsTrainModel();
|
||||
Status SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum);
|
||||
Status SetLearningRate(float learning_rate);
|
||||
float GetLearningRate();
|
||||
Status BuildTransferLearning(const std::shared_ptr<Graph> &backbone, const std::shared_ptr<Graph> &head);
|
||||
|
||||
Status InitMetrics(const std::vector<Metrics *> metrics) {
|
||||
metrics_ = metrics;
|
||||
|
|
|
@ -106,4 +106,36 @@ Status Model::Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCa
|
|||
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
|
||||
}
|
||||
|
||||
Status Model::BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context,
|
||||
const std::shared_ptr<TrainCfg> &train_cfg) {
|
||||
std::stringstream err_msg;
|
||||
if (impl_ == nullptr) {
|
||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteFileError;
|
||||
}
|
||||
}
|
||||
|
||||
if (backbone.GetGraph() == nullptr || head.GetGraph() == nullptr) {
|
||||
err_msg << "Invalid null graph.";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kLiteNullptr, err_msg.str());
|
||||
}
|
||||
if (context == nullptr) {
|
||||
err_msg << "Invalid null context.";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kLiteNullptr, err_msg.str());
|
||||
}
|
||||
impl_->SetContext(context);
|
||||
impl_->SetGraph(head.GetGraph());
|
||||
impl_->SetConfig(train_cfg);
|
||||
|
||||
Status ret = impl_->BuildTransferLearning(backbone.GetGraph(), head.GetGraph());
|
||||
if (ret != kSuccess) {
|
||||
return ret;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,8 +37,48 @@
|
|||
#include "src/cxx_api/callback/callback_impl.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/train/train_session.h"
|
||||
#include "src/train/transfer_session.h"
|
||||
|
||||
namespace mindspore {
|
||||
Status ModelImpl::BuildTransferLearning(const std::shared_ptr<Graph> &backbone, const std::shared_ptr<Graph> &head) {
|
||||
const auto b_graph_data = backbone->graph_data_;
|
||||
const auto h_graph_data = head->graph_data_;
|
||||
if (b_graph_data == nullptr || h_graph_data == nullptr) {
|
||||
MS_LOG(ERROR) << "graph data cannot be nullptr";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
bool is_train_session = h_graph_data->IsTrainModel();
|
||||
if (is_train_session) {
|
||||
const auto b_model = reinterpret_cast<lite::LiteModel *>(b_graph_data->lite_model().get());
|
||||
const auto h_model = reinterpret_cast<lite::LiteModel *>(h_graph_data->lite_model().get());
|
||||
if (h_model == nullptr || h_model->buf == nullptr || b_model == nullptr || b_model->buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Lite model has been freed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
lite::TrainCfg train_cfg;
|
||||
if (cfg_ != nullptr) {
|
||||
auto status = A2L_ConvertConfig(cfg_.get(), &train_cfg);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "Failed to convert Config to Lite Config";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
auto session = std::shared_ptr<lite::LiteSession>(
|
||||
CreateTransferSessionInt(b_model->buf, b_model->buf_size_, h_model->buf, h_model->buf_size_,
|
||||
ContextUtils::Convert(context_.get()), true, &train_cfg));
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "create session failed";
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
session_.swap(session);
|
||||
return kSuccess;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Session is not a train session.";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *out_ms,
|
||||
std::vector<session::Metrics *> *adapter_ms) {
|
||||
if (out_ms == nullptr || adapter_ms == nullptr) {
|
||||
|
|
|
@ -94,7 +94,7 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) {
|
|||
int length = sm_params_.input_shape_[sm_params_.axis_];
|
||||
int stride = UP_DIV(outter_size_, threads_);
|
||||
int count = MSMIN(stride, outter_size_ - stride * task_id);
|
||||
if (count <= 0) return RET_ERROR;
|
||||
if (count <= 0) return RET_OK;
|
||||
switch (stage_) {
|
||||
case 0:
|
||||
SoftMaxP1(ins, losses, sum_data, task_id * stride, count, length, inner_size_);
|
||||
|
@ -145,7 +145,8 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
|
|||
}
|
||||
inner_size_ = inner_size;
|
||||
outter_size_ = outter_size;
|
||||
const std::vector<int> threads = {op_parameter_->thread_num_, op_parameter_->thread_num_, 1};
|
||||
int max_num_of_threads = (outter_size_ < op_parameter_->thread_num_) ? outter_size_ : op_parameter_->thread_num_;
|
||||
const std::vector<int> threads = {max_num_of_threads, max_num_of_threads, 1};
|
||||
for (int stage = 0; stage < static_cast<int>(threads.size()); stage++) {
|
||||
stage_ = stage;
|
||||
threads_ = threads.at(stage);
|
||||
|
|
|
@ -234,12 +234,10 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
|
|||
if (orig_train_state) Train();
|
||||
return status;
|
||||
}
|
||||
} // namespace lite
|
||||
|
||||
static session::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
|
||||
const char *model_buf_head, size_t size_head,
|
||||
const lite::Context *context, bool train_mode,
|
||||
const lite::TrainCfg *cfg) {
|
||||
lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
|
||||
const char *model_buf_head, size_t size_head, const lite::Context *context,
|
||||
bool train_mode, const lite::TrainCfg *cfg) {
|
||||
auto ValidModelSize = [](size_t size) -> bool {
|
||||
constexpr size_t MaxModelSize = 1024 * 1024 * 1024ULL; // 1G B
|
||||
return size < MaxModelSize && size > 0;
|
||||
|
@ -304,6 +302,8 @@ static session::LiteSession *CreateTransferSessionInt(const char *model_buf_back
|
|||
return session;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
|
||||
session::LiteSession *session::TrainSession::CreateTransferSession(const std::string &filename_backbone,
|
||||
const std::string &filename_head,
|
||||
const lite::Context *ctxt, bool train_mode,
|
||||
|
|
|
@ -77,6 +77,10 @@ class TransferSession : public lite::TrainSession {
|
|||
bool nchw2nhwc_ = false;
|
||||
size_t size_backbone_;
|
||||
};
|
||||
|
||||
lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
|
||||
const char *model_buf_head, size_t size_head, const lite::Context *context,
|
||||
bool train_mode, const lite::TrainCfg *cfg);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_
|
||||
|
|
|
@ -233,4 +233,22 @@ TEST_F(TestCxxApiLiteModel, set_weights_FAILURE) {
|
|||
*MSTensor::CreateTensor("fc3.bias", mindspore::DataType::kNumberTypeFloat32, {NUM_OF_CLASSES}, nullptr, 0));
|
||||
ASSERT_TRUE(model.UpdateWeights(changes) == kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiLiteModel, set_get_lr_SUCCESS) {
|
||||
Model model;
|
||||
Graph graph;
|
||||
float learn_rate = 0.2;
|
||||
auto context = std::make_shared<Context>();
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
cpu_context->SetEnableFP16(true);
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
auto train_cfg = std::make_shared<TrainCfg>();
|
||||
|
||||
ASSERT_TRUE(Serialization::Load("./nets/mix_lenet_tod.ms", ModelType::kMindIR, &graph) == kSuccess);
|
||||
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||
|
||||
ASSERT_TRUE(model.SetLearningRate(learn_rate) == kSuccess);
|
||||
ASSERT_TRUE(model.GetLearningRate() == learn_rate);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -65,8 +65,8 @@ float MixedBitWeightQuantizer::CalculateMeanError(std::vector<float> norms2, std
|
|||
error_count += 1;
|
||||
mse_error += sqrtf(dnorms2[i] / norms2[i]);
|
||||
}
|
||||
auto meam_error = mse_error / (error_count + soft);
|
||||
return meam_error;
|
||||
auto mean_error = mse_error / (error_count + soft);
|
||||
return mean_error;
|
||||
}
|
||||
|
||||
// the `preferred` dim should point to the output channels dimension.
|
||||
|
@ -109,8 +109,8 @@ float MixedBitWeightQuantizer::MeasureQuantizationError(float *weights, const in
|
|||
float d = weights[i] - dequant;
|
||||
dnorms2[bucket] += d * d;
|
||||
}
|
||||
auto meam_error = CalculateMeanError(norms2, dnorms2);
|
||||
return meam_error;
|
||||
auto mean_error = CalculateMeanError(norms2, dnorms2);
|
||||
return mean_error;
|
||||
}
|
||||
|
||||
MinMax MixedBitWeightQuantizer::GetMinMax(const float *arr, int arrc) {
|
||||
|
|
|
@ -249,14 +249,14 @@ int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, conve
|
|||
delete origin_model;
|
||||
return RET_OK;
|
||||
}
|
||||
int babysitting_rounds = 25;
|
||||
step = (min_max.max - min_max.min) / babysitting_rounds;
|
||||
int baby_step_rounds = 25;
|
||||
step = (min_max.max - min_max.min) / baby_step_rounds;
|
||||
|
||||
param.rounds = babysitting_rounds;
|
||||
param.rounds = baby_step_rounds;
|
||||
param.start_scale = start_scale;
|
||||
param.step = step;
|
||||
param.thread_num = flags->commonQuantParam.thread_num;
|
||||
std::cout << "==========Search with babysitting step==============\n";
|
||||
std::cout << "==========Search with baby step==============\n";
|
||||
ret = WeightQuantModelInference(func_graph, flags, origin_session, origin_model_size, param, init_scale,
|
||||
&candidate_scales, true);
|
||||
if (ret != RET_OK) {
|
||||
|
|
Loading…
Reference in New Issue