Enable PyNative Lazy Build in Cell automatically

This commit is contained in:
caifubi 2021-09-11 09:53:06 +08:00
parent 8a51311e57
commit 9357931761
9 changed files with 85 additions and 9 deletions

View File

@ -854,21 +854,24 @@ KernelGraphPtr AscendSession::CreateKernelGraph(const GraphInfo &graph_info, OpR
return graph;
}
bool AscendSession::DisableLazyBuild(const OpRunInfo &op_run_info) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
return !op_run_info.lazy_build || ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode ||
op_run_info.is_dynamic_shape || ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
}
void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
const std::vector<int64_t> &tensors_mask) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode || op_run_info->is_dynamic_shape ||
ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
MS_EXCEPTION_IF_NULL(op_run_info);
if (DisableLazyBuild(*op_run_info)) {
session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks();
RunOpImplOrigin(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
return;
}
MS_EXCEPTION_IF_NULL(input_tensors);
MS_EXCEPTION_IF_NULL(op_run_info);
bool cache_miss = run_op_graphs_.find(graph_info) == run_op_graphs_.end();
auto graph = CreateKernelGraph(graph_info, op_run_info, input_tensors, tensors_mask, cache_miss);
EraseValueNodeTensor(tensors_mask, input_tensors);

View File

@ -154,6 +154,7 @@ class AscendSession : public SessionBasic {
KernelGraphPtr CreateKernelGraph(const GraphInfo &graph_info, OpRunInfo *op_run_info,
std::vector<tensor::TensorPtr> *input_tensors,
const std::vector<int64_t> &tensors_mask, bool cache_miss);
static bool DisableLazyBuild(const OpRunInfo &op_run_info);
// key is final_graph_id,value is child graph execute order of final graph
std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
// key is final_graph_id,value is the graph types of child graphs

View File

@ -62,6 +62,7 @@ struct OpRunInfo {
AbstractBasePtr abstract;
bool is_dynamic_shape = false;
bool is_auto_mixed_precision = false;
bool lazy_build = false;
std::string next_op_name = "";
#if defined(__APPLE__)
int next_input_index = 0;

View File

@ -63,6 +63,7 @@ struct OpExecInfo {
py::dict op_attrs;
#endif
std::vector<int64_t> inputs_mask;
bool lazy_build = false;
};
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;

View File

@ -889,6 +889,7 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
}
op_exec_info->py_primitive = prim;
op_exec_info->op_inputs = args[PY_INPUTS];
op_exec_info->lazy_build = lazy_build_;
return op_exec_info;
}
@ -1117,6 +1118,7 @@ py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type
inputs[0] = arg;
inputs[1] = dst_type;
op_exec_info->op_inputs = inputs;
op_exec_info->lazy_build = lazy_build_;
py::object ret = py::none();
RunOpInner(&ret, op_exec_info);
return ret;
@ -1861,6 +1863,7 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
op_exec_info->abstract,
op_exec_info->is_dynamic_shape,
op_exec_info->is_mixed_precision_cast,
op_exec_info->lazy_build,
op_exec_info->next_op_name,
static_cast<int>(op_exec_info->next_input_index)};
#else
@ -1869,6 +1872,7 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
op_exec_info->abstract,
op_exec_info->is_dynamic_shape,
op_exec_info->is_mixed_precision_cast,
op_exec_info->lazy_build,
op_exec_info->next_op_name,
op_exec_info->next_input_index};
#endif
@ -3055,6 +3059,8 @@ void PynativeExecutor::ClearGrad(const py::object &cell, const py::args &args) {
void PynativeExecutor::ClearRes() {
MS_LOG(DEBUG) << "Clear all res";
session::PynativeTaskManager::GetInstance().Reset();
SetLazyBuild(false);
cell_depth_ = 0;
// Maybe exit in runop step
auto ms_context = MsContext::GetInstance();
@ -3125,11 +3131,32 @@ void PynativeExecutor::Sync() {
}
}
void PynativeExecutor::SetLazyBuild(bool enable) { forward_executor()->set_lazy_build(enable); }
void PynativeExecutor::EnterCell() {
if (cell_depth_ < UINT32_MAX) {
++cell_depth_;
} else {
MS_LOG(ERROR) << "Cell call stack too deep";
}
}
void PynativeExecutor::ExitCell() {
if (cell_depth_ > 0) {
--cell_depth_;
}
}
bool PynativeExecutor::IsTopCell() const { return cell_depth_ == 0; }
void PynativeExecutor::ExecuteAllTask() { session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks(); }
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
.def("enter_cell", &PynativeExecutor::EnterCell, "enter cell.")
.def("exit_cell", &PynativeExecutor::ExitCell, "exit cell.")
.def("is_top_cell", &PynativeExecutor::IsTopCell, "check top cell.")
.def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
.def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
.def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.")
@ -3139,6 +3166,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
.def("clear_cell", &PynativeExecutor::ClearCell, "pynative clear status.")
.def("clear_grad", &PynativeExecutor::ClearGrad, "pynative clear grad status.")
.def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
.def("set_lazy_build", &PynativeExecutor::SetLazyBuild, "pynative build kernel async")
.def("execute_all_task", &PynativeExecutor::ExecuteAllTask, "clear all task")
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
.def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase")

View File

@ -314,6 +314,7 @@ class ForwardExecutor {
std::unordered_map<std::string, abstract::AbstractBasePtr> &node_abs_map() { return node_abs_map_; }
void ClearRes();
CNodePtr ConstructForwardGraph(const OpExecInfoPtr &op_exec_info);
void set_lazy_build(bool lazy_build) { lazy_build_ = lazy_build; }
private:
GradExecutorPtr grad() const;
@ -346,6 +347,7 @@ class ForwardExecutor {
PrimAbsCache prim_abs_list_;
ImplicitCastCache implicit_cast_map_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
bool lazy_build_{false};
};
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
@ -388,7 +390,11 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void ClearRes();
// Sync stream
void Sync();
void SetLazyBuild(bool enable);
void ExecuteAllTask();
void EnterCell();
void ExitCell();
bool IsTopCell() const;
private:
PynativeExecutor() = default;
@ -397,6 +403,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
static std::mutex instance_lock_;
static ForwardExecutorPtr forward_executor_;
static GradExecutorPtr grad_executor_;
uint32_t cell_depth_{0};
};
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;

View File

@ -394,6 +394,12 @@ class _PynativeExecutor:
def sync(self):
self._executor.sync()
def set_lazy_build(self, enable):
self._executor.set_lazy_build(enable)
def execute_all_task(self):
self._executor.execute_all_task()
def grad_ms_function(self, output, *args):
return self._executor.grad_ms_function(output, *args)
@ -416,6 +422,15 @@ class _PynativeExecutor:
if BROADCAST_PHASE not in phase and _get_parameter_broadcast():
_parameter_broadcast(obj, auto_parallel_mode)
def enter_cell(self):
self._executor.enter_cell()
def exit_cell(self):
self._executor.exit_cell()
def is_top_cell(self):
return self._executor.is_top_cell()
def __call__(self, obj, *args, **kwargs):
args = args + tuple(kwargs.values())
return self._executor(obj, args)

View File

@ -374,6 +374,16 @@ class Cell(Cell_):
f"The function construct needs {positional_args} positional argument and {default_args} default "
f"argument, but provided {len(inputs)}")
class CellGuard:
def __enter__(self):
_pynative_executor.set_lazy_build(True)
_pynative_executor.enter_cell()
def __exit__(self, exc_type, exc_val, exc_tb):
_pynative_executor.exit_cell()
if _pynative_executor.is_top_cell():
_pynative_executor.set_lazy_build(False)
def __call__(self, *inputs, **kwargs):
if self.__class__.construct is Cell.construct:
logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
@ -392,7 +402,11 @@ class Cell(Cell_):
return out
# Run in PyNative mode.
self._do_parameter_broadcast()
if _pynative_executor.is_top_cell():
_pynative_executor.set_lazy_build(True)
# There many Casts in parameter_broadcast. Enable lazy_build and build faster.
self._do_parameter_broadcast()
for item in inputs:
if isinstance(item, numpy.ndarray):
raise TypeError("The cell inputs should not be numpy arrays.")
@ -407,7 +421,13 @@ class Cell(Cell_):
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
if not cast_inputs:
cast_inputs = inputs
output = self.run_construct(cast_inputs, kwargs)
with self.CellGuard():
output = self.run_construct(cast_inputs, kwargs)
if _pynative_executor.is_top_cell():
_pynative_executor.execute_all_task()
if isinstance(output, Parameter):
output = output.data
_pynative_executor.end_graph(self, output, *inputs, **kwargs)

View File

@ -309,7 +309,7 @@ def test_tesnsor_augassign_by_number():
input_tensor_1d[number_index_1] *= value_tuple_mul_ele
with pytest.raises(ValueError):
input_tensor_3d[number_index_1] *= value_tuple_much_ele
with pytest.raises(ValueError):
with pytest.raises(RuntimeError):
input_tensor_1d[number_index_1] /= value_tuple_empty
with pytest.raises(ValueError):