[Pynative]fix pre run top cell bug

This commit is contained in:
wangchangheng 2023-02-13 10:20:09 +08:00
parent e70b601c47
commit 27d391ae66
2 changed files with 22 additions and 3 deletions

View File

@ -872,7 +872,7 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
<< ", input args info ptr " << top_input_args_info_.get(); << ", input args info ptr " << top_input_args_info_.get();
SetSensValue(grad, top_input_args_info_, args); SetSensValue(grad, top_input_args_info_, args);
GetPreRunTopCell(top_input_args_info_->cell_id); GetPreRunTopCell(grad, obj, args);
// For async, top can not be change when run SetForwardLastNodeInfo; Change top cell after sync // For async, top can not be change when run SetForwardLastNodeInfo; Change top cell after sync
auto already_run_top_cell = already_run_top_cell_.at(top_cell()->already_run_cell_id()); auto already_run_top_cell = already_run_top_cell_.at(top_cell()->already_run_cell_id());
@ -909,7 +909,7 @@ std::string GradExecutor::GetAlreadyRunCellId(const std::string &cell_id) const
return already_run_cell_id; return already_run_cell_id;
} }
void GradExecutor::GetPreRunTopCell(const std::string &cell_id) { void GradExecutor::GetPreRunTopCell(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args) {
// @wrap_op // @wrap_op
// class A(): // class A():
// def construct(self): // def construct(self):
@ -928,6 +928,25 @@ void GradExecutor::GetPreRunTopCell(const std::string &cell_id) {
if (top_cell_ != nullptr) { if (top_cell_ != nullptr) {
return; return;
} }
MS_EXCEPTION_IF_NULL(grad);
py::args args_without_sens;
if (grad->sens_param_) {
// If there is a sense, it will not hit the already run cache
auto tuple_args_size = args.size() - 1;
if (tuple_args_size < 0) {
MS_LOG(EXCEPTION) << "args.size:" << args.size() << " tuple_args_size:" << tuple_args_size << " is invalid.";
}
py::tuple tuple_args(tuple_args_size);
for (size_t i = 0; i < tuple_args_size; ++i) {
tuple_args[i] = args[i];
}
args_without_sens = tuple_args;
} else {
args_without_sens = args;
}
const auto &cell_id = GetCellId(obj, args_without_sens, nullptr);
MS_LOG(DEBUG) << "Get pre run top cell cell id:" << cell_id; MS_LOG(DEBUG) << "Get pre run top cell cell id:" << cell_id;
const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id); const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id);
top_cell_ = GetTopCell(check_already_run_cell_id); top_cell_ = GetTopCell(check_already_run_cell_id);

View File

@ -101,7 +101,7 @@ class GradExecutor {
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id, py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id,
const py::args &args); const py::args &args);
TopCellInfoPtr GetAlreadyRunTopCell(const std::string &already_run_cell_id) const; TopCellInfoPtr GetAlreadyRunTopCell(const std::string &already_run_cell_id) const;
void GetPreRunTopCell(const std::string &cell_id); void GetPreRunTopCell(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args);
void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const; void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const; void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const; AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;