forked from mindspore-Ecosystem/mindspore
[Pynative]fix pre run top cell bug
This commit is contained in:
parent
e70b601c47
commit
27d391ae66
|
@ -872,7 +872,7 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
|
||||||
<< ", input args info ptr " << top_input_args_info_.get();
|
<< ", input args info ptr " << top_input_args_info_.get();
|
||||||
|
|
||||||
SetSensValue(grad, top_input_args_info_, args);
|
SetSensValue(grad, top_input_args_info_, args);
|
||||||
GetPreRunTopCell(top_input_args_info_->cell_id);
|
GetPreRunTopCell(grad, obj, args);
|
||||||
|
|
||||||
// For async, top can not be change when run SetForwardLastNodeInfo; Change top cell after sync
|
// For async, top can not be change when run SetForwardLastNodeInfo; Change top cell after sync
|
||||||
auto already_run_top_cell = already_run_top_cell_.at(top_cell()->already_run_cell_id());
|
auto already_run_top_cell = already_run_top_cell_.at(top_cell()->already_run_cell_id());
|
||||||
|
@ -909,7 +909,7 @@ std::string GradExecutor::GetAlreadyRunCellId(const std::string &cell_id) const
|
||||||
return already_run_cell_id;
|
return already_run_cell_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::GetPreRunTopCell(const std::string &cell_id) {
|
void GradExecutor::GetPreRunTopCell(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args) {
|
||||||
// @wrap_op
|
// @wrap_op
|
||||||
// class A():
|
// class A():
|
||||||
// def construct(self):
|
// def construct(self):
|
||||||
|
@ -928,6 +928,25 @@ void GradExecutor::GetPreRunTopCell(const std::string &cell_id) {
|
||||||
if (top_cell_ != nullptr) {
|
if (top_cell_ != nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(grad);
|
||||||
|
py::args args_without_sens;
|
||||||
|
if (grad->sens_param_) {
|
||||||
|
// If there is a sense, it will not hit the already run cache
|
||||||
|
auto tuple_args_size = args.size() - 1;
|
||||||
|
if (tuple_args_size < 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "args.size:" << args.size() << " tuple_args_size:" << tuple_args_size << " is invalid.";
|
||||||
|
}
|
||||||
|
py::tuple tuple_args(tuple_args_size);
|
||||||
|
for (size_t i = 0; i < tuple_args_size; ++i) {
|
||||||
|
tuple_args[i] = args[i];
|
||||||
|
}
|
||||||
|
args_without_sens = tuple_args;
|
||||||
|
} else {
|
||||||
|
args_without_sens = args;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &cell_id = GetCellId(obj, args_without_sens, nullptr);
|
||||||
MS_LOG(DEBUG) << "Get pre run top cell cell id:" << cell_id;
|
MS_LOG(DEBUG) << "Get pre run top cell cell id:" << cell_id;
|
||||||
const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id);
|
const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id);
|
||||||
top_cell_ = GetTopCell(check_already_run_cell_id);
|
top_cell_ = GetTopCell(check_already_run_cell_id);
|
||||||
|
|
|
@ -101,7 +101,7 @@ class GradExecutor {
|
||||||
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id,
|
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id,
|
||||||
const py::args &args);
|
const py::args &args);
|
||||||
TopCellInfoPtr GetAlreadyRunTopCell(const std::string &already_run_cell_id) const;
|
TopCellInfoPtr GetAlreadyRunTopCell(const std::string &already_run_cell_id) const;
|
||||||
void GetPreRunTopCell(const std::string &cell_id);
|
void GetPreRunTopCell(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args);
|
||||||
void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||||
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||||
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
|
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
|
||||||
|
|
Loading…
Reference in New Issue