forked from mindspore-Ecosystem/mindspore
!48741 fix_high_order_error
Merge pull request !48741 from luochao60/fix_empty_input_error
This commit is contained in:
commit
5d61d3ff3d
|
@ -180,9 +180,6 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
|
||||||
if (op_run_info->output_get_by_infer_value) {
|
if (op_run_info->output_get_by_infer_value) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Set forward output flag for release memory,
|
|
||||||
// Because tensor address may change, it should set in main thread to ensure consistency.
|
|
||||||
PyNativeAlgo::Common::SetForwardOutputFlag(op_run_info->out_value);
|
|
||||||
|
|
||||||
// 4. Do op grad and record op info
|
// 4. Do op grad and record op info
|
||||||
// If ms function is compile, op info will not be find in second training step
|
// If ms function is compile, op info will not be find in second training step
|
||||||
|
|
|
@ -500,6 +500,13 @@ void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_i
|
||||||
} else if (input_args_info->is_high_order_top_cell) {
|
} else if (input_args_info->is_high_order_top_cell) {
|
||||||
MS_LOG(DEBUG) << "Nested grad graph existed in construct";
|
MS_LOG(DEBUG) << "Nested grad graph existed in construct";
|
||||||
MakeNewTopGraph(input_args_info);
|
MakeNewTopGraph(input_args_info);
|
||||||
|
// We need wait construct bprop task of outer top cell finish, if main thread run quickly, when it execute gradnet
|
||||||
|
// and clear async_executor queue, bprop task of outer top cell may not finish, it will cause not found cnode
|
||||||
|
// error.
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
async_executor_->Wait();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -905,8 +912,11 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
|
||||||
auto already_run_top_cell = already_run_top_cell_.at(top_cell()->already_run_cell_id());
|
auto already_run_top_cell = already_run_top_cell_.at(top_cell()->already_run_cell_id());
|
||||||
if (!already_run_top_cell->need_compile_graph()) {
|
if (!already_run_top_cell->need_compile_graph()) {
|
||||||
MS_LOG(DEBUG) << "No need compile graph";
|
MS_LOG(DEBUG) << "No need compile graph";
|
||||||
// If no need compile, we can finish construct left bprop
|
// If no need compile, we can clear construct bprop queue.
|
||||||
async_executor_->Clear();
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
async_executor_->Clear();
|
||||||
|
}
|
||||||
set_top_cell(already_run_top_cell);
|
set_top_cell(already_run_top_cell);
|
||||||
top_cell()->UpdateTopCellInfo(false, false, false);
|
top_cell()->UpdateTopCellInfo(false, false, false);
|
||||||
return;
|
return;
|
||||||
|
@ -1726,10 +1736,6 @@ void GradExecutor::AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node,
|
||||||
|
|
||||||
void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const {
|
void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const {
|
||||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
if (op_run_info->base_op_run_info.use_dynamic_shape_process) {
|
|
||||||
MS_LOG(DEBUG) << "Get dynamic shape process";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
top_cell()->GetOpInfo(op_run_info);
|
top_cell()->GetOpInfo(op_run_info);
|
||||||
MS_LOG(DEBUG) << "Current op info: " << op_run_info->op_info;
|
MS_LOG(DEBUG) << "Current op info: " << op_run_info->op_info;
|
||||||
|
|
||||||
|
@ -1841,9 +1847,6 @@ void GradExecutor::UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const {
|
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const {
|
||||||
if (top_cell()->use_dynamic_shape_process()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Get all tensors id of forward op
|
// Get all tensors id of forward op
|
||||||
mindspore::HashSet<std::string> forward_op_tensor_id;
|
mindspore::HashSet<std::string> forward_op_tensor_id;
|
||||||
const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id();
|
const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id();
|
||||||
|
@ -1871,10 +1874,12 @@ void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
tensor->set_is_forward_output(true);
|
tensor->set_is_forward_output(true);
|
||||||
top_cell()->set_tensor_id_with_tensor_object(tensor->id(), tensor);
|
if (!top_cell()->use_dynamic_shape_process()) {
|
||||||
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
|
top_cell()->set_tensor_id_with_tensor_object(tensor->id(), tensor);
|
||||||
<< " device address: " << tensor->device_address() << " shape and dtype "
|
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
|
||||||
<< tensor->GetShapeAndDataTypeInfo();
|
<< " device address: " << tensor->device_address() << " shape and dtype "
|
||||||
|
<< tensor->GetShapeAndDataTypeInfo();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -55,13 +55,11 @@ FrontendOpRunInfoPtr GetOpRunInfo(const py::object &out, const py::args &args, c
|
||||||
// Forward output of op in ms_function graph
|
// Forward output of op in ms_function graph
|
||||||
*added_out_v = PyNativeAlgo::DataConvert::PyObjToValue(tuple_out[1]);
|
*added_out_v = PyNativeAlgo::DataConvert::PyObjToValue(tuple_out[1]);
|
||||||
MS_LOG(DEBUG) << "Added output value is: " << (*added_out_v)->ToString();
|
MS_LOG(DEBUG) << "Added output value is: " << (*added_out_v)->ToString();
|
||||||
PyNativeAlgo::Common::SetForwardOutputFlag(*added_out_v);
|
|
||||||
auto op_run_info = std::make_shared<FrontendOpRunInfo>();
|
auto op_run_info = std::make_shared<FrontendOpRunInfo>();
|
||||||
PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args);
|
PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args);
|
||||||
op_run_info->base_op_run_info.op_name = graph_phase;
|
op_run_info->base_op_run_info.op_name = graph_phase;
|
||||||
// Output of ms_function
|
// Output of ms_function
|
||||||
op_run_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(tuple_out[0]);
|
op_run_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(tuple_out[0]);
|
||||||
PyNativeAlgo::Common::SetForwardOutputFlag(op_run_info->out_value);
|
|
||||||
op_run_info->base_op_run_info.abstract =
|
op_run_info->base_op_run_info.abstract =
|
||||||
PyNativeAlgo::Common::SetAbstractValueToAnyValue(op_run_info->out_value->ToAbstract());
|
PyNativeAlgo::Common::SetAbstractValueToAnyValue(op_run_info->out_value->ToAbstract());
|
||||||
op_run_info->grad_flag = true;
|
op_run_info->grad_flag = true;
|
||||||
|
@ -237,8 +235,6 @@ void MsFunction::ReplaceWithRealTensorsInGradGraph(const GradExecutor *grad_exec
|
||||||
// The forward node in ms_function graph is created during compilation and is a
|
// The forward node in ms_function graph is created during compilation and is a
|
||||||
// placeholder(mindspore/ccsrc/frontend/optimizer/ad/pynative_dfunctor.cc).After running ms_function, need to update
|
// placeholder(mindspore/ccsrc/frontend/optimizer/ad/pynative_dfunctor.cc).After running ms_function, need to update
|
||||||
// to real value.
|
// to real value.
|
||||||
(void)std::for_each(total_output_tensors.begin(), total_output_tensors.end(),
|
|
||||||
[](const tensor::TensorPtr &tensor) { tensor->set_is_forward_output(true); });
|
|
||||||
RunReplace(added_make_tuple, total_output_tensors, grad_graph, is_dynamic_shape);
|
RunReplace(added_make_tuple, total_output_tensors, grad_graph, is_dynamic_shape);
|
||||||
grad_executor->top_cell()->set_op_info_with_ms_func_forward_tensors(op_run_info->op_info, total_output_tensors);
|
grad_executor->top_cell()->set_op_info_with_ms_func_forward_tensors(op_run_info->op_info, total_output_tensors);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue