!49442 Unify ms function and common op for forward output replace algorithm

Merge pull request !49442 from zjun/unify_ms_replace_tensor
This commit is contained in:
i-robot 2023-02-28 13:26:39 +00:00 committed by Gitee
commit 913ecc3c6a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 30 additions and 79 deletions

View File

@ -1617,7 +1617,8 @@ void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) co
cnode->set_abstract(op_run_info->base_op_run_info.abstract);
SaveOutputNodeMap(op_run_info->out_value_id, op_run_info, cnode);
DoOpGrad(op_run_info, cnode, op_run_info->out_value);
UpdateForwardTensorInfoInBpropGraph(op_run_info);
top_cell()->GetOpInfo(op_run_info);
UpdateForwardTensorInfoInBpropGraph(op_run_info->op_info, op_run_info->out_value);
CheckGraphDynamic(cnode);
}
@ -1739,16 +1740,13 @@ void GradExecutor::AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node,
async_executor_->Push(task);
}
void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const {
MS_EXCEPTION_IF_NULL(op_run_info);
top_cell()->GetOpInfo(op_run_info);
MS_LOG(DEBUG) << "Current op info: " << op_run_info->op_info;
void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const std::string &op_info, const ValuePtr &v) const {
MS_LOG(DEBUG) << "Current op info: " << op_info;
std::vector<tensor::TensorPtr> op_output_tensors;
// Get output tensors
TensorValueToTensor(op_run_info->out_value, &op_output_tensors);
TensorValueToTensor(v, &op_output_tensors);
// Save all tensors info of current op
top_cell()->set_opinfo_with_tensor_id(op_run_info->op_info, op_output_tensors);
top_cell()->set_opinfo_with_tensor_id(op_info, op_output_tensors);
// First run top cell
const auto &pre_top_cell = GetAlreadyRunTopCell(top_cell()->already_run_cell_id());
@ -1757,15 +1755,17 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPt
return;
}
// Non-first run
if (pre_top_cell->op_info_with_tensor_id().find(op_run_info->op_info) ==
pre_top_cell->op_info_with_tensor_id().end()) {
MS_LOG(DEBUG) << "Can not find op info " << op_run_info->op_info << " in op info with tensor id map. Top cell "
if (pre_top_cell->op_info_with_tensor_id().find(op_info) == pre_top_cell->op_info_with_tensor_id().end()) {
MS_LOG(DEBUG) << "Can not find op info " << op_info << " in op info with tensor id map. Top cell "
<< top_cell_->already_run_cell_id();
return;
}
// Update new output tensor info in bprop graph
const auto &pre_op_tensor_id = pre_top_cell->op_info_with_tensor_id().at(op_run_info->op_info);
if (top_cell()->use_dynamic_shape_process()) {
return;
}
const auto &pre_op_tensor_id = pre_top_cell->op_info_with_tensor_id().at(op_info);
if (pre_op_tensor_id.size() != op_output_tensors.size()) {
MS_LOG(EXCEPTION) << "The size of op pre output tensor size: " << pre_op_tensor_id.size()
<< " is not equal to current " << op_output_tensors.size();
@ -1879,12 +1879,13 @@ void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr
continue;
}
tensor->set_is_forward_output(true);
if (!top_cell()->use_dynamic_shape_process()) {
top_cell()->set_tensor_id_with_tensor_object(tensor->id(), tensor);
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
<< " device address: " << tensor->device_address() << " shape and dtype "
<< tensor->GetShapeAndDataTypeInfo();
if (top_cell()->use_dynamic_shape_process()) {
continue;
}
top_cell()->set_tensor_id_with_tensor_object(tensor->id(), tensor);
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
<< " device address: " << tensor->device_address() << " shape and dtype "
<< tensor->GetShapeAndDataTypeInfo();
}
}

View File

@ -107,7 +107,7 @@ class GradExecutor {
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const;
void UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const;
void UpdateForwardTensorInfoInBpropGraph(const std::string &op_info, const ValuePtr &v) const;
void UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor,
const std::vector<tensor::TensorPtr> &pre_tensors) const;
void ClearRes();

View File

@ -195,9 +195,9 @@ void MsFunction::RunReplace(const CNodePtr &added_make_tuple,
}
}
void MsFunction::ReplaceWithRealTensorsInGradGraph(const GradExecutor *grad_executor, const ValuePtr &added_out,
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
const FrontendOpRunInfoPtr &op_run_info) const {
void MsFunction::ReplaceAddedCnodeActualOutput(const GradExecutor *grad_executor, const ValuePtr &added_out,
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
const FrontendOpRunInfoPtr &op_run_info) const {
MS_EXCEPTION_IF_NULL(ms_func_graph);
// Get added forward nodes.
auto merge_node = ms_func_graph->output();
@ -221,9 +221,6 @@ void MsFunction::ReplaceWithRealTensorsInGradGraph(const GradExecutor *grad_exec
MS_EXCEPTION_IF_NULL(added_forward_node);
if (added_forward_node->isa<ValueNode>()) {
MS_LOG(DEBUG) << "The added forward output node is value node: " << added_forward_node->DebugString();
std::vector<tensor::TensorPtr> total_output_tensors;
TensorValueToTensor(added_out, &total_output_tensors);
grad_executor->top_cell()->set_op_info_with_ms_func_forward_tensors(op_run_info->op_info, total_output_tensors);
return;
}
// Replace new output tensors for forward nodes, it will also work in grad graph with same value node.
@ -236,33 +233,6 @@ void MsFunction::ReplaceWithRealTensorsInGradGraph(const GradExecutor *grad_exec
// placeholder(mindspore/ccsrc/frontend/optimizer/ad/pynative_dfunctor.cc).After running ms_function, need to update
// to real value.
RunReplace(added_make_tuple, total_output_tensors, grad_graph, is_dynamic_shape);
grad_executor->top_cell()->set_op_info_with_ms_func_forward_tensors(op_run_info->op_info, total_output_tensors);
grad_executor->top_cell()->set_opinfo_with_tensor_id(op_run_info->op_info + kAddedValue, total_output_tensors);
}
void MsFunction::UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const TopCellInfoPtr &top_cell,
const string &op_info, const ValuePtr &new_forward_value) const {
MS_EXCEPTION_IF_NULL(new_forward_value);
MS_LOG(DEBUG) << "Ms func graph has already ran before. The graph phase is: " << graph_phase_;
std::vector<tensor::TensorPtr> new_tensors;
TensorValueToTensor(new_forward_value, &new_tensors);
if (new_tensors.empty()) {
MS_LOG(DEBUG) << "The size of added forward tensors is zero, no need to update.";
return;
}
MS_EXCEPTION_IF_NULL(top_cell);
const auto &old_tensors = top_cell->op_info_with_ms_func_forward_tensors().at(op_info);
if (old_tensors.size() != new_tensors.size()) {
MS_LOG(EXCEPTION) << "The size of old tensors is: " << old_tensors.size()
<< ", but the size of new tensors is: " << new_tensors.size()
<< ", the current op info is: " << op_info;
}
MS_EXCEPTION_IF_NULL(grad_executor);
for (size_t i = 0; i < new_tensors.size(); ++i) {
grad_executor->UpdatePreTensorInfo(new_tensors[i], {old_tensors[i]});
MS_EXCEPTION_IF_NULL(old_tensors[i]);
old_tensors[i]->set_sync_status(kNeedSyncDeviceToHost);
}
}
void MsFunction::GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNodePtrList *input_nodes,
@ -404,23 +374,15 @@ void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, co
MS_EXCEPTION_IF_NULL(grad_executor);
// Step 1: Update actual output tensors used in grad graph.
MS_EXCEPTION_IF_NULL(op_run_info->out_value);
MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString();
// The output of ms_function may be used in subsequent PyNative process
grad_executor->UpdateForwardTensorInfoInBpropGraph(op_run_info);
grad_executor->top_cell()->GetOpInfo(op_run_info);
grad_executor->UpdateForwardTensorInfoInBpropGraph(op_run_info->op_info, op_run_info->out_value);
// Step 2: Update output tensors of added forward nodes, which are added to return node of ms_function func graph.
if (grad_executor->use_dynamic_shape_process()) {
MS_LOG(DEBUG) << "Get dynamic shape process";
} else {
const auto &pre_top_cell = grad_executor->GetAlreadyRunTopCell(grad_executor->top_cell()->already_run_cell_id());
if (pre_top_cell != nullptr && pre_top_cell->op_info_with_ms_func_forward_tensors().find(op_run_info->op_info) !=
pre_top_cell->op_info_with_ms_func_forward_tensors().end()) {
UpdateMsFunctionForwardTensors(grad_executor, pre_top_cell, op_run_info->op_info, added_out_v);
}
}
grad_executor->UpdateForwardTensorInfoInBpropGraph(op_run_info->op_info + kAddedValue, added_out_v);
ReplaceWithRealTensorsInGradGraph(grad_executor, added_out_v, ms_func_graph, grad_graph, op_run_info);
// Step 3: Replace added cnode forward with actual output
ReplaceAddedCnodeActualOutput(grad_executor, added_out_v, ms_func_graph, grad_graph, op_run_info);
// Change ms function graph to real output
auto clone_ms_func_graph = BasicClone(ms_func_graph);

View File

@ -52,11 +52,9 @@ class MsFunction {
// Update device address of value node in grad graph by forward tensors.
void RunReplace(const CNodePtr &added_make_tuple, const std::vector<tensor::TensorPtr> &total_output_tensors,
const FuncGraphPtr &grad_graph, bool is_dynamic_shape) const;
void ReplaceWithRealTensorsInGradGraph(const GradExecutor *grad_executor, const ValuePtr &added_out,
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
const FrontendOpRunInfoPtr &op_run_info) const;
void UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const TopCellInfoPtr &top_cell,
const string &op_info, const ValuePtr &new_forward_value) const;
void ReplaceAddedCnodeActualOutput(const GradExecutor *grad_executor, const ValuePtr &added_out,
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
const FrontendOpRunInfoPtr &op_run_info) const;
// Make CNode for ms_function forward graph.
void GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNodePtrList *input_nodes,
const GradExecutor *grad_executor) const;

View File

@ -162,7 +162,6 @@ void TopCellInfo::Clear() {
graph_info_map_.clear();
op_info_with_tensor_id_.clear();
tensor_id_with_tensor_object_.clear();
op_info_with_ms_func_forward_tensors_.clear();
cnode_hash_with_op_index_.clear();
}

View File

@ -44,7 +44,6 @@ namespace py = pybind11;
class GradExecutor;
using OpInfoWithTensorId = mindspore::HashMap<std::string, std::vector<std::string>>;
using TensorIdWithTensorObject = mindspore::HashMap<std::string, std::vector<tensor::TensorPtr>>;
using OpInfoWithMsFuncForwardTensors = mindspore::HashMap<std::string, std::vector<tensor::TensorPtr>>;
using CellIdWithBackwardHookOp = mindspore::HashMap<std::string, std::vector<AnfNodePtr>>;
struct GraphInfo {
@ -124,13 +123,6 @@ class TopCellInfo {
inline void set_tensor_id_with_tensor_object(const std::string &id, const tensor::TensorPtr &tensor) {
(void)tensor_id_with_tensor_object_[id].emplace_back(tensor);
}
inline const OpInfoWithMsFuncForwardTensors &op_info_with_ms_func_forward_tensors() const {
return op_info_with_ms_func_forward_tensors_;
}
inline void set_op_info_with_ms_func_forward_tensors(const std::string &op_info,
const std::vector<tensor::TensorPtr> &forward_tensors) {
op_info_with_ms_func_forward_tensors_[op_info] = forward_tensors;
}
inline size_t op_index() const { return op_index_; }
inline void IncreaseOpIndex() { ++op_index_; }
@ -194,7 +186,6 @@ class TopCellInfo {
CellIdWithBackwardHookOp cell_backward_hook_op_;
OpInfoWithTensorId op_info_with_tensor_id_;
TensorIdWithTensorObject tensor_id_with_tensor_object_;
OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_;
mindspore::HashMap<size_t, size_t> cnode_hash_with_op_index_;
bool use_dynamic_shape_process_{false};
};