forked from mindspore-Ecosystem/mindspore
Enable PyNative Lazy Build in Cell automatically
This commit is contained in:
parent
8a51311e57
commit
9357931761
|
@ -854,21 +854,24 @@ KernelGraphPtr AscendSession::CreateKernelGraph(const GraphInfo &graph_info, OpR
|
|||
return graph;
|
||||
}
|
||||
|
||||
bool AscendSession::DisableLazyBuild(const OpRunInfo &op_run_info) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
return !op_run_info.lazy_build || ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode ||
|
||||
op_run_info.is_dynamic_shape || ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
|
||||
}
|
||||
|
||||
void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode || op_run_info->is_dynamic_shape ||
|
||||
ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
if (DisableLazyBuild(*op_run_info)) {
|
||||
session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks();
|
||||
RunOpImplOrigin(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
|
||||
return;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
|
||||
bool cache_miss = run_op_graphs_.find(graph_info) == run_op_graphs_.end();
|
||||
auto graph = CreateKernelGraph(graph_info, op_run_info, input_tensors, tensors_mask, cache_miss);
|
||||
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
|
|
|
@ -154,6 +154,7 @@ class AscendSession : public SessionBasic {
|
|||
KernelGraphPtr CreateKernelGraph(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask, bool cache_miss);
|
||||
static bool DisableLazyBuild(const OpRunInfo &op_run_info);
|
||||
// key is final_graph_id,value is child graph execute order of final graph
|
||||
std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
|
||||
// key is final_graph_id,value is the graph types of child graphs
|
||||
|
|
|
@ -62,6 +62,7 @@ struct OpRunInfo {
|
|||
AbstractBasePtr abstract;
|
||||
bool is_dynamic_shape = false;
|
||||
bool is_auto_mixed_precision = false;
|
||||
bool lazy_build = false;
|
||||
std::string next_op_name = "";
|
||||
#if defined(__APPLE__)
|
||||
int next_input_index = 0;
|
||||
|
|
|
@ -63,6 +63,7 @@ struct OpExecInfo {
|
|||
py::dict op_attrs;
|
||||
#endif
|
||||
std::vector<int64_t> inputs_mask;
|
||||
bool lazy_build = false;
|
||||
};
|
||||
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
|
||||
|
||||
|
|
|
@ -889,6 +889,7 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
|
|||
}
|
||||
op_exec_info->py_primitive = prim;
|
||||
op_exec_info->op_inputs = args[PY_INPUTS];
|
||||
op_exec_info->lazy_build = lazy_build_;
|
||||
return op_exec_info;
|
||||
}
|
||||
|
||||
|
@ -1117,6 +1118,7 @@ py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type
|
|||
inputs[0] = arg;
|
||||
inputs[1] = dst_type;
|
||||
op_exec_info->op_inputs = inputs;
|
||||
op_exec_info->lazy_build = lazy_build_;
|
||||
py::object ret = py::none();
|
||||
RunOpInner(&ret, op_exec_info);
|
||||
return ret;
|
||||
|
@ -1861,6 +1863,7 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
|
|||
op_exec_info->abstract,
|
||||
op_exec_info->is_dynamic_shape,
|
||||
op_exec_info->is_mixed_precision_cast,
|
||||
op_exec_info->lazy_build,
|
||||
op_exec_info->next_op_name,
|
||||
static_cast<int>(op_exec_info->next_input_index)};
|
||||
#else
|
||||
|
@ -1869,6 +1872,7 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
|
|||
op_exec_info->abstract,
|
||||
op_exec_info->is_dynamic_shape,
|
||||
op_exec_info->is_mixed_precision_cast,
|
||||
op_exec_info->lazy_build,
|
||||
op_exec_info->next_op_name,
|
||||
op_exec_info->next_input_index};
|
||||
#endif
|
||||
|
@ -3055,6 +3059,8 @@ void PynativeExecutor::ClearGrad(const py::object &cell, const py::args &args) {
|
|||
void PynativeExecutor::ClearRes() {
|
||||
MS_LOG(DEBUG) << "Clear all res";
|
||||
session::PynativeTaskManager::GetInstance().Reset();
|
||||
SetLazyBuild(false);
|
||||
cell_depth_ = 0;
|
||||
|
||||
// Maybe exit in runop step
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -3125,11 +3131,32 @@ void PynativeExecutor::Sync() {
|
|||
}
|
||||
}
|
||||
|
||||
void PynativeExecutor::SetLazyBuild(bool enable) { forward_executor()->set_lazy_build(enable); }
|
||||
|
||||
void PynativeExecutor::EnterCell() {
|
||||
if (cell_depth_ < UINT32_MAX) {
|
||||
++cell_depth_;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Cell call stack too deep";
|
||||
}
|
||||
}
|
||||
|
||||
void PynativeExecutor::ExitCell() {
|
||||
if (cell_depth_ > 0) {
|
||||
--cell_depth_;
|
||||
}
|
||||
}
|
||||
|
||||
bool PynativeExecutor::IsTopCell() const { return cell_depth_ == 0; }
|
||||
|
||||
void PynativeExecutor::ExecuteAllTask() { session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks(); }
|
||||
|
||||
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
|
||||
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
|
||||
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
|
||||
.def("enter_cell", &PynativeExecutor::EnterCell, "enter cell.")
|
||||
.def("exit_cell", &PynativeExecutor::ExitCell, "exit cell.")
|
||||
.def("is_top_cell", &PynativeExecutor::IsTopCell, "check top cell.")
|
||||
.def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
|
||||
.def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
|
||||
.def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.")
|
||||
|
@ -3139,6 +3166,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
|
|||
.def("clear_cell", &PynativeExecutor::ClearCell, "pynative clear status.")
|
||||
.def("clear_grad", &PynativeExecutor::ClearGrad, "pynative clear grad status.")
|
||||
.def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
|
||||
.def("set_lazy_build", &PynativeExecutor::SetLazyBuild, "pynative build kernel async")
|
||||
.def("execute_all_task", &PynativeExecutor::ExecuteAllTask, "clear all task")
|
||||
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
|
||||
.def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase")
|
||||
|
|
|
@ -314,6 +314,7 @@ class ForwardExecutor {
|
|||
std::unordered_map<std::string, abstract::AbstractBasePtr> &node_abs_map() { return node_abs_map_; }
|
||||
void ClearRes();
|
||||
CNodePtr ConstructForwardGraph(const OpExecInfoPtr &op_exec_info);
|
||||
void set_lazy_build(bool lazy_build) { lazy_build_ = lazy_build; }
|
||||
|
||||
private:
|
||||
GradExecutorPtr grad() const;
|
||||
|
@ -346,6 +347,7 @@ class ForwardExecutor {
|
|||
PrimAbsCache prim_abs_list_;
|
||||
ImplicitCastCache implicit_cast_map_;
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
||||
bool lazy_build_{false};
|
||||
};
|
||||
|
||||
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
||||
|
@ -388,7 +390,11 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void ClearRes();
|
||||
// Sync stream
|
||||
void Sync();
|
||||
void SetLazyBuild(bool enable);
|
||||
void ExecuteAllTask();
|
||||
void EnterCell();
|
||||
void ExitCell();
|
||||
bool IsTopCell() const;
|
||||
|
||||
private:
|
||||
PynativeExecutor() = default;
|
||||
|
@ -397,6 +403,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
static std::mutex instance_lock_;
|
||||
static ForwardExecutorPtr forward_executor_;
|
||||
static GradExecutorPtr grad_executor_;
|
||||
uint32_t cell_depth_{0};
|
||||
};
|
||||
|
||||
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
|
||||
|
|
|
@ -394,6 +394,12 @@ class _PynativeExecutor:
|
|||
def sync(self):
|
||||
self._executor.sync()
|
||||
|
||||
def set_lazy_build(self, enable):
|
||||
self._executor.set_lazy_build(enable)
|
||||
|
||||
def execute_all_task(self):
|
||||
self._executor.execute_all_task()
|
||||
|
||||
def grad_ms_function(self, output, *args):
|
||||
return self._executor.grad_ms_function(output, *args)
|
||||
|
||||
|
@ -416,6 +422,15 @@ class _PynativeExecutor:
|
|||
if BROADCAST_PHASE not in phase and _get_parameter_broadcast():
|
||||
_parameter_broadcast(obj, auto_parallel_mode)
|
||||
|
||||
def enter_cell(self):
|
||||
self._executor.enter_cell()
|
||||
|
||||
def exit_cell(self):
|
||||
self._executor.exit_cell()
|
||||
|
||||
def is_top_cell(self):
|
||||
return self._executor.is_top_cell()
|
||||
|
||||
def __call__(self, obj, *args, **kwargs):
|
||||
args = args + tuple(kwargs.values())
|
||||
return self._executor(obj, args)
|
||||
|
|
|
@ -374,6 +374,16 @@ class Cell(Cell_):
|
|||
f"The function construct needs {positional_args} positional argument and {default_args} default "
|
||||
f"argument, but provided {len(inputs)}")
|
||||
|
||||
class CellGuard:
|
||||
def __enter__(self):
|
||||
_pynative_executor.set_lazy_build(True)
|
||||
_pynative_executor.enter_cell()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
_pynative_executor.exit_cell()
|
||||
if _pynative_executor.is_top_cell():
|
||||
_pynative_executor.set_lazy_build(False)
|
||||
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
if self.__class__.construct is Cell.construct:
|
||||
logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
|
||||
|
@ -392,7 +402,11 @@ class Cell(Cell_):
|
|||
return out
|
||||
|
||||
# Run in PyNative mode.
|
||||
if _pynative_executor.is_top_cell():
|
||||
_pynative_executor.set_lazy_build(True)
|
||||
# There many Casts in parameter_broadcast. Enable lazy_build and build faster.
|
||||
self._do_parameter_broadcast()
|
||||
|
||||
for item in inputs:
|
||||
if isinstance(item, numpy.ndarray):
|
||||
raise TypeError("The cell inputs should not be numpy arrays.")
|
||||
|
@ -407,7 +421,13 @@ class Cell(Cell_):
|
|||
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
|
||||
if not cast_inputs:
|
||||
cast_inputs = inputs
|
||||
|
||||
with self.CellGuard():
|
||||
output = self.run_construct(cast_inputs, kwargs)
|
||||
|
||||
if _pynative_executor.is_top_cell():
|
||||
_pynative_executor.execute_all_task()
|
||||
|
||||
if isinstance(output, Parameter):
|
||||
output = output.data
|
||||
_pynative_executor.end_graph(self, output, *inputs, **kwargs)
|
||||
|
|
|
@ -309,7 +309,7 @@ def test_tesnsor_augassign_by_number():
|
|||
input_tensor_1d[number_index_1] *= value_tuple_mul_ele
|
||||
with pytest.raises(ValueError):
|
||||
input_tensor_3d[number_index_1] *= value_tuple_much_ele
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(RuntimeError):
|
||||
input_tensor_1d[number_index_1] /= value_tuple_empty
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
|
Loading…
Reference in New Issue