fix_empty_output_error

This commit is contained in:
luochao 2023-02-08 15:28:14 +08:00
parent 654c9dba8d
commit f6ed01dda0
3 changed files with 18 additions and 20 deletions

View File

@ -179,9 +179,6 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
if (op_run_info->output_get_by_infer_value) {
return;
}
// Set forward output flag for release memory,
// Because tensor address may change, it should set in main thread to ensure consistency.
PyNativeAlgo::Common::SetForwardOutputFlag(op_run_info->out_value);
// 4. Do op grad and record op info
// If ms function is compile, op info will not be find in second training step

View File

@ -477,6 +477,13 @@ void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_i
} else if (input_args_info->is_high_order_top_cell) {
MS_LOG(DEBUG) << "Nested grad graph existed in construct";
MakeNewTopGraph(input_args_info);
// We need wait construct bprop task of outer top cell finish, if main thread run quickly, when it execute gradnet
// and clear async_executor queue, bprop task of outer top cell may not finish, it will cause not found cnode
// error.
{
py::gil_scoped_release gil_release;
async_executor_->Wait();
}
}
}
@ -882,8 +889,11 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
auto already_run_top_cell = already_run_top_cell_.at(top_cell()->already_run_cell_id());
if (!already_run_top_cell->need_compile_graph()) {
MS_LOG(DEBUG) << "No need compile graph";
// If no need compile, we can finish construct left bprop
async_executor_->Clear();
// If no need compile, we can clear construct bprop queue.
{
py::gil_scoped_release gil_release;
async_executor_->Clear();
}
set_top_cell(already_run_top_cell);
top_cell()->UpdateTopCellInfo(false, false, false);
return;
@ -1675,10 +1685,6 @@ void GradExecutor::AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node,
void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const {
MS_EXCEPTION_IF_NULL(op_run_info);
if (op_run_info->base_op_run_info.use_dynamic_shape_process) {
MS_LOG(DEBUG) << "Get dynamic shape process";
return;
}
top_cell()->GetOpInfo(op_run_info);
MS_LOG(DEBUG) << "Current op info: " << op_run_info->op_info;
@ -1790,9 +1796,6 @@ void GradExecutor::UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor,
}
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const {
if (top_cell()->use_dynamic_shape_process()) {
return;
}
// Get all tensors id of forward op
mindspore::HashSet<std::string> forward_op_tensor_id;
const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id();
@ -1820,10 +1823,12 @@ void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr
continue;
}
tensor->set_is_forward_output(true);
top_cell()->set_tensor_id_with_tensor_object(tensor->id(), tensor);
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
<< " device address: " << tensor->device_address() << " shape and dtype "
<< tensor->GetShapeAndDataTypeInfo();
if (!top_cell()->use_dynamic_shape_process()) {
top_cell()->set_tensor_id_with_tensor_object(tensor->id(), tensor);
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
<< " device address: " << tensor->device_address() << " shape and dtype "
<< tensor->GetShapeAndDataTypeInfo();
}
}
}

View File

@ -55,13 +55,11 @@ FrontendOpRunInfoPtr GetOpRunInfo(const py::object &out, const py::args &args, c
// Forward output of op in ms_function graph
*added_out_v = PyNativeAlgo::DataConvert::PyObjToValue(tuple_out[1]);
MS_LOG(DEBUG) << "Added output value is: " << (*added_out_v)->ToString();
PyNativeAlgo::Common::SetForwardOutputFlag(*added_out_v);
auto op_run_info = std::make_shared<FrontendOpRunInfo>();
PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args);
op_run_info->base_op_run_info.op_name = graph_phase;
// Output of ms_function
op_run_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(tuple_out[0]);
PyNativeAlgo::Common::SetForwardOutputFlag(op_run_info->out_value);
op_run_info->base_op_run_info.abstract =
PyNativeAlgo::Common::SetAbstractValueToAnyValue(op_run_info->out_value->ToAbstract());
op_run_info->grad_flag = true;
@ -237,8 +235,6 @@ void MsFunction::ReplaceWithRealTensorsInGradGraph(const GradExecutor *grad_exec
// The forward node in ms_function graph is created during compilation and is a
// placeholder(mindspore/ccsrc/frontend/optimizer/ad/pynative_dfunctor.cc).After running ms_function, need to update
// to real value.
(void)std::for_each(total_output_tensors.begin(), total_output_tensors.end(),
[](const tensor::TensorPtr &tensor) { tensor->set_is_forward_output(true); });
RunReplace(added_make_tuple, total_output_tensors, grad_graph, is_dynamic_shape);
grad_executor->top_cell()->set_op_info_with_ms_func_forward_tensors(op_run_info->op_info, total_output_tensors);
}