forked from mindspore-Ecosystem/mindspore
!33738 [MS][LITE] set static_allocator for train
Merge pull request !33738 from jianghui58/static_alloc
This commit is contained in:
commit
dc25bd4cf2
|
@ -87,7 +87,6 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
|
|||
file(GLOB CXX_API_TRAIN_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/model.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/model_impl.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/train_support.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/metrics/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/callback/*.cc
|
||||
)
|
||||
|
@ -316,6 +315,7 @@ set(TRAIN_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/model_build.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/model_build_impl.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/converters.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/train/train_support.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/graph_fusion.cc
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "src/cxx_api/callback/callback_impl.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/train/train_session.h"
|
||||
#include "src/train/static_allocator.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::shared_ptr<lite::LiteSession> CreateTrainSession(std::shared_ptr<Graph::GraphData> graph_data,
|
||||
|
@ -58,6 +59,12 @@ std::shared_ptr<lite::LiteSession> CreateTrainSession(std::shared_ptr<Graph::Gra
|
|||
}
|
||||
shared_session.reset(session);
|
||||
|
||||
context->allocator = std::make_shared<StaticAllocator>();
|
||||
if (context->allocator == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot convert to static allocation";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
lite::TrainCfg train_cfg;
|
||||
if (cfg != nullptr) {
|
||||
auto status = A2L_ConvertConfig(cfg.get(), &train_cfg);
|
||||
|
|
Loading…
Reference in New Issue