!49442 Unify ms function and common op for forward output replace algorithm
Merge pull request !49442 from zjun/unify_ms_replace_tensor
This commit is contained in:
commit
913ecc3c6a
|
@ -1617,7 +1617,8 @@ void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) co
|
|||
cnode->set_abstract(op_run_info->base_op_run_info.abstract);
|
||||
SaveOutputNodeMap(op_run_info->out_value_id, op_run_info, cnode);
|
||||
DoOpGrad(op_run_info, cnode, op_run_info->out_value);
|
||||
UpdateForwardTensorInfoInBpropGraph(op_run_info);
|
||||
top_cell()->GetOpInfo(op_run_info);
|
||||
UpdateForwardTensorInfoInBpropGraph(op_run_info->op_info, op_run_info->out_value);
|
||||
CheckGraphDynamic(cnode);
|
||||
}
|
||||
|
||||
|
@ -1739,16 +1740,13 @@ void GradExecutor::AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node,
|
|||
async_executor_->Push(task);
|
||||
}
|
||||
|
||||
void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const {
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
top_cell()->GetOpInfo(op_run_info);
|
||||
MS_LOG(DEBUG) << "Current op info: " << op_run_info->op_info;
|
||||
|
||||
void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const std::string &op_info, const ValuePtr &v) const {
|
||||
MS_LOG(DEBUG) << "Current op info: " << op_info;
|
||||
std::vector<tensor::TensorPtr> op_output_tensors;
|
||||
// Get output tensors
|
||||
TensorValueToTensor(op_run_info->out_value, &op_output_tensors);
|
||||
TensorValueToTensor(v, &op_output_tensors);
|
||||
// Save all tensors info of current op
|
||||
top_cell()->set_opinfo_with_tensor_id(op_run_info->op_info, op_output_tensors);
|
||||
top_cell()->set_opinfo_with_tensor_id(op_info, op_output_tensors);
|
||||
|
||||
// First run top cell
|
||||
const auto &pre_top_cell = GetAlreadyRunTopCell(top_cell()->already_run_cell_id());
|
||||
|
@ -1757,15 +1755,17 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPt
|
|||
return;
|
||||
}
|
||||
// Non-first run
|
||||
if (pre_top_cell->op_info_with_tensor_id().find(op_run_info->op_info) ==
|
||||
pre_top_cell->op_info_with_tensor_id().end()) {
|
||||
MS_LOG(DEBUG) << "Can not find op info " << op_run_info->op_info << " in op info with tensor id map. Top cell "
|
||||
if (pre_top_cell->op_info_with_tensor_id().find(op_info) == pre_top_cell->op_info_with_tensor_id().end()) {
|
||||
MS_LOG(DEBUG) << "Can not find op info " << op_info << " in op info with tensor id map. Top cell "
|
||||
<< top_cell_->already_run_cell_id();
|
||||
return;
|
||||
}
|
||||
|
||||
// Update new output tensor info in bprop graph
|
||||
const auto &pre_op_tensor_id = pre_top_cell->op_info_with_tensor_id().at(op_run_info->op_info);
|
||||
if (top_cell()->use_dynamic_shape_process()) {
|
||||
return;
|
||||
}
|
||||
const auto &pre_op_tensor_id = pre_top_cell->op_info_with_tensor_id().at(op_info);
|
||||
if (pre_op_tensor_id.size() != op_output_tensors.size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of op pre output tensor size: " << pre_op_tensor_id.size()
|
||||
<< " is not equal to current " << op_output_tensors.size();
|
||||
|
@ -1879,12 +1879,13 @@ void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr
|
|||
continue;
|
||||
}
|
||||
tensor->set_is_forward_output(true);
|
||||
if (!top_cell()->use_dynamic_shape_process()) {
|
||||
top_cell()->set_tensor_id_with_tensor_object(tensor->id(), tensor);
|
||||
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
|
||||
<< " device address: " << tensor->device_address() << " shape and dtype "
|
||||
<< tensor->GetShapeAndDataTypeInfo();
|
||||
if (top_cell()->use_dynamic_shape_process()) {
|
||||
continue;
|
||||
}
|
||||
top_cell()->set_tensor_id_with_tensor_object(tensor->id(), tensor);
|
||||
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
|
||||
<< " device address: " << tensor->device_address() << " shape and dtype "
|
||||
<< tensor->GetShapeAndDataTypeInfo();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ class GradExecutor {
|
|||
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
|
||||
AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const;
|
||||
void UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
void UpdateForwardTensorInfoInBpropGraph(const std::string &op_info, const ValuePtr &v) const;
|
||||
void UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor,
|
||||
const std::vector<tensor::TensorPtr> &pre_tensors) const;
|
||||
void ClearRes();
|
||||
|
|
|
@ -195,9 +195,9 @@ void MsFunction::RunReplace(const CNodePtr &added_make_tuple,
|
|||
}
|
||||
}
|
||||
|
||||
void MsFunction::ReplaceWithRealTensorsInGradGraph(const GradExecutor *grad_executor, const ValuePtr &added_out,
|
||||
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
|
||||
const FrontendOpRunInfoPtr &op_run_info) const {
|
||||
void MsFunction::ReplaceAddedCnodeActualOutput(const GradExecutor *grad_executor, const ValuePtr &added_out,
|
||||
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
|
||||
const FrontendOpRunInfoPtr &op_run_info) const {
|
||||
MS_EXCEPTION_IF_NULL(ms_func_graph);
|
||||
// Get added forward nodes.
|
||||
auto merge_node = ms_func_graph->output();
|
||||
|
@ -221,9 +221,6 @@ void MsFunction::ReplaceWithRealTensorsInGradGraph(const GradExecutor *grad_exec
|
|||
MS_EXCEPTION_IF_NULL(added_forward_node);
|
||||
if (added_forward_node->isa<ValueNode>()) {
|
||||
MS_LOG(DEBUG) << "The added forward output node is value node: " << added_forward_node->DebugString();
|
||||
std::vector<tensor::TensorPtr> total_output_tensors;
|
||||
TensorValueToTensor(added_out, &total_output_tensors);
|
||||
grad_executor->top_cell()->set_op_info_with_ms_func_forward_tensors(op_run_info->op_info, total_output_tensors);
|
||||
return;
|
||||
}
|
||||
// Replace new output tensors for forward nodes, it will also work in grad graph with same value node.
|
||||
|
@ -236,33 +233,6 @@ void MsFunction::ReplaceWithRealTensorsInGradGraph(const GradExecutor *grad_exec
|
|||
// placeholder(mindspore/ccsrc/frontend/optimizer/ad/pynative_dfunctor.cc).After running ms_function, need to update
|
||||
// to real value.
|
||||
RunReplace(added_make_tuple, total_output_tensors, grad_graph, is_dynamic_shape);
|
||||
grad_executor->top_cell()->set_op_info_with_ms_func_forward_tensors(op_run_info->op_info, total_output_tensors);
|
||||
grad_executor->top_cell()->set_opinfo_with_tensor_id(op_run_info->op_info + kAddedValue, total_output_tensors);
|
||||
}
|
||||
|
||||
void MsFunction::UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const TopCellInfoPtr &top_cell,
|
||||
const string &op_info, const ValuePtr &new_forward_value) const {
|
||||
MS_EXCEPTION_IF_NULL(new_forward_value);
|
||||
MS_LOG(DEBUG) << "Ms func graph has already ran before. The graph phase is: " << graph_phase_;
|
||||
std::vector<tensor::TensorPtr> new_tensors;
|
||||
TensorValueToTensor(new_forward_value, &new_tensors);
|
||||
if (new_tensors.empty()) {
|
||||
MS_LOG(DEBUG) << "The size of added forward tensors is zero, no need to update.";
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(top_cell);
|
||||
const auto &old_tensors = top_cell->op_info_with_ms_func_forward_tensors().at(op_info);
|
||||
if (old_tensors.size() != new_tensors.size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of old tensors is: " << old_tensors.size()
|
||||
<< ", but the size of new tensors is: " << new_tensors.size()
|
||||
<< ", the current op info is: " << op_info;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(grad_executor);
|
||||
for (size_t i = 0; i < new_tensors.size(); ++i) {
|
||||
grad_executor->UpdatePreTensorInfo(new_tensors[i], {old_tensors[i]});
|
||||
MS_EXCEPTION_IF_NULL(old_tensors[i]);
|
||||
old_tensors[i]->set_sync_status(kNeedSyncDeviceToHost);
|
||||
}
|
||||
}
|
||||
|
||||
void MsFunction::GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNodePtrList *input_nodes,
|
||||
|
@ -404,23 +374,15 @@ void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, co
|
|||
MS_EXCEPTION_IF_NULL(grad_executor);
|
||||
|
||||
// Step 1: Update actual output tensors used in grad graph.
|
||||
MS_EXCEPTION_IF_NULL(op_run_info->out_value);
|
||||
MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString();
|
||||
// The output of ms_function may be used in subsequent PyNative process
|
||||
grad_executor->UpdateForwardTensorInfoInBpropGraph(op_run_info);
|
||||
grad_executor->top_cell()->GetOpInfo(op_run_info);
|
||||
grad_executor->UpdateForwardTensorInfoInBpropGraph(op_run_info->op_info, op_run_info->out_value);
|
||||
|
||||
// Step 2: Update output tensors of added forward nodes, which are added to return node of ms_function func graph.
|
||||
if (grad_executor->use_dynamic_shape_process()) {
|
||||
MS_LOG(DEBUG) << "Get dynamic shape process";
|
||||
} else {
|
||||
const auto &pre_top_cell = grad_executor->GetAlreadyRunTopCell(grad_executor->top_cell()->already_run_cell_id());
|
||||
if (pre_top_cell != nullptr && pre_top_cell->op_info_with_ms_func_forward_tensors().find(op_run_info->op_info) !=
|
||||
pre_top_cell->op_info_with_ms_func_forward_tensors().end()) {
|
||||
UpdateMsFunctionForwardTensors(grad_executor, pre_top_cell, op_run_info->op_info, added_out_v);
|
||||
}
|
||||
}
|
||||
grad_executor->UpdateForwardTensorInfoInBpropGraph(op_run_info->op_info + kAddedValue, added_out_v);
|
||||
|
||||
ReplaceWithRealTensorsInGradGraph(grad_executor, added_out_v, ms_func_graph, grad_graph, op_run_info);
|
||||
// Step 3: Replace added cnode forward with actual output
|
||||
ReplaceAddedCnodeActualOutput(grad_executor, added_out_v, ms_func_graph, grad_graph, op_run_info);
|
||||
|
||||
// Change ms function graph to real output
|
||||
auto clone_ms_func_graph = BasicClone(ms_func_graph);
|
||||
|
|
|
@ -52,11 +52,9 @@ class MsFunction {
|
|||
// Update device address of value node in grad graph by forward tensors.
|
||||
void RunReplace(const CNodePtr &added_make_tuple, const std::vector<tensor::TensorPtr> &total_output_tensors,
|
||||
const FuncGraphPtr &grad_graph, bool is_dynamic_shape) const;
|
||||
void ReplaceWithRealTensorsInGradGraph(const GradExecutor *grad_executor, const ValuePtr &added_out,
|
||||
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
|
||||
const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
void UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const TopCellInfoPtr &top_cell,
|
||||
const string &op_info, const ValuePtr &new_forward_value) const;
|
||||
void ReplaceAddedCnodeActualOutput(const GradExecutor *grad_executor, const ValuePtr &added_out,
|
||||
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
|
||||
const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
// Make CNode for ms_function forward graph.
|
||||
void GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNodePtrList *input_nodes,
|
||||
const GradExecutor *grad_executor) const;
|
||||
|
|
|
@ -162,7 +162,6 @@ void TopCellInfo::Clear() {
|
|||
graph_info_map_.clear();
|
||||
op_info_with_tensor_id_.clear();
|
||||
tensor_id_with_tensor_object_.clear();
|
||||
op_info_with_ms_func_forward_tensors_.clear();
|
||||
cnode_hash_with_op_index_.clear();
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,6 @@ namespace py = pybind11;
|
|||
class GradExecutor;
|
||||
using OpInfoWithTensorId = mindspore::HashMap<std::string, std::vector<std::string>>;
|
||||
using TensorIdWithTensorObject = mindspore::HashMap<std::string, std::vector<tensor::TensorPtr>>;
|
||||
using OpInfoWithMsFuncForwardTensors = mindspore::HashMap<std::string, std::vector<tensor::TensorPtr>>;
|
||||
using CellIdWithBackwardHookOp = mindspore::HashMap<std::string, std::vector<AnfNodePtr>>;
|
||||
|
||||
struct GraphInfo {
|
||||
|
@ -124,13 +123,6 @@ class TopCellInfo {
|
|||
inline void set_tensor_id_with_tensor_object(const std::string &id, const tensor::TensorPtr &tensor) {
|
||||
(void)tensor_id_with_tensor_object_[id].emplace_back(tensor);
|
||||
}
|
||||
inline const OpInfoWithMsFuncForwardTensors &op_info_with_ms_func_forward_tensors() const {
|
||||
return op_info_with_ms_func_forward_tensors_;
|
||||
}
|
||||
inline void set_op_info_with_ms_func_forward_tensors(const std::string &op_info,
|
||||
const std::vector<tensor::TensorPtr> &forward_tensors) {
|
||||
op_info_with_ms_func_forward_tensors_[op_info] = forward_tensors;
|
||||
}
|
||||
inline size_t op_index() const { return op_index_; }
|
||||
inline void IncreaseOpIndex() { ++op_index_; }
|
||||
|
||||
|
@ -194,7 +186,6 @@ class TopCellInfo {
|
|||
CellIdWithBackwardHookOp cell_backward_hook_op_;
|
||||
OpInfoWithTensorId op_info_with_tensor_id_;
|
||||
TensorIdWithTensorObject tensor_id_with_tensor_object_;
|
||||
OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_;
|
||||
mindspore::HashMap<size_t, size_t> cnode_hash_with_op_index_;
|
||||
bool use_dynamic_shape_process_{false};
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue