!33738 [MS][LITE] set static_allocator for train

Merge pull request !33738 from jianghui58/static_alloc
This commit is contained in:
i-robot 2022-04-29 01:04:14 +00:00 committed by Gitee
commit dc25bd4cf2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 8 additions and 1 deletions

View File

@ -87,7 +87,6 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
file(GLOB CXX_API_TRAIN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/model.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/model_impl.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/train_support.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/metrics/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/callback/*.cc
)
@ -316,6 +315,7 @@ set(TRAIN_SRC
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/model_build.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/model_build_impl.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/converters.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/train_support.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/graph_fusion.cc

View File

@ -38,6 +38,7 @@
#include "src/cxx_api/callback/callback_impl.h"
#include "src/common/log_adapter.h"
#include "src/train/train_session.h"
#include "src/train/static_allocator.h"
namespace mindspore {
std::shared_ptr<lite::LiteSession> CreateTrainSession(std::shared_ptr<Graph::GraphData> graph_data,
@ -58,6 +59,12 @@ std::shared_ptr<lite::LiteSession> CreateTrainSession(std::shared_ptr<Graph::Gra
}
shared_session.reset(session);
context->allocator = std::make_shared<StaticAllocator>();
if (context->allocator == nullptr) {
MS_LOG(ERROR) << " cannot convert to static allocation";
return nullptr;
}
lite::TrainCfg train_cfg;
if (cfg != nullptr) {
auto status = A2L_ConvertConfig(cfg.get(), &train_cfg);