forked from mindspore-Ecosystem/mindspore
fixed issue I3HZIK and removed some more SUPPORT_TRAIN ifdefs
This commit is contained in:
parent
84a67654a4
commit
8974c7f0bc
|
@ -395,9 +395,9 @@ int LiteSession::CompileGraph(Model *model) {
|
|||
}
|
||||
// scheduler kernels
|
||||
#if SUPPORT_NPU
|
||||
Scheduler scheduler(context_, model, &tensors_, npu_manager_, npu_pass_manager_);
|
||||
Scheduler scheduler(context_, model, &tensors_, is_train_session_, npu_manager_, npu_pass_manager_);
|
||||
#else
|
||||
Scheduler scheduler(context_, model, &tensors_);
|
||||
Scheduler scheduler(context_, model, &tensors_, is_train_session_);
|
||||
#endif
|
||||
ret = scheduler.Schedule(&kernels_);
|
||||
if (ret != RET_OK) {
|
||||
|
@ -599,7 +599,7 @@ LiteSession::~LiteSession() {
|
|||
npu_manager_->Reset();
|
||||
delete npu_manager_;
|
||||
#endif
|
||||
#if GPU_OPENCL && !SUPPORT_TRAIN
|
||||
#if GPU_OPENCL
|
||||
delete opencl_runtime_wrapper_;
|
||||
#endif
|
||||
delete (model_);
|
||||
|
@ -737,7 +737,7 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
|
|||
}
|
||||
|
||||
int LiteSession::InitGPURuntime() {
|
||||
#if GPU_OPENCL && !SUPPORT_TRAIN
|
||||
#if GPU_OPENCL
|
||||
if (this->context_->IsGpuEnabled()) {
|
||||
opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper();
|
||||
if (opencl_runtime_wrapper_ == nullptr) {
|
||||
|
@ -754,7 +754,7 @@ int LiteSession::InitGPURuntime() {
|
|||
MS_LOG(INFO) << "Init OpenCL runtime success.";
|
||||
}
|
||||
}
|
||||
#elif GPU_VULKAN && !SUPPORT_TRAIN
|
||||
#elif GPU_VULKAN
|
||||
if (this->context_->IsGpuEnabled()) {
|
||||
auto gpu_device_info = this->context_->GetGpuInfo();
|
||||
vk_runtime_wrap_ = new (std::nothrow) gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime>;
|
||||
|
|
|
@ -134,13 +134,14 @@ class LiteSession : public session::LiteSession {
|
|||
Executor *executor_ = nullptr;
|
||||
Model *model_ = nullptr;
|
||||
std::atomic<bool> is_running_ = false;
|
||||
bool is_train_session_ = false;
|
||||
#if SUPPORT_NPU
|
||||
NPUManager *npu_manager_ = nullptr;
|
||||
NPUPassManager *npu_pass_manager_ = nullptr;
|
||||
#endif
|
||||
#if GPU_OPENCL && !SUPPORT_TRAIN
|
||||
#if GPU_OPENCL
|
||||
opencl::OpenCLRuntimeWrapper *opencl_runtime_wrapper_{nullptr};
|
||||
#elif GPU_VULKAN && !SUPPORT_TRAIN
|
||||
#elif GPU_VULKAN
|
||||
gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime> *vk_runtime_wrap_{nullptr};
|
||||
#endif
|
||||
};
|
||||
|
|
|
@ -215,7 +215,6 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_inter
|
|||
}
|
||||
|
||||
namespace {
|
||||
#ifndef SUPPORT_TRAIN
|
||||
int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) {
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
|
@ -319,7 +318,6 @@ int CopyConstTensorData(const std::vector<Tensor *> &tensors, int op_type) {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
|
||||
inline void FreeRestoreTensors(std::map<Tensor *, Tensor *> *restored_origin_tensors) {
|
||||
MS_ASSERT(restored_origin_tensors != nullptr);
|
||||
|
@ -368,19 +366,20 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_ten
|
|||
return nullptr;
|
||||
}
|
||||
std::map<Tensor *, Tensor *> restored_origin_tensors;
|
||||
#ifndef SUPPORT_TRAIN
|
||||
ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret;
|
||||
return nullptr;
|
||||
|
||||
if (!is_train_session_) {
|
||||
ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret;
|
||||
return nullptr;
|
||||
}
|
||||
// we don't need to restore tensor for copy data
|
||||
ret = CopyConstTensorData(in_tensors, op_type);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
// we don't need to restore tensor for copy data
|
||||
ret = CopyConstTensorData(in_tensors, op_type);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter);
|
||||
if (kernel != nullptr) {
|
||||
MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type);
|
||||
|
|
|
@ -30,16 +30,17 @@
|
|||
namespace mindspore::lite {
|
||||
class Scheduler {
|
||||
public:
|
||||
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors)
|
||||
: context_(ctx), src_model_(src_model), src_tensors_(src_tensors) {}
|
||||
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session)
|
||||
: context_(ctx), src_model_(src_model), src_tensors_(src_tensors), is_train_session_(is_train_session) {}
|
||||
#if SUPPORT_NPU
|
||||
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors,
|
||||
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session,
|
||||
NPUManager *npu_manager = nullptr, NPUPassManager *npu_pass_manager = nullptr)
|
||||
: context_(ctx),
|
||||
src_model_(src_model),
|
||||
src_tensors_(src_tensors),
|
||||
npu_manager_(npu_manager),
|
||||
npu_pass_manager_(npu_pass_manager) {}
|
||||
npu_pass_manager_(npu_pass_manager),
|
||||
is_train_session_(is_train_session) {}
|
||||
#endif
|
||||
~Scheduler() = default;
|
||||
|
||||
|
@ -113,6 +114,7 @@ class Scheduler {
|
|||
#endif
|
||||
std::vector<size_t> graph_output_node_indexes_;
|
||||
std::map<int, OpParameter *> op_parameters_;
|
||||
bool is_train_session_ = false;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
|
|
|
@ -166,11 +166,6 @@ int CpuSubGraph::Run(const KernelCallBack &before, const KernelCallBack &after)
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
for (auto out_tensor : out_tensors_) { // increase RefCount of output tensors, such that Run will not free them
|
||||
out_tensor->set_ref_count(out_tensor->ref_count() + 1);
|
||||
}
|
||||
#endif
|
||||
#ifdef SUPPORT_GPU
|
||||
// In heterogeneous scenarios of CPU and GPU, call MutableData to MapBuffer(synchronize data).
|
||||
if (static_cast<const lite::InnerContext *>(context_)->IsGpuEnabled()) {
|
||||
|
|
|
@ -57,6 +57,7 @@ static kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *>
|
|||
return *it;
|
||||
}
|
||||
TrainSession::TrainSession() {
|
||||
is_train_session_ = true;
|
||||
#ifdef ENABLE_V0
|
||||
if (VersionManager::GetInstance()->CheckV0Schema()) {
|
||||
kernel::PopulateTrainV0Parameters();
|
||||
|
|
Loading…
Reference in New Issue