!19868 Add dynamic shape info to opinfo

Merge pull request !19868 from zjun/fix_r1.3
This commit is contained in:
i-robot 2021-07-13 01:21:54 +00:00 committed by Gitee
commit 508f015652
2 changed files with 26 additions and 21 deletions

View File

@ -781,9 +781,7 @@ void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_i
// 4.Get output abstract
bool prim_cache_hit = false;
GetOpOutputAbstract(op_exec_info, args_spec_list, &prim_cache_hit);
// 5.Record op info for dynamic graph judge
grad()->RecordGradOpInfo(op_exec_info);
// 6.Get output
// 5.Get output
GetOpOutput(op_exec_info, args_spec_list, cnode, prim_cache_hit, ret);
}
@ -905,11 +903,7 @@ AnfNodePtr ForwardExecutor::ConstructForwardGraph(const OpExecInfoPtr &op_exec_i
if (grad()->need_construct_graph()) {
auto id = GetId(obj);
AnfNodePtr input_node = nullptr;
bool required_grad = true;
if (op_mask) {
required_grad = meta_tensor->param_info()->requires_grad();
}
input_node = grad()->GetInput(obj, op_mask & required_grad);
input_node = grad()->GetInput(obj, op_mask);
// update abstract
if (input_node != nullptr) {
if (input_node->abstract() != nullptr) {
@ -924,7 +918,7 @@ AnfNodePtr ForwardExecutor::ConstructForwardGraph(const OpExecInfoPtr &op_exec_i
CNodePtr cnode = nullptr;
if (grad()->need_construct_graph()) {
cnode = grad()->curr_g()->NewCNodeInOrder(inputs);
MS_LOG(DEBUG) << "Make CNode for " << op_exec_info->op_name << " new cnode is " << cnode->DebugString();
MS_LOG(DEBUG) << "Make CNode for " << op_exec_info->op_name << ", new cnode is " << cnode->DebugString();
}
return cnode;
}
@ -1009,8 +1003,10 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
} else {
node_abs_map_.clear();
}
*ret = out_real;
// Record op info for judge whether the construct of cell has been changed
grad()->RecordGradOpInfo(op_exec_info, out_real);
grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real);
*ret = out_real;
}
py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
@ -1233,7 +1229,7 @@ AnfNodePtr GradExecutor::GetInput(const py::object &obj, bool op_mask) {
node = MakeValueNode(obj, obj_id);
}
node == nullptr ? MS_LOG(DEBUG) << "Get node is nullptr"
: MS_LOG(DEBUG) << "Get input node " << node->ToString() << " " << obj_id;
: MS_LOG(DEBUG) << "Get input node " << node->ToString() << ", id " << obj_id;
return node;
}
AnfNodePtr GradExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) {
@ -1305,15 +1301,14 @@ TopCellInfoPtr GradExecutor::GetTopCell(const std::string &cell_id) const {
return nullptr;
}
void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info) {
// Record op info for judge whether the construct of cell has been changed
if (!grad_flag()) {
MS_LOG(DEBUG) << "grad flag is set to false, no need to record op info";
void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const py::object &ret) {
if (!grad_flag_) {
MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info";
return;
}
// Record input args info (weight or data)
MS_EXCEPTION_IF_NULL(op_exec_info);
std::string input_args_info;
// Record input args info (weight or data)
for (auto mask : op_exec_info->inputs_mask) {
if (mask) {
input_args_info += "w";
@ -1325,7 +1320,17 @@ void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info) {
op_exec_info->op_info.clear();
const auto &curr_op_num = top_cell()->op_num();
op_exec_info->op_info += op_exec_info->op_name + "-" + std::to_string(curr_op_num) + "-" + input_args_info;
top_cell()->all_op_info() += "_" + op_exec_info->op_info;
// The out shape is added to determine those ops that change the shape
ValuePtr out_value = parse::data_converter::PyDataToValue(ret);
MS_EXCEPTION_IF_NULL(out_value);
auto out_abs = out_value->ToAbstract();
if (out_abs != nullptr) {
auto out_shape = out_abs->BuildShape()->ToString();
if (out_shape.find("()") == std::string::npos && out_shape.find("NoShape") == std::string::npos) {
op_exec_info->op_info += "-" + out_shape;
}
}
top_cell()->all_op_info() += "-" + op_exec_info->op_info;
top_cell()->set_op_num(curr_op_num + 1);
}
@ -1438,7 +1443,7 @@ void GradExecutor::MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, c
SetNodeMapInGraphInfoMap(curr_g_, GetId(out), ms_function_cnode);
// Record ms_function cnode info and update forward tensors
op_exec_info->op_name = graph_phase;
RecordGradOpInfo(op_exec_info);
RecordGradOpInfo(op_exec_info, out);
MS_LOG(DEBUG) << "Ms_function cnode op info: " << op_exec_info->op_info;
UpdateForwardTensorInfoInBpropGraph(op_exec_info, out);
// Add out and dout
@ -1455,13 +1460,13 @@ void GradExecutor::MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, c
}
void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const py::object &out_real) {
if (!grad_flag()) {
if (!grad_flag_) {
MS_LOG(DEBUG) << "The grad flag is false, no need to update forward op info in bprop graph";
return;
}
MS_EXCEPTION_IF_NULL(top_cell_);
MS_EXCEPTION_IF_NULL(op_exec_info);
auto op_info = op_exec_info->op_info;
const auto &op_info = op_exec_info->op_info;
MS_LOG(DEBUG) << "Current op info: " << op_info;
std::vector<tensor::TensorPtr> all_op_tensors;

View File

@ -181,7 +181,7 @@ class GradExecutor {
bool in_cell_with_custom_bprop_() const { return custom_bprop_cell_count_ > 0; }
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
std::string GetCellId(const py::object &obj, const py::args &args);
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info);
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const py::object &ret);
bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; }
void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode);
void DoOpGrad(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &op_out);