From a69939a52fb93b790d7ab56a750963a73fc635d8 Mon Sep 17 00:00:00 2001 From: wangchangheng Date: Thu, 13 Oct 2022 16:08:14 +0800 Subject: [PATCH] front dynamic detect --- .../ccsrc/backend/graph_compiler/backend.cc | 9 +- .../backend/graph_compiler/backend_base.h | 2 +- .../ccsrc/frontend/optimizer/ad/auto_grad.h | 6 +- mindspore/ccsrc/include/common/utils/utils.h | 1 + mindspore/ccsrc/pipeline/pynative/base.h | 3 + .../pipeline/pynative/forward/do_cast.cc | 1 + .../pipeline/pynative/forward/forward.cc | 22 +- .../ccsrc/pipeline/pynative/forward/forward.h | 7 + .../ccsrc/pipeline/pynative/grad/grad.cc | 656 ++++++++++++++++-- mindspore/ccsrc/pipeline/pynative/grad/grad.h | 71 +- .../pynative/grad/ms_function_grad.cc | 85 ++- .../pipeline/pynative/grad/ms_function_grad.h | 7 + .../ccsrc/pipeline/pynative/grad/top_cell.cc | 63 ++ .../ccsrc/pipeline/pynative/grad/top_cell.h | 54 +- .../pipeline/pynative/pynative_execute.cc | 24 +- .../pipeline/pynative/pynative_execute.h | 6 +- .../runtime/graph_scheduler/graph_compiler.cc | 3 + .../runtime/graph_scheduler/graph_compiler.h | 4 +- mindspore/python/mindspore/common/api.py | 31 +- mindspore/python/mindspore/nn/cell.py | 2 + .../python/mindspore/ops/composite/base.py | 10 +- 21 files changed, 978 insertions(+), 89 deletions(-) diff --git a/mindspore/ccsrc/backend/graph_compiler/backend.cc b/mindspore/ccsrc/backend/graph_compiler/backend.cc index 7a8ba7c9706..9a4079841ce 100644 --- a/mindspore/ccsrc/backend/graph_compiler/backend.cc +++ b/mindspore/ccsrc/backend/graph_compiler/backend.cc @@ -688,6 +688,7 @@ void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_i graph_compiler_->CalculateForwardOpOutputCount(graph, inputs[graph_index], &forward_op_output_tensor_id_); } + bool use_dynamic_shape_process = root_graph_->has_flag(kFlagUseDynamicShapeProcess); py::gil_scoped_release release; for (const auto &kernel : graph->execution_order()) { InputTensorInfo input_tensor_info; @@ -714,9 +715,8 @@ void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_i GraphInfo graph_info; graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index], &input_tensor_info); - graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info, &op_run_info, &graph_info, - &graph_output_info); - bool use_dynamic_shape_process = op_run_info->base_op_run_info.use_dynamic_shape_process; + graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info, use_dynamic_shape_process, + &op_run_info, &graph_info, &graph_output_info); if (use_dynamic_shape_process) { RunOpDynamic(op_run_info, &op_outputs); } else { @@ -751,7 +751,8 @@ void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const Graph } if (contain_cut_graph || root_graph_->has_flag(kFlagIsDynamicStructure) || - (enable_backend_dynamic_detect_ && root_graph_->has_flag(kFlagIsPynativeBpropGraph) && is_dynamic)) { + (enable_backend_dynamic_detect_ && root_graph_->has_flag(kFlagIsPynativeBpropGraph) && is_dynamic) || + root_graph_->has_flag(kFlagUseDynamicShapeProcess)) { RunGraphBySingleOp(graph_compiler_info, args, outputs); } else { RunGraphByActors(actor_info, graph_compiler_info, args, outputs); diff --git a/mindspore/ccsrc/backend/graph_compiler/backend_base.h b/mindspore/ccsrc/backend/graph_compiler/backend_base.h index 7554c4da26c..7f7c7889ee7 100644 --- a/mindspore/ccsrc/backend/graph_compiler/backend_base.h +++ b/mindspore/ccsrc/backend/graph_compiler/backend_base.h @@ -149,7 +149,7 @@ class BACKEND_EXPORT MindRTBackendBase : public Backend { // Save the mapping between cell id and actor info. mindspore::HashMap graph_actor_infos_; - bool enable_backend_dynamic_detect_{true}; + bool enable_backend_dynamic_detect_{false}; FuncGraphPtr root_graph_; GraphPartitionPtr graph_partition_; std::shared_ptr graph_compiler_; diff --git a/mindspore/ccsrc/frontend/optimizer/ad/auto_grad.h b/mindspore/ccsrc/frontend/optimizer/ad/auto_grad.h index 7166a999c68..9ecf8cbb14b 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/auto_grad.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/auto_grad.h @@ -46,11 +46,11 @@ struct GradParam { : cnode(cnode), op_args(op_args), out(out), fprop_fg(std::move(fprop_fg)) {} // Primal CNode create by op forward process - const CNodePtr &cnode; + const CNodePtr cnode; // Input value for cnode - const ValuePtrList &op_args; + const ValuePtrList op_args; // Output of op - const ValuePtr &out; + const ValuePtr out; // Bprop func graph const FuncGraphPtr fprop_fg; // High order used this, which diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index cc85e5c3827..ab1f2a01171 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -907,6 +907,7 @@ constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph"; constexpr auto kFlagPyNativeRunInGraph = "pynative_run_in_graph"; constexpr auto kFlagNeedRenormalize = "need_renormalize"; constexpr auto kFlagEnableZeroCopyInGraph = "enable_zero_copy_in_graph"; +constexpr auto kFlagUseDynamicShapeProcess = "use_dynamic_shape_process"; // TODO(dsj): for ms_function running in graph_mode. should be delete later constexpr auto kAttrMSFunction = "ms_function_graph"; diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index 3f156719d69..e20aeb17aaf 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -59,6 +59,7 @@ struct FrontendOpRunInfo { bool grad_flag = false; bool output_get_by_infer_value = false; int mix_type{0}; + size_t op_index = 0; size_t input_size = 0; size_t custom_bprop_cell_count = 0; PrimitivePyPtr op_prim{nullptr}; @@ -88,6 +89,8 @@ struct InputArgsInfo { size_t input_size; std::string obj_id; bool has_sens{false}; + bool is_run_cell{false}; + bool use_dynamic_shape_process = false; PrimitivePyPtr custom_bprp_prim{nullptr}; ValuePtr out_value{nullptr}; std::string cell_id; diff --git a/mindspore/ccsrc/pipeline/pynative/forward/do_cast.cc b/mindspore/ccsrc/pipeline/pynative/forward/do_cast.cc index 60203b17b92..87c2cbe0251 100644 --- a/mindspore/ccsrc/pipeline/pynative/forward/do_cast.cc +++ b/mindspore/ccsrc/pipeline/pynative/forward/do_cast.cc @@ -274,6 +274,7 @@ ValuePtr CastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, cons cast_run_info->base_op_run_info.next_op_name = op_name; cast_run_info->base_op_run_info.next_input_index = index; cast_run_info->base_op_run_info.lazy_build = op_run_info->base_op_run_info.lazy_build; + cast_run_info->base_op_run_info.use_dynamic_shape_process = op_run_info->base_op_run_info.use_dynamic_shape_process; (void)cast_run_info->input_value.emplace_back(v); (void)cast_run_info->input_value.emplace_back(GetDstType(type_id)); cast_run_info->input_size = input_size; diff --git a/mindspore/ccsrc/pipeline/pynative/forward/forward.cc b/mindspore/ccsrc/pipeline/pynative/forward/forward.cc index ba9aea4673c..1515f734b1c 100644 --- a/mindspore/ccsrc/pipeline/pynative/forward/forward.cc +++ b/mindspore/ccsrc/pipeline/pynative/forward/forward.cc @@ -183,10 +183,21 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) { if (!op_run_info->output_get_by_infer_value) { GetOutput(op_run_info); } + if (!op_run_info->grad_flag) { + MS_LOG(DEBUG) << "Grad flag is false"; + return; + } + + // Set forward output flag for release memory, + // Because tensor address may change, it should set in main thread to ensure consistency. + PyNativeAlgo::Common::SetForwardOutputFlag(op_run_info->out_value); + + // Const value no need do op grad + if (op_run_info->output_get_by_infer_value) { + return; + } // 4. Do op grad and record op info - if (enable_async_) { - grad()->AsyncProcessOpGradInfo(op_run_info); - } else { + if (!is_ms_function_compiling_) { grad()->ProcessOpGradInfo(op_run_info); } } @@ -199,10 +210,13 @@ FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args) co // Used for async run op_run_info->grad_flag = grad()->grad_flag(); op_run_info->custom_bprop_cell_count = grad()->custom_bprop_cell_count(); + op_run_info->base_op_run_info.use_dynamic_shape_process = + (device_target_ == kAscendDevice ? false : grad()->use_dynamic_shape_process()); op_run_info->base_op_run_info.op_name = args[static_cast(RunOpArgsEnum::PY_NAME)].cast(); op_run_info->base_op_run_info.lazy_build = lazy_build_; PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast(RunOpArgsEnum::PY_PRIM)]); PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args[static_cast(RunOpArgsEnum::PY_INPUTS)]); + (void)op_run_prim_py_list_.emplace_back(op_run_info->op_prim); return op_run_info; } @@ -412,6 +426,7 @@ void ForwardExecutor::Sync() { MS_EXCEPTION_IF_NULL(item.second); item.second->SyncStream(); } + op_run_prim_py_list_.clear(); } ValuePtr ForwardExecutor::RunOpInMs(const FrontendOpRunInfoPtr &op_run_info) { @@ -466,6 +481,7 @@ void ForwardExecutor::ClearRes() { infer_operation()->ClearConstFlagPrimCache(); std::stack().swap(forward_cell_stack_); mindrt_backends_.clear(); + op_run_prim_py_list_.clear(); } } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/forward/forward.h b/mindspore/ccsrc/pipeline/pynative/forward/forward.h index 7da887790f9..6b24e93d960 100644 --- a/mindspore/ccsrc/pipeline/pynative/forward/forward.h +++ b/mindspore/ccsrc/pipeline/pynative/forward/forward.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "pipeline/pynative/forward/do_cast.h" #include "pipeline/pynative/forward/do_infer.h" #include "backend/graph_compiler/backend.h" @@ -71,6 +72,10 @@ class ForwardExecutor { MS_EXCEPTION_IF_NULL(infer_operation_); return infer_operation_; } + inline void set_is_ms_function_compiling(bool is_ms_function_compiling) { + is_ms_function_compiling_ = is_ms_function_compiling; + } + inline std::string device_target() { return device_target_; } private: GradExecutorPtr grad() const; @@ -94,6 +99,7 @@ class ForwardExecutor { private: bool init_{false}; bool lazy_build_{false}; + bool is_ms_function_compiling_{false}; uint32_t device_id_{0}; std::string last_target_{"Unknown"}; std::string device_target_; @@ -103,6 +109,7 @@ class ForwardExecutor { InferOperationPtr infer_operation_; MindrtBackendMap mindrt_backends_; bool enable_async_ = false; + mutable std::vector op_run_prim_py_list_; }; } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/grad/grad.cc b/mindspore/ccsrc/pipeline/pynative/grad/grad.cc index 8516d2b363b..4f05695eaf0 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/grad.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/grad.cc @@ -63,8 +63,8 @@ std::string GetCellId(const py::object &obj, const py::args &args, const InputAr return cell_id; } -InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args, bool is_grad_top_cell, - bool is_high_order_top_cell) { +InputArgsInfoPtr ParsePyArgsToInputArgsInfo(const py::object &obj, const py::args &args, bool is_grad_top_cell, + bool is_high_order_top_cell) { bool has_custom_bprop = py::hasattr(obj, parse::CUSTOM_BPROP_NAME); const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj); const auto &input_args_info = @@ -82,6 +82,7 @@ InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args, b } pipeline::CheckArgsValid(obj, args); } + input_args_info->is_run_cell = py::isinstance(obj); input_args_info->cell_id = GetCellId(obj, args, input_args_info); MS_LOG(DEBUG) << "cell_id is " << obj_id << ", is grad top cell " << (is_grad_top_cell || is_high_order_top_cell); return input_args_info; @@ -200,10 +201,10 @@ ForwardExecutorPtr GradExecutor::forward() const { } std::string GradExecutor::GetCurCellOrder() const { - if (input_args_info_stack_.empty()) { - MS_LOG(EXCEPTION) << "The input_args_info_stack_ is empty!"; + if (cur_cell_id_.empty()) { + MS_LOG(EXCEPTION) << "The cur_cell_id_ is empty!"; } - return input_args_info_stack_.top()->cell_id + "_" + std::to_string(cell_order_); + return cur_cell_id_ + "_" + std::to_string(cell_order_); } TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() { @@ -300,12 +301,18 @@ void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_i auto graph_info_cg = std::make_shared(); top_cell()->SetGraphInfoMap(curr_g(), graph_info_cg); HandleInputArgsForTopCell(input_args_info, false); + top_cell()->set_need_compile_graph(true); top_cell()->set_init_kpynative(true); } } +void GradExecutor::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph) const { + top_cell()->set_need_compile_graph(need_compile_graph); + top_cell()->set_forward_already_run(forward_already_run); +} + void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) { - const auto &input_args_info = GetInputArgsInfo(obj, args, input_args_info_stack_.empty(), is_high_order_top_cell()); + const auto input_args_info = GetInputArgsInfo(obj, args); PushInputArgsInfoStack(input_args_info); if (input_args_info->has_custom_bprop) { @@ -317,17 +324,21 @@ void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) { } input_args_info->grad_order = grad_order_; // May be can async here - if (enable_async_) { - AsyncNewGraphImpl(input_args_info); - } else { - NewGraphImpl(input_args_info); - } + NewGraphImpl(input_args_info); +} + +InputArgsInfoPtr GradExecutor::GetInputArgsInfo(const py::object &obj, const py::args &args) { + auto input_args_info = + ParsePyArgsToInputArgsInfo(obj, args, input_args_info_stack_.empty(), is_high_order_top_cell()); + input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_; + return input_args_info; } void GradExecutor::NewGraphImpl(const InputArgsInfoPtr &input_args_info) { MS_EXCEPTION_IF_NULL(input_args_info); ++cell_order_; const auto &cell_id = input_args_info->cell_id; + cur_cell_id_ = cell_id; MS_LOG(DEBUG) << "NewGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id << ", input args info ptr " << input_args_info.get(); // Make top graph and init resource @@ -357,12 +368,18 @@ void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) { auto fg = std::make_shared(); fg->debug_info()->set_name("pynative_forward_graph"); auto resource = std::make_shared(); - const auto &already_run_cell_id = input_args_info->cell_id + std::to_string(input_args_info->grad_order); - top_cell_ = std::make_shared(input_args_info->grad_order, input_args_info->cell_id, already_run_cell_id, - resource, fg); + const auto &already_run_cell_id = GetAlreadyRunCellId(input_args_info->cell_id); + top_cell_ = std::make_shared(input_args_info->grad_order, input_args_info->obj_id, + input_args_info->cell_id, already_run_cell_id, resource, fg); top_cell_->set_forward_already_run(true); + top_cell_->set_is_run_cell(input_args_info->is_run_cell); top_cell_->set_input_args_id(input_args_info->input_args_id); PushHighOrderGraphStack(top_cell_); + (void)top_cell_list_.emplace_back(top_cell_); + + const auto &cell_id = input_args_info->obj_id.append("_").append(std::to_string(grad_order_)); + is_cell_id_in_dynamic_detect_nodes_map_ = + (cell_id_with_dynamic_detect_nodes_.find(cell_id) != cell_id_with_dynamic_detect_nodes_.end()); MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << " resource ptr " << resource.get(); } @@ -387,7 +404,11 @@ void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string & MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr); auto sens_v = ConvertOutputValueToTensor(v); auto cloned_value = ShallowCopyTensorValue(sens_v); - auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value); + if (!MsContext::GetInstance()->get_param(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) { + AsyncUpdateOutputNodeOfTopCell(output_node, cloned_value); + } else { + auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value); + } } void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, const py::args &args) { @@ -400,16 +421,13 @@ void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, c GetCustomBpropPrim(obj, args, out, input_args_info); } input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out); + input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_; PopInputArgsInfoStack(); if (input_args_info->is_grad_topest_cell) { set_grad_flag(false); } // May be can async here - if (enable_async_) { - AsyncEndGraphImpl(input_args_info); - } else { - EndGraphImpl(input_args_info); - } + EndGraphImpl(input_args_info); } void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) { @@ -453,6 +471,7 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) { SetForwardLastNodeInfo(out_value, out_id); } top_cell()->CheckSubCellHookChanged(); + CheckNeedCompileGraph(input_args_info); top_input_args_info_ = input_args_info; } } @@ -478,7 +497,15 @@ void GradExecutor::DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info, op_run_info->input_size = input_args_info->input_arg_value_vec.size(); op_run_info->input_value_id = input_args_info->input_arg_id_vec; auto cnode = ConstructForwardGraph(op_run_info); + + if (grad_is_running_ && !bprop_grad_stack_.top().second) { + MS_LOG(DEBUG) << "Custom bprop, no need do op grad"; + return; + } DoOpGrad(op_run_info, cnode, input_args_info->out_value); + CheckGraphDynamic(cnode, top_cell()->op_index()); + top_cell()->IncreaseOpIndex(); + SaveOutputNodeMap(out_id, op_run_info, cnode); } @@ -535,6 +562,56 @@ void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &arg input_args_info->custom_bprp_prim = fake_prim; } +void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info) { + const auto &new_top_cell = top_cell(); + const auto &already_top_cell_id = new_top_cell->already_run_cell_id(); + // Update top cell by current cell op info + if (already_run_top_cell_.find(already_top_cell_id) == already_run_top_cell_.end()) { + MS_LOG(DEBUG) << "Cell " << already_top_cell_id << " has never been ran, need compile graph"; + already_run_top_cell_[already_top_cell_id] = new_top_cell; + pre_top_cell_ = top_cell(); + return; + } + + MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has been ran"; + auto pre_top_cell = already_run_top_cell_.at(already_top_cell_id); + MS_EXCEPTION_IF_NULL(pre_top_cell); + + if (input_args_info->use_dynamic_shape_process || !input_args_info->is_run_cell) { + // Function need compile every time. + MS_LOG(DEBUG) << "The graph is dynamic, need to compile graph again"; + EraseTopCellFromTopCellList(pre_top_cell); + { + py::gil_scoped_acquire acquire; + pre_top_cell->Clear(); + } + already_run_top_cell_[already_top_cell_id] = new_top_cell; + pre_top_cell_ = nullptr; + } else { + MS_LOG(DEBUG) << "no need to compile graph again"; + pre_top_cell->set_input_args_id(new_top_cell->input_args_id()); + // In high order situations, the internal top cell remains unchanged, but the external top cell has changed. Then + // the graph info of the internal top cell needs to be updated so that the external top cell can perceive it. + if (!input_args_info->is_grad_topest_cell) { + pre_top_cell->SetGraphInfoMap(pre_top_cell->fg(), new_top_cell->graph_info_map().at(new_top_cell->fg())); + } + pre_top_cell_ = pre_top_cell; + pre_top_cell->set_forward_already_run(true); + } +} + +void GradExecutor::EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell) { + MS_EXCEPTION_IF_NULL(top_cell); + auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), + [&](const TopCellInfoPtr &elem) { return elem.get() == top_cell.get(); }); + if (iter == top_cell_list_.end()) { + MS_LOG(WARNING) << "Can not find top cell " << top_cell.get() << " cell id " << top_cell->cell_id() + << " from top cell list"; + } else { + (void)top_cell_list_.erase(iter); + } +} + void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights, const py::object &grad_position, const py::args &args) { { @@ -558,7 +635,19 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob (void)top_input_args_info_->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_v)); top_input_args_info_->has_sens = true; } + if (pre_top_cell_ != nullptr) { + set_top_cell(pre_top_cell_); + } + if (!top_cell()->need_compile_graph()) { + MS_LOG(DEBUG) << "No need compile graph"; + top_cell_list_.pop_back(); + + UpdateTopCellInfo(false, false); + return; + } + MS_LOG(DEBUG) << "Need compile graph"; + top_cell()->set_grad_operation(grad_operation_); SetBpropGraphJitLevel(obj); bool weight_param_is_tuple = true; auto w_args = GetWeightsArgs(weights, &weight_param_is_tuple); @@ -568,12 +657,22 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob GetGradGraph(grad_attr, w_args, p_args); } +std::string GradExecutor::GetAlreadyRunCellId(const std::string &cell_id) const { + std::string already_run_cell_id(cell_id); + already_run_cell_id += std::to_string(grad_order_ == 0 ? 1 : grad_order_); + already_run_cell_id += "_" + grad_operation_; + MS_LOG(DEBUG) << "Get already run top cell id " << already_run_cell_id; + return already_run_cell_id; +} + void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector &w_args, const std::vector &p_args) { // Get bprop graph of top cell auto bprop_graph = GetBpropGraph(grad_attr, w_args, p_args); MS_EXCEPTION_IF_NULL(bprop_graph); bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true); + bool use_dynamic_shape_process = (forward()->device_target() == kAscendDevice ? false : use_dynamic_shape_process_); + bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process); MS_EXCEPTION_IF_NULL(top_input_args_info_); bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id)); auto resource = top_cell()->resource(); @@ -583,14 +682,13 @@ void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(bprop_graph, true); PyNativeAlgo::Common::DumpGraphIR("launch_bprop_graph.ir", bprop_graph); - if (backends_.find(top_input_args_info_->obj_id) == backends_.end()) { - backends_[top_input_args_info_->obj_id] = compile::CreateBackend(); - } - resource->SetBackendAsync([&]() { return backends_[top_input_args_info_->obj_id]; }); + SaveForwardTensorInfoInBpropGraph(resource); + resource->SetBackendAsync([]() { return compile::CreateBackend(); }); MS_LOG(DEBUG) << "Start task emit action"; (void)TaskEmitAction(resource); MS_LOG(DEBUG) << "Start execute action"; (void)ExecuteAction(resource); + UpdateTopCellInfo(false, false); resource->Clean(); } @@ -761,10 +859,18 @@ void GradExecutor::SetGradOrder(const std::string &cell_id) { } py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, - const py::args &args) { + const py::object &grad_hash_id, const py::args &args) { auto cell_id = GetCellId(obj, args, nullptr); + // Check current cell grad order and erase it if in current top cell list SetGradOrder(cell_id); + // Include weight param size and required grad flag + std::string grad_hash_id_str; + if (!py::isinstance(grad_hash_id)) { + grad_hash_id_str = std::string(py::str(grad_hash_id)); + } + grad_operation_ = std::to_string(static_cast(grad->get_all_)) + + std::to_string(static_cast(grad->get_by_list_)) + grad_hash_id_str; std::string input_args_id; for (size_t i = 0; i < args.size(); ++i) { @@ -774,8 +880,9 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con // check whether need to run forward process bool forward_run = false; if (input_args_info_stack_.empty() && top_cell_ != nullptr) { - cell_id += std::to_string(grad_order_ == 0 ? 1 : grad_order_); - if (CanGetTopCell(cell_id)) { + const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id); + auto find_top_cell = GetTopCell(check_already_run_cell_id); + if (find_top_cell != nullptr) { MS_LOG(DEBUG) << "Find already run top cell"; forward_run = top_cell()->forward_already_run(); bool input_args_changed = !top_cell()->input_args_id().empty() && top_cell()->input_args_id() != input_args_id; @@ -948,7 +1055,12 @@ void GradExecutor::ClearGradRes() { if (top_cell_ != nullptr) { top_cell_->ClearDeviceMemory(); } - top_cell_ = nullptr; + + if (use_dynamic_shape_process_ || + already_run_top_cell_.find(top_cell_->already_run_cell_id()) != already_run_top_cell_.end()) { + top_cell_ = nullptr; + } + DecreaseGradOrder(); ClearGlobalRes(); } @@ -959,13 +1071,21 @@ void GradExecutor::ClearRes() { grad_is_running_ = false; need_renormalize_ = false; eliminate_forward_ = true; + use_dynamic_shape_process_ = false; + is_cell_id_in_dynamic_detect_nodes_map_ = false; custom_bprop_cell_count_ = 0; grad_order_ = 0; top_cell_ = nullptr; top_input_args_info_ = nullptr; bprop_cell_list_.clear(); - backends_.clear(); async_executor_->Reset(); + for (const auto &cell_ptr : top_cell_list_) { + MS_EXCEPTION_IF_NULL(cell_ptr); + cell_ptr->Clear(); + } + top_cell_list_.clear(); + already_run_top_cell_.clear(); + cell_id_with_dynamic_detect_nodes_.clear(); std::stack().swap(input_args_info_stack_); std::stack>().swap(bprop_grad_stack_); std::stack().swap(high_order_stack_); @@ -1072,6 +1192,8 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str auto cnode = curr_g()->NewCNode(inputs); MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString(); top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode, -1, false); + CheckGraphDynamic(cnode, top_cell()->op_index()); + top_cell()->IncreaseOpIndex(); return cnode; } @@ -1099,10 +1221,42 @@ AnfNodePtr GradExecutor::CreateTupleGetItemNode(const std::string &obj_id, c_node->set_abstract(prim_abs); } } + CheckGraphDynamic(c_node, top_cell()->op_index()); + top_cell()->IncreaseOpIndex(); MS_LOG(DEBUG) << "Get input node " << c_node->ToString() << ", id " << obj_id; return c_node; } +TopCellInfoPtr GradExecutor::GetTopCell(const std::string &already_run_cell_id) { + TopCellInfoPtr find_top_cell = nullptr; + for (const auto &top_cell : top_cell_list_) { + MS_EXCEPTION_IF_NULL(top_cell); + // Complete match, means run grad operation first + if (top_cell->already_run_cell_id() == already_run_cell_id) { + return top_cell; + } + // Partial match, means run forward first + if (already_run_cell_id.find(top_cell->already_run_cell_id()) != std::string::npos && + top_cell->already_run_cell_id().back() == '_') { + find_top_cell = top_cell; + break; + } + } + // Same topcell info, but grad operation is not the same, construct backward graph again + if (find_top_cell != nullptr) { + if (!find_top_cell->grad_operation().empty() && find_top_cell->grad_operation() != grad_operation_) { + MS_LOG(DEBUG) << "Already exist grad operation " << find_top_cell->grad_operation() << " is different with new " + << grad_operation_; + EraseTopCellFromTopCellList(find_top_cell); + (void)already_run_top_cell_.erase(find_top_cell->already_run_cell_id()); + return nullptr; + } else { + return find_top_cell; + } + } + return nullptr; +} + void GradExecutor::SetHookChanged(const py::object &cell) const { if (top_cell_ == nullptr) { return; @@ -1118,24 +1272,19 @@ void GradExecutor::SetHookChanged(const py::object &cell) const { void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const { MS_EXCEPTION_IF_NULL(op_run_info); - if (!op_run_info->grad_flag) { - MS_LOG(DEBUG) << "Grad flag is false"; - return; - } - // Set forward output flag for release memory - PyNativeAlgo::Common::SetForwardOutputFlag(op_run_info->out_value); - - // Const value no need do op grad - if (op_run_info->output_get_by_infer_value) { - return; - } // Do op grad and save node info. If cell have custom bprop, no need do op grad. Otherwise, need do. if (op_run_info->custom_bprop_cell_count <= 0) { const auto &cnode = ConstructForwardGraph(op_run_info); MS_EXCEPTION_IF_NULL(cnode); cnode->set_abstract(op_run_info->base_op_run_info.abstract); SaveOutputNodeMap(op_run_info->out_value_id, op_run_info, cnode); + if (grad_is_running_ && !bprop_grad_stack_.top().second) { + MS_LOG(DEBUG) << "Custom bprop, no need do op grad"; + return; + } DoOpGrad(op_run_info, cnode, op_run_info->out_value); + CheckGraphDynamic(cnode, top_cell()->op_index()); + UpdateForwardTensorInfoInBpropGraph(op_run_info); } } @@ -1163,10 +1312,6 @@ void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const FrontendOp // Run ad grad for curr op and connect grad graph with previous op void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode, const ValuePtr &op_out) const { - if (grad_is_running_ && !bprop_grad_stack_.top().second) { - MS_LOG(DEBUG) << "Custom bprop, no need do op grad"; - return; - } MS_EXCEPTION_IF_NULL(op_run_info); // to avoid out exist in tape bprop, avoid out be modified. @@ -1175,14 +1320,196 @@ void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNode std::back_inserter(cloned_op_args), [](const ValuePtr &value) { return ShallowCopyTensorValue(value); }); ValuePtr cloned_out = ShallowCopyTensorValue(op_out); - std::vector tensors; - TensorValueToTensor(cloned_out, &tensors); - for (auto tensor : tensors) { - tensor->set_is_forward_output(true); - } if (!ad::GradPynativeOp(top_cell()->auto_grad_cell_ptr(), cnode, cloned_op_args, cloned_out)) { MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_run_info->base_op_run_info.op_name; } + auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr(); + if (!MsContext::GetInstance()->get_param(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) { + AsyncGradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out); + } else { + GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out); + } +} + +void GradExecutor::GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode, + const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const { + if (!ad::GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out)) { + MS_LOG(EXCEPTION) << "Failed to run ad grad for op "; + } +} + +void GradExecutor::AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode, + const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const { + const auto fn = [this, auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out]() { + this->GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out); + }; + auto task = std::make_shared(fn); + async_executor_->Push(task); +} + +void GradExecutor::AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &cloned_value) const { + auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr(); + MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr); + const auto fn = [auto_grad_cell_ptr, output_node, cloned_value]() { + auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value); + }; + auto task = std::make_shared(fn); + async_executor_->Push(task); +} + +void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const { + MS_EXCEPTION_IF_NULL(op_run_info); + if (op_run_info->base_op_run_info.use_dynamic_shape_process) { + MS_LOG(DEBUG) << "Get dynamic shape process"; + return; + } + top_cell()->GetOpInfo(op_run_info); + MS_LOG(DEBUG) << "Current op info: " << op_run_info->op_info; + + std::vector op_output_tensors; + // Get output tensors + TensorValueToTensor(op_run_info->out_value, &op_output_tensors); + // Save all tensors info of current op + top_cell()->set_opinfo_with_tensor_id(op_run_info->op_info, op_output_tensors); + + // First run top cell + if (already_run_top_cell_.find(top_cell_->already_run_cell_id()) == already_run_top_cell_.end()) { + MS_LOG(DEBUG) << "Top cell " << top_cell_->cell_id() << " run firstly"; + return; + } + // Non-first run + const auto &pre_top_cell = already_run_top_cell_.at(top_cell_->already_run_cell_id()); + MS_EXCEPTION_IF_NULL(pre_top_cell); + if (pre_top_cell->op_info_with_tensor_id().find(op_run_info->op_info) == + pre_top_cell->op_info_with_tensor_id().end()) { + MS_LOG(DEBUG) << "Can not find op info " << op_run_info->op_info << " in op info with tensor id map. Top cell " + << top_cell_->cell_id(); + return; + } + + // Update new output tensor info in bprop graph + const auto &pre_op_tensor_id = pre_top_cell->op_info_with_tensor_id().at(op_run_info->op_info); + if (pre_op_tensor_id.size() != op_output_tensors.size()) { + MS_LOG(EXCEPTION) << "The size of op pre output tensor size: " << pre_op_tensor_id.size() + << " is not equal to current " << op_output_tensors.size(); + } + // For value node tensor in the bprop graph, take its id for tensor, and save in tensor_id_with_tensor_object; + // And then take the output of the op and find out if the output used by tensor_id_with_tensor_object, + // if there is a tensor need to replace it. + const auto &pre_tensor_id_with_tensor_object = pre_top_cell->tensor_id_with_tensor_object(); + for (size_t i = 0; i < pre_op_tensor_id.size(); ++i) { + auto pre_id = pre_op_tensor_id[i]; + if (pre_tensor_id_with_tensor_object.find(pre_id) == pre_tensor_id_with_tensor_object.end()) { + continue; + } + // Based on the output size of the op is fixed, so can use index. + const auto &new_tensor = op_output_tensors[i]; + const auto &pre_tensor_object = pre_tensor_id_with_tensor_object.at(pre_id); + UpdatePreTensorInfo(new_tensor, pre_tensor_object); + } +} + +void GradExecutor::UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor, + const std::vector &pre_tensors) const { + MS_EXCEPTION_IF_NULL(new_tensor); + if (pre_tensors.empty() || new_tensor->device_address() == nullptr) { + MS_LOG(DEBUG) << "The number of pre tensors is zero or the device address of new tensor is nullptr."; + return; + } + const auto &device_target = MsContext::GetInstance()->get_param(MS_CTX_DEVICE_TARGET); + for (auto &pre_tensor : pre_tensors) { + MS_EXCEPTION_IF_NULL(pre_tensor); + MS_LOG(DEBUG) << "Replace Old tensor id " << pre_tensor->id() << " device_address: " << pre_tensor->device_address() + << " shape and type " << pre_tensor->GetShapeAndDataTypeInfo() << " with New tensor id " + << new_tensor->id() << " device_address " << new_tensor->device_address() << " shape and dtype " + << new_tensor->GetShapeAndDataTypeInfo(); + (void)pre_tensor->set_shape(new_tensor->shape()); + (void)pre_tensor->set_data_type(new_tensor->data_type()); + auto device_address = std::dynamic_pointer_cast(new_tensor->device_address()); + MS_EXCEPTION_IF_NULL(device_address); + if (device_target != kCPUDevice && device_address->GetDeviceType() != device::DeviceType::kCPU) { + pre_tensor->set_device_address(new_tensor->device_address()); + continue; + } + for (const auto &item : PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->mindrt_backend()) { + MS_EXCEPTION_IF_NULL(item.second); + item.second->WaitTaskFinish(); + } + // Replace data in device address when run in CPU device. + if (pre_tensor->device_address() != nullptr) { + // If tensor is dynamic shape, Just replace device address. + if (PyNativeAlgo::Common::ValueHasDynamicShape(pre_tensor)) { + pre_tensor->set_device_address(new_tensor->device_address()); + continue; + } + auto old_device_address = std::dynamic_pointer_cast(pre_tensor->device_address()); + MS_EXCEPTION_IF_NULL(old_device_address); + auto new_device_address = std::dynamic_pointer_cast(new_tensor->device_address()); + MS_EXCEPTION_IF_NULL(new_device_address); + + // CPU host tensor data_c is different from device address if the address is from mem_pool. + if (new_device_address->from_mem_pool()) { + pre_tensor->set_device_address(new_device_address); + continue; + } + + auto old_ptr = old_device_address->GetMutablePtr(); + MS_EXCEPTION_IF_NULL(old_ptr); + auto new_ptr = new_device_address->GetPtr(); + MS_EXCEPTION_IF_NULL(new_ptr); + MS_EXCEPTION_IF_CHECK_FAIL(old_device_address->GetSize() == new_device_address->GetSize(), "Size not equal"); + if (old_device_address->GetSize() < SECUREC_MEM_MAX_LEN) { + auto ret_code = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize()); + MS_EXCEPTION_IF_CHECK_FAIL(ret_code == EOK, "Memory copy failed, ret code: " + std::to_string(ret_code)); + } else { + auto ret_code = std::memcpy(old_ptr, new_ptr, old_device_address->GetSize()); + MS_EXCEPTION_IF_CHECK_FAIL(ret_code == old_ptr, "Memory copy failed"); + } + } else { + pre_tensor->set_device_address(device_address); + pre_tensor->data_sync(); + pre_tensor->set_device_address(nullptr); + pre_tensor->set_sync_status(kNeedSyncHostToDevice); + } + } +} + +void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const { + if (use_dynamic_shape_process_) { + return; + } + // Get all tensors id of forward op + mindspore::HashSet forward_op_tensor_id; + const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id(); + for (const auto &record : op_info_with_tensor_id) { + (void)std::for_each( + record.second.begin(), record.second.end(), + [&forward_op_tensor_id](const std::string &tensor_id) { (void)forward_op_tensor_id.emplace(tensor_id); }); + } + // Get all tensors obj in value node of bprop graph + MS_EXCEPTION_IF_NULL(resource); + const auto &bprop_graph = resource->func_graph(); + MS_EXCEPTION_IF_NULL(bprop_graph); + const auto &value_node_list = bprop_graph->value_nodes(); + std::vector tensors_in_bprop_graph; + for (const auto &elem : value_node_list) { + auto value_node = elem.first->cast(); + MS_EXCEPTION_IF_NULL(value_node); + TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph); + } + + // Save tensor in value node of bprop graph + for (const auto &tensor : tensors_in_bprop_graph) { + MS_EXCEPTION_IF_NULL(tensor); + if (forward_op_tensor_id.find(tensor->id()) == forward_op_tensor_id.end() || tensor->device_address() == nullptr) { + continue; + } + tensor->set_is_forward_output(true); + top_cell()->set_tensor_id_with_tensor_object(tensor->id(), tensor); + MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id() + << " device address: " << tensor->device_address() << " shape and dtype " + << tensor->GetShapeAndDataTypeInfo(); + } } AnfNodePtr GradExecutor::GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const { @@ -1241,6 +1568,7 @@ CNodePtr GradExecutor::ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_ if (IsPrimitiveCNode(cnode, prim::kPrimCellBackwardHook)) { top_cell()->RecordCellBackwardHookOp(GetCurCellOrder(), cnode); } + MS_LOG(DEBUG) << "Make CNode for " << op_run_info->base_op_run_info.op_name << ", new cnode is " << cnode->DebugString(); return cnode; @@ -1260,5 +1588,235 @@ void GradExecutor::SetBpropGraphJitLevel(const py::object &obj) const { MS_EXCEPTION_IF_NULL(graph_executor); graph_executor->SetJitConfig(jit_config_dict); } + +void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t &node_idx, + bool is_ms_function_node, + const std::string &graph_phase) const { + MS_EXCEPTION_IF_NULL(cnode); + auto node_info = std::make_shared(); + if (!is_ms_function_node) { + node_info->prim = GetCNodePrimitive(cnode); + for (size_t i = 1; i < cnode->inputs().size(); i++) { + const auto &input_node = cnode->input(i); + MS_EXCEPTION_IF_NULL(input_node); + + if (input_node->isa()) { + node_info->input_values[i] = GetValueNode(input_node); + } else if (input_node->isa()) { + const auto &node_abs = input_node->abstract(); + auto op_index = top_cell()->get_op_index_by_cnode_hash(input_node->hash()); + node_info->input_cnode_info[i] = std::make_pair(op_index, node_abs); + } else { + if (!input_node->isa()) { + MS_LOG(EXCEPTION) << "input_node:" << input_node->fullname_with_scope() + << " is none of value node, cnode and parameter."; + } + const auto ¶m = input_node->cast(); + MS_EXCEPTION_IF_NULL(param); + node_info->input_param_infos[i] = param->param_info(); + } + } + node_info->output_abs = cnode->abstract(); + } else { + node_info->is_graph_node = true; + node_info->graph_phase = graph_phase; + } + top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx); + const auto &cell_id = top_cell()->c_cell_id() + "_" + std::to_string(top_cell()->grad_order()); + (void)cell_id_with_dynamic_detect_nodes_[cell_id].emplace_back(node_info); +} + +bool IsAbsDifferent(const AbstractBasePtr &old_abs, const AbstractBasePtr &new_abs) { + if (old_abs == new_abs) { + return false; + } + + if (old_abs == nullptr || new_abs == nullptr) { + MS_LOG(DEBUG) << "graph is dynamic, old_abs is different with new_abs"; + return true; + } + + if (!common::IsEqual(old_abs->BuildType(), new_abs->BuildType()) || + !common::IsEqual(old_abs->BuildShape(), new_abs->BuildShape())) { + MS_LOG(DEBUG) << "graph is dynamic, old_abs is different with new_abs, old abs:" << old_abs->ToString() + << " new abs:" << new_abs->ToString(); + return true; + } + return false; +} + +bool IsValuePtrEqual(const ValuePtr &v1, const ValuePtr &v2) { + if (v1 == v2) { + return true; + } + if (v1 == nullptr || v2 == nullptr) { + return false; + } + + if (v1->isa() && v2->isa()) { + return v1->cast()->ValueEqual(*(v2->cast())); + } + return *v1 == *v2; +} + +bool IsParamInfoEqual(const ParamInfoPtr &p1, const ParamInfoPtr &p2) { + if (p1 == p2) { + return true; + } + if (p1 == nullptr || p2 == nullptr) { + return false; + } + + return p1->key() == p2->key(); +} + +bool GradExecutor::IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, + const std::vector &new_anf_inputs) const { + MS_EXCEPTION_IF_NULL(old_node_info); + + auto old_input_size = old_node_info->input_cnode_info.size() + old_node_info->input_values.size() + + old_node_info->input_param_infos.size(); + if (old_input_size != new_anf_inputs.size() - 1) { + MS_LOG(DEBUG) << "graph is dynamic, old input size:" << old_input_size + << " new input_infos:" << (new_anf_inputs.size() - 1); + return true; + } + + for (size_t i = 1; i < new_anf_inputs.size(); i++) { + const auto &new_anf_input = new_anf_inputs[i]; + MS_EXCEPTION_IF_NULL(new_anf_input); + if (new_anf_input->isa()) { + const auto &value_iter = old_node_info->input_values.find(i); + if (value_iter == old_node_info->input_values.end()) { + MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a value, old input is not a value."; + return true; + } + + if (!IsValuePtrEqual(value_iter->second, GetValueNode(new_anf_input))) { + MS_LOG(DEBUG) << "The " << i << "th input, value is different."; + return true; + } + } else if (new_anf_input->isa()) { + // Compare cnode abstract. + const auto &node_iter = old_node_info->input_cnode_info.find(i); + if (node_iter == old_node_info->input_cnode_info.end()) { + MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a cnode, old input is not a cnode."; + return true; + } + + size_t old_op_index = 0; + AbstractBasePtr old_abs = nullptr; + std::tie(old_op_index, old_abs) = node_iter->second; + if (IsAbsDifferent(old_abs, new_anf_input->abstract())) { + MS_LOG(DEBUG) << "The " << i << "th input, abs is different."; + return true; + } + + // Compare cnode edge. + if (old_op_index != top_cell()->get_op_index_by_cnode_hash(new_anf_input->hash())) { + MS_LOG(DEBUG) << "The " << i << "th input, op_index is different, old op_index:" << old_op_index + << " new op_index:" << top_cell()->get_op_index_by_cnode_hash(new_anf_input->hash()); + return true; + } + } else { + // Compare parameter. + if (!new_anf_input->isa()) { + MS_LOG(EXCEPTION) << "new_anf_input:" << new_anf_input->fullname_with_scope() + << " is none of value node, cnode and parameter."; + } + + const auto &node_iter = old_node_info->input_param_infos.find(i); + if (node_iter == old_node_info->input_param_infos.end()) { + MS_LOG(DEBUG) << "The " << i + << "th input is different, cur input is a parameter, old input is not a parameter."; + return true; + } + + const auto ¶m = new_anf_input->cast(); + MS_EXCEPTION_IF_NULL(param); + if (!IsParamInfoEqual(node_iter->second, param->param_info())) { + MS_LOG(DEBUG) << "The " << i << "th input, param info is different."; + return true; + } + } + } + + return false; +} + +bool GradExecutor::IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, + const CNodePtr &new_cnode, bool is_ms_function_node, + const std::string &graph_phase) const { + MS_EXCEPTION_IF_NULL(new_cnode); + MS_EXCEPTION_IF_NULL(old_node_info); + + // 1.Detect ms_function phase + if (is_ms_function_node != old_node_info->is_graph_node || + (is_ms_function_node && graph_phase != old_node_info->graph_phase)) { + MS_LOG(DEBUG) << "graph is dynamic, old is_graph_node:" << old_node_info->is_graph_node + << " new is_graph_node:" << is_ms_function_node << " old graph_phase" << old_node_info->graph_phase + << " new graph_phase:" << graph_phase; + return true; + } + + // 2.Detect cnode prim + auto new_prim = GetCNodePrimitive(new_cnode); + if (!common::IsEqual(new_prim, old_node_info->prim)) { + MS_LOG(DEBUG) << "graph is dynamic, old prim:" << (old_node_info->prim == nullptr ? 0 : old_node_info->prim->name()) + << " new prim:" << (new_prim == nullptr ? 0 : new_prim->name()); + return true; + } + + // 3.Detect output abs + if (IsAbsDifferent(old_node_info->output_abs, new_cnode->abstract())) { + MS_LOG(DEBUG) << "graph is dynamic, output_abs is different"; + return true; + } + + // 4.Detect inputs + return IsCnodeInputsDynamic(old_node_info, new_cnode->inputs()); +} + +bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node, + const std::string &graph_phase) const { + MS_EXCEPTION_IF_NULL(cnode); + if (!is_cell_id_in_dynamic_detect_nodes_map_) { + SaveDynamicDetectNodeInfoInFirstTime(cnode, node_idx, is_ms_function_node, graph_phase); + // The net is regarded as a static net by default in the first time. + return false; + } + + const auto &cell_id = top_cell()->c_cell_id() + "_" + std::to_string(top_cell()->grad_order()); + const auto &dynamic_nodes = cell_id_with_dynamic_detect_nodes_[cell_id]; + if (node_idx >= dynamic_nodes.size()) { + MS_LOG(DEBUG) << "old dynamic_nodes size:" << dynamic_nodes.size() << " cur node_idx is:" << node_idx + << ", graph is dynamic."; + return true; + } + + if (IsDynamicDetectNodeInfoChange(dynamic_nodes[node_idx], cnode, is_ms_function_node, graph_phase)) { + MS_LOG(DEBUG) << "graph is dynamic, node_idx:" << node_idx + << " is different, cnode:" << cnode->fullname_with_scope(); + return true; + } + top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx); + + return false; +} + +void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node, + const std::string &graph_phase) const { + if (!top_cell()->is_run_cell() || use_dynamic_shape_process_) { + return; + } + + use_dynamic_shape_process_ = IsGraphDynamic(cnode, node_idx, is_ms_function_node, graph_phase); + if (use_dynamic_shape_process_) { + MS_LOG(DEBUG) << "cnode:" << cnode->fullname_with_scope() << ",node_idx:" << node_idx + << ",is_ms_function_node:" << is_ms_function_node << ",graph_phase:" << graph_phase + << ",use_dynamic_shape_process_:" << use_dynamic_shape_process_; + cell_id_with_dynamic_detect_nodes_.clear(); + } +} } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/grad/grad.h b/mindspore/ccsrc/pipeline/pynative/grad/grad.h index f5b2af1dfaf..5dd56ade1cd 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/grad.h +++ b/mindspore/ccsrc/pipeline/pynative/grad/grad.h @@ -36,6 +36,17 @@ class ForwardExecutor; using ForwardExecutorPtr = std::shared_ptr; using ForwardExecutorWeakPtr = std::weak_ptr; +struct DynamicDetectNodeInfo { + PrimitivePtr prim{nullptr}; + AbstractBasePtr output_abs{nullptr}; + bool is_graph_node{false}; + std::string graph_phase; + mindspore::HashMap> input_cnode_info; + mindspore::HashMap input_values; + mindspore::HashMap input_param_infos; +}; +using DynamicDetectNodeInfoPtr = std::shared_ptr; + class GradExecutor { public: GradExecutor() = default; @@ -43,8 +54,7 @@ class GradExecutor { explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr) : forward_executor_(ForwardExecutorWeakPtr(forward_executor)), ms_function_(std::make_shared()), - async_executor_(std::make_unique()), - enable_async_(std::getenv("ENABLE_ASYNC")) {} + async_executor_(std::make_shared()) {} std::function InitGraph = [this](auto &&PH1, auto &&PH2) { NewGraphInner(std::forward(PH1), std::forward(PH2)); @@ -69,6 +79,10 @@ class GradExecutor { MS_EXCEPTION_IF_NULL(ms_function_); return ms_function_; } + inline void set_use_dynamic_shape_process(bool use_dynamic_shape_process) { + use_dynamic_shape_process_ = use_dynamic_shape_process; + } + inline bool need_renormalize() const { return need_renormalize_; } inline void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); } inline bool grad_flag() const { return grad_flag_; } @@ -77,12 +91,16 @@ class GradExecutor { inline bool eliminate_forward() const { return eliminate_forward_; } inline void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; } inline size_t custom_bprop_cell_count() const { return custom_bprop_cell_count_; } + inline bool use_dynamic_shape_process() const { return use_dynamic_shape_process_; } + inline std::shared_ptr async_executor() const { return async_executor_; } void SetHookChanged(const py::object &cell) const; void GradNetInner(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights, const py::object &grad_position, const py::args &args); py::object RunGradGraph(); CNodePtr ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const; - py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args); + py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id, + const py::args &args); + TopCellInfoPtr GetTopCell(const std::string &already_run_cell_id); void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const; void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const; void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args); @@ -90,25 +108,34 @@ class GradExecutor { AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const; void AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info); AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const; + void UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const; + void UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor, + const std::vector &pre_tensors) const; void ClearRes(); void WorkerJoin() { async_executor_->WorkerJoin(); } + void CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node = false, + const std::string &graph_phase = "") const; + private: ForwardExecutorPtr forward() const; inline FuncGraphPtr curr_g() const { return top_cell()->fg(); } inline void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell) { high_order_stack_.push(top_cell); } - inline bool CanGetTopCell(const string &already_run_cell_id) { - return already_run_cell_id.find(top_cell()->already_run_cell_id()) != std::string::npos; - } std::string GetCurCellOrder() const; void SetGradOrder(const std::string &cell_id); void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode) const; void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode, const ValuePtr &op_out) const; + void GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode, + const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const; + void AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode, + const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const; + void AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &cloned_value) const; AnfNodePtr GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const; void SetBpropGraphJitLevel(const py::object &obj) const; void ClearGlobalRes(); void ClearGradRes(); + std::string GetAlreadyRunCellId(const std::string &cell_id) const; // Higher derivative inline bool IsNestedGrad() const { return grad_order_ > 1; } @@ -121,6 +148,7 @@ class GradExecutor { inline bool is_high_order_top_cell() const { return !input_args_info_stack_.empty() && IsNestedGrad() && top_cell()->grad_order() != grad_order_; } + void SwitchTopCell(); void DoParameterReplace(const FuncGraphPtr &first_grad_fg, const std::vector &forward_args, std::vector *inputs, ValuePtrList *weights_args); @@ -132,15 +160,20 @@ class GradExecutor { void HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info, bool is_bprop_top) const; void InitResourceAndDfBuilder(const InputArgsInfoPtr &cell_info); void MakeNewTopGraph(const InputArgsInfoPtr &input_args_info); + void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph) const; + // Manage resource when run grad process. bool IsBpropGraph(const std::string &cell_id) const; void NewGraphInner(const py::object &obj, const py::args &args); + InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args); void NewGraphImpl(const InputArgsInfoPtr &input_args_info); void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info); void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const; void GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out, const InputArgsInfoPtr &input_args_info); void DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info, const std::string &out_id); + void CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info); + void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell); void GetGradGraph(const ad::GradAttr &grad_attr, const std::vector &w_args, const std::vector &p_args); FuncGraphPtr GetBpropGraph(const ad::GradAttr &grad_attr, const vector &w_args, @@ -151,22 +184,38 @@ class GradExecutor { const abstract::AbstractBasePtr ¶m_tensor_abs, const std::string &input_shape); void UpdateParamAbsByArgs(const std::vector &input_args, const FuncGraphPtr &bprop_graph, bool has_sens); std::vector GetGradPositionArgs(const py::object &grad_position, bool get_by_position) const; + void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const; // Manage resource for construct forward graph. AnfNodePtr GetOutputNodeAsInput(const std::string &obj_id) const; AnfNodePtr GetValueSequenceInput(const ValuePtr &v, const std::string &obj_id) const; AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id, const std::pair> &out) const; + void SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node, + const std::string &graph_phase) const; + bool IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node, + const std::string &graph_phase) const; + bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, + const std::vector &new_anf_inputs) const; + bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, const CNodePtr &new_cnode, + bool is_ms_function_node, const std::string &graph_phase) const; bool grad_flag_{false}; bool grad_is_running_{false}; bool need_renormalize_{false}; bool eliminate_forward_{true}; + mutable bool use_dynamic_shape_process_{false}; + mutable bool is_cell_id_in_dynamic_detect_nodes_map_{false}; int custom_bprop_cell_count_{0}; + + // Used in sub thread size_t cell_order_{0}; + std::string cur_cell_id_{""}; + // If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ... size_t grad_order_{0}; - + std::string grad_operation_; TopCellInfoPtr top_cell_{nullptr}; + TopCellInfoPtr pre_top_cell_{nullptr}; InputArgsInfoPtr top_input_args_info_{nullptr}; // Records every cell info for share, regardless of whether need construct grad graph std::stack input_args_info_stack_; @@ -175,11 +224,13 @@ class GradExecutor { std::vector bprop_cell_list_; // For high grad order std::stack high_order_stack_; + std::vector top_cell_list_; + // Record all top cell which has been ran + mindspore::HashMap already_run_top_cell_; ForwardExecutorWeakPtr forward_executor_; MsFunctionPtr ms_function_; - std::unique_ptr async_executor_; - std::map backends_; - bool enable_async_ = false; + std::shared_ptr async_executor_; + mutable mindspore::HashMap> cell_id_with_dynamic_detect_nodes_; }; } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/grad/ms_function_grad.cc b/mindspore/ccsrc/pipeline/pynative/grad/ms_function_grad.cc index 7b91e488ed9..0b879759892 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/ms_function_grad.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/ms_function_grad.cc @@ -19,6 +19,8 @@ #include "include/common/utils/anfalgo.h" #include "include/common/utils/parallel_context.h" #include "ir/func_graph_cloner.h" +#include "runtime/pynative/async/async_queue.h" +#include "pipeline/pynative/grad/bprop_task.h" namespace mindspore { namespace pynative { @@ -151,6 +153,35 @@ void MsFunction::ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, co RunReplace(added_make_tuple, total_output_tensors, grad_graph); } +void MsFunction::UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const string &op_info, + const ValuePtr &new_forward_value) const { + if (grad_executor->use_dynamic_shape_process()) { + MS_LOG(DEBUG) << "Get dynamic shape process"; + return; + } + MS_EXCEPTION_IF_NULL(new_forward_value); + MS_LOG(DEBUG) << "Ms func graph has already ran before. The graph phase is: " << graph_phase_; + MS_LOG(DEBUG) << "The output values of added forward nodes are: " << new_forward_value->ToString(); + std::vector new_tensors; + TensorValueToTensor(new_forward_value, &new_tensors); + if (new_tensors.empty()) { + MS_LOG(DEBUG) << "The size of added forward tensors is zero, no need to update."; + return; + } + MS_EXCEPTION_IF_NULL(grad_executor); + const auto &top_cell = grad_executor->top_cell(); + const auto &old_tensors = top_cell->op_info_with_ms_func_forward_tensors().at(op_info); + if (old_tensors.size() != new_tensors.size()) { + MS_LOG(EXCEPTION) << "The size of old tensors is: " << old_tensors.size() + << ", but the size of new tensors is: " << new_tensors.size() + << ", the current op info is: " << op_info; + } + for (size_t i = 0; i < new_tensors.size(); ++i) { + grad_executor->UpdatePreTensorInfo(new_tensors[i], {old_tensors[i]}); + old_tensors[i]->set_sync_status(kNeedSyncDeviceToHost); + } +} + void MsFunction::GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNodePtrList *input_nodes, const GradExecutor *grad_executor) const { MS_EXCEPTION_IF_NULL(op_run_info); @@ -213,6 +244,7 @@ void MsFunction::GetWeightsNode(const FrontendOpRunInfoPtr &op_run_info, const G void MsFunction::MakeCNodeForMsFunction(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor, const FuncGraphPtr &ms_func_graph, CNodePtr *ms_function_cnode) const { + MS_EXCEPTION_IF_NULL(op_run_info); // Get input node info of ms_function std::vector input_nodes{NewValueNode(ms_func_graph)}; MS_EXCEPTION_IF_NULL(grad_executor); @@ -222,6 +254,7 @@ void MsFunction::MakeCNodeForMsFunction(const FrontendOpRunInfoPtr &op_run_info, // Make a CNode which includes ms_function fprop graph and inputs node MS_EXCEPTION_IF_NULL(ms_function_cnode); *ms_function_cnode = grad_executor->top_cell()->fg()->NewCNode(input_nodes); + MS_LOG(DEBUG) << "Make ms function forward CNode: " << (*ms_function_cnode)->DebugString(); } @@ -242,6 +275,10 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr); auto grad_param = std::make_shared(ms_function_cnode, op_run_info->input_value, op_run_info->out_value, grad_graph); + { + py::gil_scoped_release gil_release; + grad_executor->async_executor()->Wait(); + } if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) { MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: " << ms_function_cnode->DebugString(); @@ -250,21 +287,55 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run return ms_function_cnode; } +void MsFunction::AsyncKPynativeWithFProp(const GradExecutor *grad_executor, + const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, + const ad::GradParamPtr &grad_param) const { + MS_EXCEPTION_IF_NULL(grad_executor); + + const auto fn = [this, grad_param, auto_grad_cell_ptr]() { + MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr); + if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) { + MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode"; + } + }; + auto task = std::make_shared(fn); + grad_executor->async_executor()->Push(task); +} + +void MsFunction::AsyncGradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor, + const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph, + const FuncGraphPtr &grad_graph) const { + const auto fn = [this, op_run_info, grad_executor, added_out_v, ms_func_graph, grad_graph]() { + this->GradMsFunctionInner(op_run_info, grad_executor, added_out_v, ms_func_graph, grad_graph); + }; + auto task = std::make_shared(fn); + grad_executor->async_executor()->Push(task); +} + void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor, const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const { MS_EXCEPTION_IF_NULL(op_run_info); MS_EXCEPTION_IF_NULL(grad_executor); MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString(); - if (!grad_executor->grad_flag()) { - MS_LOG(EXCEPTION) << "The flag of need construct graph is False."; + // Step 1: Update actual output tensors used in grad graph. + MS_EXCEPTION_IF_NULL(op_run_info->out_value); + MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString(); + // The output of ms_function may be used in subsequent PyNative process + grad_executor->UpdateForwardTensorInfoInBpropGraph(op_run_info); + + // Step 2: Update output tensors of added forward nodes, which are added to return node of ms_function func graph. + if (grad_executor->top_cell()->op_info_with_ms_func_forward_tensors().find(op_run_info->op_info) != + grad_executor->top_cell()->op_info_with_ms_func_forward_tensors().end()) { + UpdateMsFunctionForwardTensors(grad_executor, op_run_info->op_info, added_out_v); } - // Update actual output tensors used in grad graph. + ReplaceNewTensorsInGradGraph(grad_executor->top_cell(), added_out_v, ms_func_graph, grad_graph); // Clone new ms_function func graph and grad graph. auto new_ms_func_graph = BasicClone(ms_func_graph); auto new_grad_graph = BasicClone(grad_graph, true); + auto new_make_tuple = new_ms_func_graph->output()->cast(); MS_EXCEPTION_IF_NULL(new_make_tuple); new_ms_func_graph->set_output(new_make_tuple->input(1)); @@ -273,6 +344,11 @@ void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, co const auto &ms_function_cnode = MakeAdjointForMsFunction(op_run_info, grad_executor, new_ms_func_graph, new_grad_graph); ms_function_cnode->set_abstract(new_ms_func_graph->output()->abstract()->Broaden()); + + auto grad_exec_ptr = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor(); + MS_EXCEPTION_IF_NULL(grad_exec_ptr); + grad_exec_ptr->CheckGraphDynamic(ms_function_cnode, op_run_info->op_index, true, + op_run_info->base_op_run_info.op_name); } void MsFunction::SetMsFuncGraphParameters(const FuncGraphPtr &ms_func_graph) { @@ -316,6 +392,9 @@ py::object MsFunction::GradMsFunction(const py::object &out, const py::args &arg const auto &op_run_info = GetOpRunInfo(out, args, graph_phase_, &added_out_v); FuncGraphPtr grad_graph = executor->GetGradGraph(graph_phase_); PyNativeAlgo::Common::DumpGraphIR("ms_func_forward_graph.ir", ms_func_graph); + if (!grad_executor->grad_flag()) { + MS_LOG(EXCEPTION) << "The flag of need construct graph is False."; + } GradMsFunctionInner(op_run_info, grad_executor.get(), added_out_v, ms_func_graph, grad_graph); SetMsFuncGraphParameters(ms_func_graph); graph_phase_.clear(); diff --git a/mindspore/ccsrc/pipeline/pynative/grad/ms_function_grad.h b/mindspore/ccsrc/pipeline/pynative/grad/ms_function_grad.h index 8612bcc2148..9e14541e21e 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/ms_function_grad.h +++ b/mindspore/ccsrc/pipeline/pynative/grad/ms_function_grad.h @@ -42,11 +42,18 @@ class MsFunction { void GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor, const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const; + void AsyncGradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor, + const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph, + const FuncGraphPtr &grad_graph) const; + void AsyncKPynativeWithFProp(const GradExecutor *grad_executor, const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, + const ad::GradParamPtr &grad_param) const; // Update device address of value node in grad graph by forward tensors. void RunReplace(const CNodePtr &added_make_tuple, const std::vector &total_output_tensors, const FuncGraphPtr &grad_graph) const; void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const ValuePtr &added_out, const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const; + void UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const string &op_info, + const ValuePtr &new_forward_value) const; // Make CNode for ms_function forward graph. void GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNodePtrList *input_nodes, const GradExecutor *grad_executor) const; diff --git a/mindspore/ccsrc/pipeline/pynative/grad/top_cell.cc b/mindspore/ccsrc/pipeline/pynative/grad/top_cell.cc index 5947398f49e..bf7616d7de0 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/top_cell.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/top_cell.cc @@ -57,6 +57,40 @@ void TopCellInfo::RecordCellBackwardHookOp(const std::string &cell_order, const } } +void TopCellInfo::GetOpInfo(const FrontendOpRunInfoPtr &op_run_info) { + MS_EXCEPTION_IF_NULL(op_run_info); + std::string input_args_info; + // Record input args info (weight or data) + // self.p = Parameter(); + // def construct(x, y) + // if y: + // x = x + x + // else: + // x = x + self.p + // return x + for (size_t i = 0; i < op_run_info->base_op_run_info.input_tensor.size(); i++) { + const auto &t = op_run_info->base_op_run_info.input_tensor[i]; + MS_EXCEPTION_IF_NULL(t); + if (t->is_parameter() && t->param_info() != nullptr && t->param_info()->requires_grad()) { + input_args_info += "w"; + } else { + input_args_info += "d"; + } + } + // Record op name and index + op_run_info->op_info.clear(); + op_run_info->op_info += + op_run_info->base_op_run_info.op_name + "-" + std::to_string(op_index_) + "-" + input_args_info; + const auto &out_abs = op_run_info->base_op_run_info.abstract; + auto shape = out_abs->BuildShape(); + MS_EXCEPTION_IF_NULL(shape); + if (!shape->isa() && !shape->IsDimZero()) { + op_run_info->op_info += "-" + shape->ToString(); + } + op_run_info->op_index = op_index_; + ++op_index_; +} + void TopCellInfo::ClearDeviceMemory() const { MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_; auto ms_context = MsContext::GetInstance(); @@ -154,11 +188,40 @@ void TopCellInfo::SetNestedMultipleOutputToGraphInfoMap(const string &id, const } } +void TopCellInfo::Clear() { + MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_; + hook_changed_ = false; + ms_function_flag_ = false; + is_init_kpynative_ = false; + need_compile_graph_ = false; + forward_already_run_ = false; + op_index_ = 0; + resource_ = nullptr; + fg_ = nullptr; + graph_info_map_.clear(); + op_info_with_tensor_id_.clear(); + tensor_id_with_tensor_object_.clear(); + op_info_with_ms_func_forward_tensors_.clear(); + cnode_hash_with_op_index_.clear(); +} + void TopCellInfo::SetUnpackOutputToGraphInfoMap(const std::string &id, const AnfNodePtr &node, const std::vector &index) const { auto &graph_info = graph_info_map().at(fg()); MS_EXCEPTION_IF_NULL(graph_info); graph_info->node_map[id] = std::make_pair(node, index); } + +void TopCellInfo::set_opinfo_with_tensor_id(const std::string &op_info, + const std::vector &op_out_tensors) { + if (op_info_with_tensor_id_.find(op_info) != op_info_with_tensor_id_.end()) { + MS_LOG(EXCEPTION) << "Top cell: " << cell_id_ << " records op info with tensor id, but get op info " << op_info + << " in op_info_with_tensor_id map"; + } + // Record the relationship between the forward op and its output tensor id + (void)std::for_each(op_out_tensors.begin(), op_out_tensors.end(), [this, &op_info](const tensor::TensorPtr &tensor) { + (void)op_info_with_tensor_id_[op_info].emplace_back(tensor->id()); + }); +} } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/grad/top_cell.h b/mindspore/ccsrc/pipeline/pynative/grad/top_cell.h index 8112637d494..2ef75a8505a 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/top_cell.h +++ b/mindspore/ccsrc/pipeline/pynative/grad/top_cell.h @@ -42,6 +42,9 @@ namespace mindspore { namespace pynative { namespace py = pybind11; class GradExecutor; +using OpInfoWithTensorId = mindspore::HashMap>; +using TensorIdWithTensorObject = mindspore::HashMap>; +using OpInfoWithMsFuncForwardTensors = mindspore::HashMap>; using CellIdWithBackwardHookOp = mindspore::HashMap>; struct GraphInfo { @@ -55,9 +58,10 @@ using GraphInfoPtr = std::shared_ptr; class TopCellInfo { public: ~TopCellInfo() = default; - TopCellInfo(size_t grad_order, std::string cellid, std::string already_run_cell_id, pipeline::ResourcePtr r, - FuncGraphPtr fg) + TopCellInfo(size_t grad_order, std::string c_cell_id, std::string cellid, std::string already_run_cell_id, + pipeline::ResourcePtr r, FuncGraphPtr fg) : grad_order_(grad_order), + c_cell_id_(std::move(c_cell_id)), cell_id_(std::move(cellid)), already_run_cell_id_(std::move(already_run_cell_id)), resource_(std::move(r)), @@ -70,11 +74,14 @@ class TopCellInfo { inline void set_sub_cell_hook_changed(const std::string &sub_cell) { (void)sub_cell_hook_changed_.emplace(sub_cell); } inline const CellIdWithBackwardHookOp &cell_backward_hook_op() const { return cell_backward_hook_op_; } void RecordCellBackwardHookOp(const std::string &cell_order, const AnfNodePtr &hook_op); + void GetOpInfo(const FrontendOpRunInfoPtr &op_run_info); inline void ClearCellHookOp() { cell_backward_hook_op_.clear(); } inline bool ms_function_flag() const { return ms_function_flag_; } inline void set_ms_function_flag(bool ms_function_flag) { ms_function_flag_ = ms_function_flag; } inline bool forward_already_run() const { return forward_already_run_; } inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; } + inline bool need_compile_graph() const { return need_compile_graph_; } + inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; } inline pipeline::ResourcePtr resource() const { return resource_; } inline FuncGraphPtr fg() const { MS_EXCEPTION_IF_NULL(fg_); @@ -82,18 +89,51 @@ class TopCellInfo { } inline void set_fg(const FuncGraphPtr &fg) { fg_ = fg; } inline const std::string &cell_id() const { return cell_id_; } + inline const std::string &c_cell_id() const { return c_cell_id_; } inline const std::string &already_run_cell_id() const { return already_run_cell_id_; } inline void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; } inline const std::string &input_args_id() const { return input_args_id_; } + const std::string &grad_operation() const { return grad_operation_; } + void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; } inline void CheckSubCellHookChanged() { sub_cell_hook_changed_.clear(); } inline void SetGraphInfoMap(const FuncGraphPtr &fg, const GraphInfoPtr &graph_info) { graph_info_map_[fg] = graph_info; } + inline void set_is_run_cell(bool is_run_cell) { is_run_cell_ = is_run_cell; } + inline bool is_run_cell() { return is_run_cell_; } inline const OrderedMap &graph_info_map() const { return graph_info_map_; } - inline ad::AutoGradCellImplPtr auto_grad_cell_ptr() const { return auto_grad_cell_ptr_; } + inline ad::AutoGradCellImplPtr auto_grad_cell_ptr() const { + MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr_); + return auto_grad_cell_ptr_; + } void set_auto_grad_cell_ptr(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr) { auto_grad_cell_ptr_ = auto_grad_cell_ptr; } + inline const OpInfoWithTensorId &op_info_with_tensor_id() const { return op_info_with_tensor_id_; } + void set_opinfo_with_tensor_id(const std::string &op_info, const std::vector &op_out_tensors); + inline const TensorIdWithTensorObject &tensor_id_with_tensor_object() const { return tensor_id_with_tensor_object_; } + inline void set_tensor_id_with_tensor_object(const std::string &id, const tensor::TensorPtr &tensor) { + (void)tensor_id_with_tensor_object_[id].emplace_back(tensor); + } + inline const OpInfoWithMsFuncForwardTensors &op_info_with_ms_func_forward_tensors() const { + return op_info_with_ms_func_forward_tensors_; + } + inline size_t op_index() const { return op_index_; } + inline void IncreaseOpIndex() { op_index_++; } + + inline void set_cnode_hash_with_op_index(const size_t &node_hash, const size_t &op_index) { + cnode_hash_with_op_index_[node_hash] = op_index; + } + inline size_t get_op_index_by_cnode_hash(const size_t &node_hash) { + auto iter = cnode_hash_with_op_index_.find(node_hash); + if (iter == cnode_hash_with_op_index_.end()) { + MS_LOG(EXCEPTION) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_"; + } + return iter->second; + } + + void Clear(); + void DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id); void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr ¶m, bool is_weight = false) const; void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1, @@ -111,7 +151,11 @@ class TopCellInfo { bool ms_function_flag_{false}; bool is_init_kpynative_{false}; bool forward_already_run_{false}; + bool need_compile_graph_{false}; + bool is_run_cell_{false}; + size_t op_index_{0}; size_t grad_order_{0}; + std::string c_cell_id_; std::string cell_id_; std::string already_run_cell_id_; std::string input_args_id_; @@ -126,6 +170,10 @@ class TopCellInfo { // Record backward hook ops for each cell object. // Each cell object has two backward hook ops. CellIdWithBackwardHookOp cell_backward_hook_op_; + OpInfoWithTensorId op_info_with_tensor_id_; + TensorIdWithTensorObject tensor_id_with_tensor_object_; + OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_; + mindspore::HashMap cnode_hash_with_op_index_; }; using TopCellInfoPtr = std::shared_ptr; } // namespace pynative diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 52b4b788a82..93b0d27ec0a 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -139,6 +139,7 @@ void PyNativeExecutor::ClearRes() const { void PyNativeExecutor::Init() { MS_LOG(DEBUG) << "Init PyNativeExecutor"; forward_executor_ = std::make_shared(); + forward_executor_->Init(); grad_executor_ = std::make_shared(forward_executor_); forward_executor_->set_grad_executor(grad_executor_); } @@ -161,8 +162,8 @@ bool PyNativeExecutor::grad_flag() const { return grad_executor()->grad_flag(); void PyNativeExecutor::set_grad_flag(bool flag) const { grad_executor()->set_grad_flag(flag); } py::object PyNativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, - const py::args &args) const { - return grad_executor()->CheckAlreadyRun(grad, obj, args); + const py::object &grad_hash_id, const py::args &args) const { + return grad_executor()->CheckAlreadyRun(grad, obj, grad_hash_id, args); } void PyNativeExecutor::NewGraph(const py::object &obj, const py::args &args) const { @@ -187,7 +188,10 @@ void PyNativeExecutor::EndGraph(const py::object &obj, const py::object &out, co forward_executor()->ProcessAfterEndGraph(obj, is_cell); } -py::object PyNativeExecutor::Run() const { return PyNativeExecutorTry(grad_executor()->RunGraph); } +py::object PyNativeExecutor::Run() const { + const auto &ret = PyNativeExecutorTry(grad_executor()->RunGraph); + return ret; +} void PyNativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::object &grad_position, const py::args &args) const { @@ -195,13 +199,22 @@ void PyNativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::obj } py::object PyNativeExecutor::GradMsFunction(const py::object &out, const py::args &args) const { - return grad_executor()->ms_function()->GradMsFunction(out, args); + const auto &ret = grad_executor()->ms_function()->GradMsFunction(out, args); + return ret; } void PyNativeExecutor::SetLazyBuild(bool enable) const { forward_executor()->set_lazy_build(enable); } bool PyNativeExecutor::IsFirstCell() const { return forward_executor()->IsFirstCell(); } +void PyNativeExecutor::SetMsFunctionCompileStatus(bool is_compiling) const { + forward_executor()->set_is_ms_function_compiling(is_compiling); +} + +void PyNativeExecutor::SetDynamicInput(const py::object &cell, const py::args &args) const { + grad_executor()->set_use_dynamic_shape_process(true); +} + void RegPyNativeExecutor(const py::module *m) { (void)py::class_>(*m, "PyNativeExecutor_") .def_static("get_instance", &PyNativeExecutor::GetInstance, "PyNativeExecutor get_instance.") @@ -220,10 +233,13 @@ void RegPyNativeExecutor(const py::module *m) { .def("set_hook_changed", &PyNativeExecutor::SetHookChanged, "set pynative hook changed") .def("set_grad_flag", &PyNativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), "Executor set grad flag.") + .def("set_dynamic_input", &PyNativeExecutor::SetDynamicInput, "set dynamic input") .def("set_py_exe_path", &PyNativeExecutor::set_py_exe_path, py::arg("py_exe_path") = py::str(""), "set python executable path.") .def("set_kernel_build_server_dir", &PyNativeExecutor::set_kernel_build_server_dir, py::arg("kernel_build_server_dir") = py::str(""), "set kernel build server directory path.") + .def("set_ms_function_compile_status", &PyNativeExecutor::SetMsFunctionCompileStatus, + "set ms_funciton compile status.") .def("real_run_op", &PyNativeExecutor::RealRunOp, "Run op pynatively.") .def("constant_folding", &PyNativeExecutor::CallConstantFolding, "Call Constant Folding Primitive"); } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 54fe51cd00b..b6f66759920 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -65,13 +65,17 @@ class PyNativeExecutor : public std::enable_shared_from_this { void GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::object &grad_position, const py::args &args) const; py::object GradMsFunction(const py::object &out, const py::args &args) const; - py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args) const; + void SetDynamicInput(const py::object &cell, const py::args &args) const; + + py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id, + const py::args &args) const; void ClearRes() const; // Sync stream void Sync() const; void SetLazyBuild(bool enable) const; bool IsFirstCell() const; void WorkerJoin() { grad_executor_->WorkerJoin(); } + void SetMsFunctionCompileStatus(bool is_compiling) const; private: PyNativeExecutor() = default; diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc index fb2ff4c6783..55afdb0178a 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc @@ -602,13 +602,16 @@ TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel, } void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info, + bool use_dynamic_shape_process, session::BackendOpRunInfoPtr *op_run_info, GraphInfo *graph_info, const GraphOutputInfo *const graph_output_info) { MS_EXCEPTION_IF_NULL(session_); MS_EXCEPTION_IF_NULL(graph_info); *op_run_info = session_->GetSingleOpRunInfo(kernel, *graph_info, tensor_info, graph_output_info); session_->GetSingleOpGraphInfo(kernel, tensor_info, graph_info, *op_run_info); + MS_EXCEPTION_IF_NULL(*op_run_info); (*op_run_info)->base_op_run_info.graph_info = *graph_info; + (*op_run_info)->base_op_run_info.use_dynamic_shape_process = use_dynamic_shape_process; } void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map *ref_count) const { diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.h b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.h index 87000442a3e..c6aba6c1dab 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.h @@ -130,8 +130,8 @@ class GraphCompiler { // Get OpRunInfo and GraphInfo for single op compile and run. void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info, - session::BackendOpRunInfoPtr *op_run_info, GraphInfo *graph_info, - const GraphOutputInfo *const graph_output_info); + bool use_dynamic_shape_process, session::BackendOpRunInfoPtr *op_run_info, + GraphInfo *graph_info, const GraphOutputInfo *const graph_output_info); // Calculate ref count of PyNative back propagation operators. void CalculateRefCount(const KernelGraphPtr &graph, std::map *ref_count) const; diff --git a/mindspore/python/mindspore/common/api.py b/mindspore/python/mindspore/common/api.py index 6c276913c67..6ed75c40f5a 100644 --- a/mindspore/python/mindspore/common/api.py +++ b/mindspore/python/mindspore/common/api.py @@ -296,7 +296,9 @@ class _MindsporeFunctionExecutor: args_list = args if self.obj is not None: args_list = args_list[1:] + _pynative_executor.set_ms_function_compile_status(True) phase = self.compile(args_list, self.fn.__name__) + _pynative_executor.set_ms_function_compile_status(False) if context.get_context("precompile_only"): return None new_inputs = self._generate_run_args(args_list) @@ -428,6 +430,7 @@ class _MindsporeFunctionExecutor: self.input_signature.append(args_list[-1]) Validator.check_dynamic_shape(self.input_signature, args_list) compile_args = tuple(self.input_signature) + _pynative_executor.set_dynamic_input(self.obj, *compile_args) return compile_args def _generate_run_args(self, args_list): @@ -1012,7 +1015,7 @@ class _PyNativeExecutor: """ self._executor.end_graph(obj, output, *args, *(kwargs.values())) - def check_run(self, grad, obj, *args, **kwargs): + def check_run(self, grad, obj, grad_hash_id, *args, **kwargs): """ Whether the forward graph need to construct. @@ -1026,7 +1029,7 @@ class _PyNativeExecutor: Return: bool, specifies whether the forward graph need to construct. """ - return self._executor.check_run(grad, obj, *args, *(kwargs.values())) + return self._executor.check_run(grad, obj, grad_hash_id, *args, *(kwargs.values())) def grad(self, obj, grad, weights, grad_position, *args, **kwargs): """ @@ -1122,6 +1125,30 @@ class _PyNativeExecutor: """ self._executor.set_grad_flag(flag) + def set_ms_function_compile_status(self, status): + """ + Set ms_function is compiling + + Args: + status(bool): ms_function compile status + Return: + None. + """ + self._executor.set_ms_function_compile_status(status) + + def set_dynamic_input(self, obj, *args): + """ + Set dynamic shape tensor of input arguments. + + Args: + obj (Function/Cell): The function or cell instance. + args (tuple): Function or cell dynamic input arguments. + + Return: + None. + """ + self._executor.set_dynamic_input(obj, *args) + def is_first_cell(self): """ The flag of first cell instance. diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index ad5a05d9486..4b7d09fa4eb 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -891,6 +891,8 @@ class Cell(Cell_): self._check_construct_args(*inputs) if self._dynamic_shape_inputs: ds.config.set_dynamic_shape(True) + if context._get_mode() == context.PYNATIVE_MODE: + _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs) def get_inputs(self): """ diff --git a/mindspore/python/mindspore/ops/composite/base.py b/mindspore/python/mindspore/ops/composite/base.py index fbd99c14ca3..bd90c46e0b4 100644 --- a/mindspore/python/mindspore/ops/composite/base.py +++ b/mindspore/python/mindspore/ops/composite/base.py @@ -392,14 +392,14 @@ class GradOperation(GradOperation_): new_kwargs = kwargs.copy() new_kwargs.pop('sens') if isinstance(fn, (FunctionType, MethodType)): - if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs): + if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs): _pynative_executor.set_grad_flag(True) _pynative_executor.new_graph(fn, *args, **new_kwargs) output = fn(*args, **new_kwargs) _pynative_executor.end_graph(fn, output, *args, **new_kwargs) else: # Check if fn have run already - if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs): + if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs): fn.set_grad() fn(*args, **new_kwargs) fn.set_grad(False) @@ -465,6 +465,7 @@ class _Grad(GradOperation_): self.pynative_ = False self.grad_position = None self.weights_id = None + self.grad_hash_id = None def __call__(self, fn, weights=None, grad_position=0): weights_id = _get_grad_weights_id(weights) @@ -537,6 +538,7 @@ class _Grad(GradOperation_): self.fn = fn self.grad_position = grad_position self.weights_id = weights_id + self.grad_hash_id = (grad_position, weights_id) return self.grad_fn def _pynative_forward_run(self, fn, grad, args, kwargs): @@ -550,7 +552,7 @@ class _Grad(GradOperation_): else: args = args[:-1] if isinstance(fn, (FunctionType, MethodType)): - if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs): + if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs): _pynative_executor.set_grad_flag(True) _pynative_executor.new_graph(fn, *args, **new_kwargs) outputs = fn(*args, **new_kwargs) @@ -558,7 +560,7 @@ class _Grad(GradOperation_): return outputs else: # Check if fn has run already. - if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs): + if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs): fn.set_grad() outputs = fn(*args, **new_kwargs) fn.set_grad(False)