!19868 Add dynamic shape info to opinfo
Merge pull request !19868 from zjun/fix_r1.3
This commit is contained in:
commit
508f015652
|
@ -781,9 +781,7 @@ void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_i
|
|||
// 4.Get output abstract
|
||||
bool prim_cache_hit = false;
|
||||
GetOpOutputAbstract(op_exec_info, args_spec_list, &prim_cache_hit);
|
||||
// 5.Record op info for dynamic graph judge
|
||||
grad()->RecordGradOpInfo(op_exec_info);
|
||||
// 6.Get output
|
||||
// 5.Get output
|
||||
GetOpOutput(op_exec_info, args_spec_list, cnode, prim_cache_hit, ret);
|
||||
}
|
||||
|
||||
|
@ -905,11 +903,7 @@ AnfNodePtr ForwardExecutor::ConstructForwardGraph(const OpExecInfoPtr &op_exec_i
|
|||
if (grad()->need_construct_graph()) {
|
||||
auto id = GetId(obj);
|
||||
AnfNodePtr input_node = nullptr;
|
||||
bool required_grad = true;
|
||||
if (op_mask) {
|
||||
required_grad = meta_tensor->param_info()->requires_grad();
|
||||
}
|
||||
input_node = grad()->GetInput(obj, op_mask & required_grad);
|
||||
input_node = grad()->GetInput(obj, op_mask);
|
||||
// update abstract
|
||||
if (input_node != nullptr) {
|
||||
if (input_node->abstract() != nullptr) {
|
||||
|
@ -924,7 +918,7 @@ AnfNodePtr ForwardExecutor::ConstructForwardGraph(const OpExecInfoPtr &op_exec_i
|
|||
CNodePtr cnode = nullptr;
|
||||
if (grad()->need_construct_graph()) {
|
||||
cnode = grad()->curr_g()->NewCNodeInOrder(inputs);
|
||||
MS_LOG(DEBUG) << "Make CNode for " << op_exec_info->op_name << " new cnode is " << cnode->DebugString();
|
||||
MS_LOG(DEBUG) << "Make CNode for " << op_exec_info->op_name << ", new cnode is " << cnode->DebugString();
|
||||
}
|
||||
return cnode;
|
||||
}
|
||||
|
@ -1009,8 +1003,10 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
|
|||
} else {
|
||||
node_abs_map_.clear();
|
||||
}
|
||||
*ret = out_real;
|
||||
// Record op info for judge whether the construct of cell has been changed
|
||||
grad()->RecordGradOpInfo(op_exec_info, out_real);
|
||||
grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real);
|
||||
*ret = out_real;
|
||||
}
|
||||
|
||||
py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
|
||||
|
@ -1233,7 +1229,7 @@ AnfNodePtr GradExecutor::GetInput(const py::object &obj, bool op_mask) {
|
|||
node = MakeValueNode(obj, obj_id);
|
||||
}
|
||||
node == nullptr ? MS_LOG(DEBUG) << "Get node is nullptr"
|
||||
: MS_LOG(DEBUG) << "Get input node " << node->ToString() << " " << obj_id;
|
||||
: MS_LOG(DEBUG) << "Get input node " << node->ToString() << ", id " << obj_id;
|
||||
return node;
|
||||
}
|
||||
AnfNodePtr GradExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) {
|
||||
|
@ -1305,15 +1301,14 @@ TopCellInfoPtr GradExecutor::GetTopCell(const std::string &cell_id) const {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info) {
|
||||
// Record op info for judge whether the construct of cell has been changed
|
||||
if (!grad_flag()) {
|
||||
MS_LOG(DEBUG) << "grad flag is set to false, no need to record op info";
|
||||
void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const py::object &ret) {
|
||||
if (!grad_flag_) {
|
||||
MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info";
|
||||
return;
|
||||
}
|
||||
// Record input args info (weight or data)
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
std::string input_args_info;
|
||||
// Record input args info (weight or data)
|
||||
for (auto mask : op_exec_info->inputs_mask) {
|
||||
if (mask) {
|
||||
input_args_info += "w";
|
||||
|
@ -1325,7 +1320,17 @@ void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info) {
|
|||
op_exec_info->op_info.clear();
|
||||
const auto &curr_op_num = top_cell()->op_num();
|
||||
op_exec_info->op_info += op_exec_info->op_name + "-" + std::to_string(curr_op_num) + "-" + input_args_info;
|
||||
top_cell()->all_op_info() += "_" + op_exec_info->op_info;
|
||||
// The out shape is added to determine those ops that change the shape
|
||||
ValuePtr out_value = parse::data_converter::PyDataToValue(ret);
|
||||
MS_EXCEPTION_IF_NULL(out_value);
|
||||
auto out_abs = out_value->ToAbstract();
|
||||
if (out_abs != nullptr) {
|
||||
auto out_shape = out_abs->BuildShape()->ToString();
|
||||
if (out_shape.find("()") == std::string::npos && out_shape.find("NoShape") == std::string::npos) {
|
||||
op_exec_info->op_info += "-" + out_shape;
|
||||
}
|
||||
}
|
||||
top_cell()->all_op_info() += "-" + op_exec_info->op_info;
|
||||
top_cell()->set_op_num(curr_op_num + 1);
|
||||
}
|
||||
|
||||
|
@ -1438,7 +1443,7 @@ void GradExecutor::MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, c
|
|||
SetNodeMapInGraphInfoMap(curr_g_, GetId(out), ms_function_cnode);
|
||||
// Record ms_function cnode info and update forward tensors
|
||||
op_exec_info->op_name = graph_phase;
|
||||
RecordGradOpInfo(op_exec_info);
|
||||
RecordGradOpInfo(op_exec_info, out);
|
||||
MS_LOG(DEBUG) << "Ms_function cnode op info: " << op_exec_info->op_info;
|
||||
UpdateForwardTensorInfoInBpropGraph(op_exec_info, out);
|
||||
// Add out and dout
|
||||
|
@ -1455,13 +1460,13 @@ void GradExecutor::MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, c
|
|||
}
|
||||
|
||||
void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const py::object &out_real) {
|
||||
if (!grad_flag()) {
|
||||
if (!grad_flag_) {
|
||||
MS_LOG(DEBUG) << "The grad flag is false, no need to update forward op info in bprop graph";
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(top_cell_);
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
auto op_info = op_exec_info->op_info;
|
||||
const auto &op_info = op_exec_info->op_info;
|
||||
MS_LOG(DEBUG) << "Current op info: " << op_info;
|
||||
|
||||
std::vector<tensor::TensorPtr> all_op_tensors;
|
||||
|
|
|
@ -181,7 +181,7 @@ class GradExecutor {
|
|||
bool in_cell_with_custom_bprop_() const { return custom_bprop_cell_count_ > 0; }
|
||||
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
||||
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info);
|
||||
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const py::object &ret);
|
||||
bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; }
|
||||
void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode);
|
||||
void DoOpGrad(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &op_out);
|
||||
|
|
Loading…
Reference in New Issue