diff --git a/mindspore/ccsrc/pipeline/pynative/grad/grad.cc b/mindspore/ccsrc/pipeline/pynative/grad/grad.cc index 68ca0116f08..4221d26cacb 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/grad.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/grad.cc @@ -872,7 +872,7 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob << ", input args info ptr " << top_input_args_info_.get(); SetSensValue(grad, top_input_args_info_, args); - GetPreRunTopCell(top_input_args_info_->cell_id); + GetPreRunTopCell(grad, obj, args); // For async, top can not be change when run SetForwardLastNodeInfo; Change top cell after sync auto already_run_top_cell = already_run_top_cell_.at(top_cell()->already_run_cell_id()); @@ -909,7 +909,7 @@ std::string GradExecutor::GetAlreadyRunCellId(const std::string &cell_id) const return already_run_cell_id; } -void GradExecutor::GetPreRunTopCell(const std::string &cell_id) { +void GradExecutor::GetPreRunTopCell(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args) { // @wrap_op // class A(): // def construct(self): @@ -928,6 +928,25 @@ void GradExecutor::GetPreRunTopCell(const std::string &cell_id) { if (top_cell_ != nullptr) { return; } + + MS_EXCEPTION_IF_NULL(grad); + py::args args_without_sens; + if (grad->sens_param_) { + // If there is a sense, it will not hit the already run cache + auto tuple_args_size = args.size() - 1; + if (tuple_args_size < 0) { + MS_LOG(EXCEPTION) << "args.size:" << args.size() << " tuple_args_size:" << tuple_args_size << " is invalid."; + } + py::tuple tuple_args(tuple_args_size); + for (size_t i = 0; i < tuple_args_size; ++i) { + tuple_args[i] = args[i]; + } + args_without_sens = tuple_args; + } else { + args_without_sens = args; + } + + const auto &cell_id = GetCellId(obj, args_without_sens, nullptr); MS_LOG(DEBUG) << "Get pre run top cell cell id:" << cell_id; const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id); top_cell_ = GetTopCell(check_already_run_cell_id); diff --git a/mindspore/ccsrc/pipeline/pynative/grad/grad.h b/mindspore/ccsrc/pipeline/pynative/grad/grad.h index 57b2d73a4d8..6ad2785bb1d 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/grad.h +++ b/mindspore/ccsrc/pipeline/pynative/grad/grad.h @@ -101,7 +101,7 @@ class GradExecutor { py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id, const py::args &args); TopCellInfoPtr GetAlreadyRunTopCell(const std::string &already_run_cell_id) const; - void GetPreRunTopCell(const std::string &cell_id); + void GetPreRunTopCell(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args); void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const; void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const; AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;