diff --git a/graphengine b/graphengine index 0c33e9d1256..43f5d24337b 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 0c33e9d12562953ca4bd6c03cb77da2c2da74acd +Subproject commit 43f5d24337bf785251eefae2d810c7d5684194d6 diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 868255a3598..f5cacc7ed5c 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -97,7 +97,7 @@ PYBIND11_MODULE(_c_expression, m) { py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"), py::arg("phase") = py::str("dataset"), "Init and exec dataset."); (void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode."); - (void)m.def("init_ge", &mindspore::pipeline::InitGe, "Init GE"); + (void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend."); (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index 251a0c2d84f..0c1c0a924b2 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -236,7 +236,7 @@ py::dict ExecutorPy::GetAllreduceFusion(const std::string &phase) { void ExecutorPy::DelNetRes(const std::string &id) { #ifdef ENABLE_GE - FinalizeGe(); + FinalizeBackend(); #endif if (executor_ != nullptr) { bool flag = false; @@ -680,6 +680,13 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba const std::vector &types, const std::vector> &shapes, const std::vector &input_indexes, const std::string &phase) { std::string name = MsContext::GetInstance()->backend_policy(); +#ifndef NO_DLIB + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (!ms_context->IsTsdOpened() || !ms_context->IsGeInited()) { + (void)InitBackend(); + } +#endif if (name == kMsConvert || name == kMsVm) { return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes); } @@ -758,7 +765,7 @@ void ResetOpId() { mindspore::id_generator::reset_id(); } void InitHccl() { #ifdef ENABLE_GE - (void)InitGe(); + (void)InitBackend(); #else mindspore::parse::python_adapter::set_python_env_flag(true); auto ms_context = MsContext::GetInstance(); @@ -780,7 +787,7 @@ void InitHccl() { void FinalizeHccl() { #ifdef ENABLE_GE - (void)FinalizeGe(); + (void)FinalizeBackend(); #else device::KernelRuntimeManager::Instance().ClearRuntimeResource(); #endif @@ -801,7 +808,7 @@ void ReleaseGeTsd() { } } -void InitGe() { +void InitBackend() { // set python env flag mindspore::parse::python_adapter::set_python_env_flag(true); // open tsd before ge initialize @@ -813,7 +820,7 @@ void InitGe() { (void)ms_context->InitGe(); } -void FinalizeGe() { +void FinalizeBackend() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); (void)context_ptr->FinalizeGe(); diff --git a/mindspore/ccsrc/pipeline/pipeline.h b/mindspore/ccsrc/pipeline/pipeline.h index 38d4f1937fa..6a99d4dbcd4 100644 --- a/mindspore/ccsrc/pipeline/pipeline.h +++ b/mindspore/ccsrc/pipeline/pipeline.h @@ -116,8 +116,8 @@ bool InitDistribute(const std::map &options); void ResetOpId(); void InitHccl(); void FinalizeHccl(); -void InitGe(); -void FinalizeGe(); +void InitBackend(); +void FinalizeBackend(); void ClearResAtexit(); void ReleaseGeTsd(); diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index 5e8fc48216d..1e2a5d6f09c 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -439,4 +439,18 @@ bool MsContext::PynativeInitGe() { is_pynative_ge_init_ = true; return true; } + +bool MsContext::IsTsdOpened() { + if (tsd_ref_ > 0) { + return true; + } + return false; +} + +bool MsContext::IsGeInited() { + if (ge_ref_ > 0) { + return true; + } + return false; +} } // namespace mindspore diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index 1d84061a8a1..b2d594d10e7 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -82,8 +82,10 @@ class MsContext { bool OpenTsd(); bool CloseTsd(bool force = false); + bool IsTsdOpened(); bool InitGe(); bool FinalizeGe(bool force = false); + bool IsGeInited(); void set_enable_hccl(bool enable_hccl) { enable_hccl_ = enable_hccl; } bool enable_hccl() const { return enable_hccl_; } bool PynativeInitGe(); diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 5016dd58bf5..455e7a7f4fb 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -22,7 +22,7 @@ from mindspore import context from mindspore import log as logger from mindspore.parallel._utils import _get_parallel_mode from .._c_expression import generate_key, Executor_, Tensor, MetaTensor -from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_ge +from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend from .tensor import Tensor as MsTensor # store ms_function class compiled pipeline cache @@ -184,7 +184,7 @@ class _MindSporeFunction: @_wrap_func def __call__(self, *args): - init_ge() + init_backend() converted, arguments_dict, parse_method = _convert_function_arguments(self.fn, *args) if not converted: raise RuntimeError('Process function parameter is failure') diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 4980e90f3fa..9cea668471d 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -22,7 +22,7 @@ from ..common import dtype as mstype from ..common.api import _executor from .._checkparam import _check_str_by_regular from ..common.parameter import Parameter, ParameterTuple -from .._c_expression import init_ge +from .._c_expression import init_backend from ..ops.primitive import Primitive from ..parallel._tensor import _load_tensor_by_layout from ..parallel._utils import _get_parallel_mode @@ -66,7 +66,7 @@ class Cell: self._phase = 'train' self._parameter_layout_dict = {} self._create_time = int(time.time() * 1e9) - init_ge() + init_backend() # call gc to release GE session resources used by non-used cell objects gc.collect() self._construct_inputs_num = 0