fixed issue I3HZIK and removed some more SUPPORT_TRAIN ifdefs

This commit is contained in:
Emir Haleva 2021-04-20 08:09:28 +03:00
parent 84a67654a4
commit 8974c7f0bc
6 changed files with 28 additions and 30 deletions

View File

@ -395,9 +395,9 @@ int LiteSession::CompileGraph(Model *model) {
}
// scheduler kernels
#if SUPPORT_NPU
Scheduler scheduler(context_, model, &tensors_, npu_manager_, npu_pass_manager_);
Scheduler scheduler(context_, model, &tensors_, is_train_session_, npu_manager_, npu_pass_manager_);
#else
Scheduler scheduler(context_, model, &tensors_);
Scheduler scheduler(context_, model, &tensors_, is_train_session_);
#endif
ret = scheduler.Schedule(&kernels_);
if (ret != RET_OK) {
@ -599,7 +599,7 @@ LiteSession::~LiteSession() {
npu_manager_->Reset();
delete npu_manager_;
#endif
#if GPU_OPENCL && !SUPPORT_TRAIN
#if GPU_OPENCL
delete opencl_runtime_wrapper_;
#endif
delete (model_);
@ -737,7 +737,7 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
}
int LiteSession::InitGPURuntime() {
#if GPU_OPENCL && !SUPPORT_TRAIN
#if GPU_OPENCL
if (this->context_->IsGpuEnabled()) {
opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper();
if (opencl_runtime_wrapper_ == nullptr) {
@ -754,7 +754,7 @@ int LiteSession::InitGPURuntime() {
MS_LOG(INFO) << "Init OpenCL runtime success.";
}
}
#elif GPU_VULKAN && !SUPPORT_TRAIN
#elif GPU_VULKAN
if (this->context_->IsGpuEnabled()) {
auto gpu_device_info = this->context_->GetGpuInfo();
vk_runtime_wrap_ = new (std::nothrow) gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime>;

View File

@ -134,13 +134,14 @@ class LiteSession : public session::LiteSession {
Executor *executor_ = nullptr;
Model *model_ = nullptr;
std::atomic<bool> is_running_ = false;
bool is_train_session_ = false;
#if SUPPORT_NPU
NPUManager *npu_manager_ = nullptr;
NPUPassManager *npu_pass_manager_ = nullptr;
#endif
#if GPU_OPENCL && !SUPPORT_TRAIN
#if GPU_OPENCL
opencl::OpenCLRuntimeWrapper *opencl_runtime_wrapper_{nullptr};
#elif GPU_VULKAN && !SUPPORT_TRAIN
#elif GPU_VULKAN
gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime> *vk_runtime_wrap_{nullptr};
#endif
};

View File

@ -215,7 +215,6 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_inter
}
namespace {
#ifndef SUPPORT_TRAIN
int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) {
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
MS_ASSERT(tensor != nullptr);
@ -319,7 +318,6 @@ int CopyConstTensorData(const std::vector<Tensor *> &tensors, int op_type) {
}
return RET_OK;
}
#endif
inline void FreeRestoreTensors(std::map<Tensor *, Tensor *> *restored_origin_tensors) {
MS_ASSERT(restored_origin_tensors != nullptr);
@ -368,7 +366,8 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_ten
return nullptr;
}
std::map<Tensor *, Tensor *> restored_origin_tensors;
#ifndef SUPPORT_TRAIN
if (!is_train_session_) {
ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret;
@ -380,7 +379,7 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_ten
MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
return nullptr;
}
#endif
}
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type);

View File

@ -30,16 +30,17 @@
namespace mindspore::lite {
class Scheduler {
public:
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors)
: context_(ctx), src_model_(src_model), src_tensors_(src_tensors) {}
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session)
: context_(ctx), src_model_(src_model), src_tensors_(src_tensors), is_train_session_(is_train_session) {}
#if SUPPORT_NPU
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors,
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session,
NPUManager *npu_manager = nullptr, NPUPassManager *npu_pass_manager = nullptr)
: context_(ctx),
src_model_(src_model),
src_tensors_(src_tensors),
npu_manager_(npu_manager),
npu_pass_manager_(npu_pass_manager) {}
npu_pass_manager_(npu_pass_manager),
is_train_session_(is_train_session) {}
#endif
~Scheduler() = default;
@ -113,6 +114,7 @@ class Scheduler {
#endif
std::vector<size_t> graph_output_node_indexes_;
std::map<int, OpParameter *> op_parameters_;
bool is_train_session_ = false;
};
} // namespace mindspore::lite

View File

@ -166,11 +166,6 @@ int CpuSubGraph::Run(const KernelCallBack &before, const KernelCallBack &after)
}
}
#ifdef SUPPORT_TRAIN
for (auto out_tensor : out_tensors_) { // increase RefCount of output tensors, such that Run will not free them
out_tensor->set_ref_count(out_tensor->ref_count() + 1);
}
#endif
#ifdef SUPPORT_GPU
// In heterogeneous scenarios of CPU and GPU, call MutableData to MapBuffer(synchronize data).
if (static_cast<const lite::InnerContext *>(context_)->IsGpuEnabled()) {

View File

@ -57,6 +57,7 @@ static kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *>
return *it;
}
TrainSession::TrainSession() {
is_train_session_ = true;
#ifdef ENABLE_V0
if (VersionManager::GetInstance()->CheckV0Schema()) {
kernel::PopulateTrainV0Parameters();