forked from mindspore-Ecosystem/mindspore
support anf node
This commit is contained in:
parent
c8f819811e
commit
ab34671105
|
@ -189,58 +189,73 @@ bool IsValuePtrEqual(const ValuePtr &v1, const ValuePtr &v2) {
|
|||
return *v1 == *v2;
|
||||
}
|
||||
|
||||
bool IsParamInfoEqual(const ParamInfoPtr &p1, const ParamInfoPtr &p2) {
|
||||
if (p1 == p2) {
|
||||
return true;
|
||||
}
|
||||
if (p1 == nullptr || p2 == nullptr) {
|
||||
bool IsParamInfoEqual(const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
if (node1->isa<Parameter>() != node2->isa<Parameter>()) {
|
||||
return false;
|
||||
}
|
||||
return p1->key() == p2->key();
|
||||
|
||||
const auto &p1 = node1->cast<ParameterPtr>();
|
||||
const auto &p2 = node2->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(p1);
|
||||
MS_EXCEPTION_IF_NULL(p2);
|
||||
auto param_info1 = p1->param_info();
|
||||
auto param_info2 = p2->param_info();
|
||||
if (param_info1 == param_info2) {
|
||||
return true;
|
||||
}
|
||||
if (param_info1 == nullptr || param_info2 == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return param_info1->key() == param_info2->key();
|
||||
}
|
||||
|
||||
bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, size_t node_index,
|
||||
const std::vector<AnfNodePtr> &new_anf_inputs, const TopCellInfoPtr &top_cell) {
|
||||
MS_EXCEPTION_IF_NULL(old_node_info);
|
||||
auto old_input_size = old_node_info->input_cnode_info.size() + old_node_info->input_values.size() +
|
||||
old_node_info->input_param_infos.size();
|
||||
if (old_input_size != new_anf_inputs.size() - 1) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, old input size: " << old_input_size
|
||||
<< " new input_infos: " << (new_anf_inputs.size() - 1);
|
||||
bool IsCnodeInputsDynamic(const std::vector<AnfNodePtr> &old_anf_inputs, const std::vector<AnfNodePtr> &new_anf_inputs,
|
||||
size_t node_index, const TopCellInfoPtr &top_cell,
|
||||
const std::vector<size_t> &old_op_index_of_cnode_inputs) {
|
||||
if (old_anf_inputs.size() != new_anf_inputs.size()) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, old input size: " << old_anf_inputs.size()
|
||||
<< " new input_infos: " << new_anf_inputs.size();
|
||||
return true;
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < new_anf_inputs.size(); i++) {
|
||||
const auto &new_anf_input = new_anf_inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(new_anf_input);
|
||||
const auto &old_anf_input = old_anf_inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(old_anf_input);
|
||||
|
||||
if (new_anf_input->isa<ValueNode>()) {
|
||||
const auto &value_iter = old_node_info->input_values.find(i);
|
||||
if (value_iter == old_node_info->input_values.end()) {
|
||||
if (!old_anf_input->isa<ValueNode>()) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a value, old input is not a value.";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!IsValuePtrEqual(value_iter->second, GetValueNode(new_anf_input))) {
|
||||
if (!IsValuePtrEqual(GetValueNode(old_anf_input), GetValueNode(new_anf_input))) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input, value is different.";
|
||||
return true;
|
||||
}
|
||||
} else if (new_anf_input->isa<CNode>()) {
|
||||
// Compare cnode abstract.
|
||||
const auto &node_iter = old_node_info->input_cnode_info.find(i);
|
||||
if (node_iter == old_node_info->input_cnode_info.end()) {
|
||||
if (!old_anf_input->isa<CNode>()) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a cnode, old input is not a cnode.";
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t old_op_index = 0;
|
||||
AbstractBasePtr old_abs = nullptr;
|
||||
std::tie(old_op_index, old_abs) = node_iter->second;
|
||||
if (IsAbsDifferent(old_abs, new_anf_input->abstract())) {
|
||||
if (IsAbsDifferent(old_anf_input->abstract(), new_anf_input->abstract())) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input, abs is different.";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (i - 1 >= old_op_index_of_cnode_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "i - 1 is out of range, i - 1:" << (i - 1)
|
||||
<< " old_op_index_of_cnode_inputs.size:" << old_op_index_of_cnode_inputs.size();
|
||||
}
|
||||
|
||||
// Compare cnode edge.
|
||||
auto old_op_index = old_op_index_of_cnode_inputs[i - 1];
|
||||
MS_EXCEPTION_IF_NULL(top_cell);
|
||||
if (old_op_index != top_cell->get_op_index_by_cnode_hash(new_anf_input->hash(), node_index)) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input, op_index is different, old op_index: " << old_op_index
|
||||
|
@ -254,16 +269,7 @@ bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, size_t
|
|||
<< " is none of value node, cnode and parameter.";
|
||||
}
|
||||
|
||||
const auto &node_iter = old_node_info->input_param_infos.find(i);
|
||||
if (node_iter == old_node_info->input_param_infos.end()) {
|
||||
MS_LOG(DEBUG) << "The " << i
|
||||
<< "th input is different, cur input is a parameter, old input is not a parameter.";
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto ¶m = new_anf_input->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
if (!IsParamInfoEqual(node_iter->second, param->param_info())) {
|
||||
if (!IsParamInfoEqual(new_anf_input, old_anf_input)) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input, param info is different.";
|
||||
return true;
|
||||
}
|
||||
|
@ -272,27 +278,37 @@ bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, size_t
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, size_t node_index,
|
||||
const CNodePtr &new_cnode, const TopCellInfoPtr &top_cell) {
|
||||
bool IsDynamicDetectCnodeChange(const DynamicDetectNodeInfoPtr &old_node_info, const CNodePtr &new_cnode,
|
||||
size_t node_index, const TopCellInfoPtr &top_cell) {
|
||||
MS_EXCEPTION_IF_NULL(old_node_info);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
auto old_anf_node = old_node_info->anf_node;
|
||||
if (!old_anf_node->isa<CNode>()) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, new node is a cnode, old node is not a cnode";
|
||||
return true;
|
||||
}
|
||||
|
||||
auto old_cnode = old_anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(old_cnode);
|
||||
|
||||
// 2.Detect cnode prim
|
||||
auto old_prim = GetCNodePrimitive(old_cnode);
|
||||
auto new_prim = GetCNodePrimitive(new_cnode);
|
||||
if (!common::IsEqual(new_prim, old_node_info->prim)) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, old prim: "
|
||||
<< (old_node_info->prim == nullptr ? "nullptr" : old_node_info->prim->name())
|
||||
if (!common::IsEqual(old_prim, new_prim)) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, old prim: " << (old_prim == nullptr ? "nullptr" : old_prim->name())
|
||||
<< " new prim: " << (new_prim == nullptr ? "nullptr" : new_prim->name());
|
||||
return true;
|
||||
}
|
||||
|
||||
// 3.Detect output abs
|
||||
if (IsAbsDifferent(old_node_info->output_abs, new_cnode->abstract())) {
|
||||
if (IsAbsDifferent(old_cnode->abstract(), new_cnode->abstract())) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, output_abs is different";
|
||||
return true;
|
||||
}
|
||||
|
||||
// 4.Detect inputs
|
||||
return IsCnodeInputsDynamic(old_node_info, node_index, new_cnode->inputs(), top_cell);
|
||||
return IsCnodeInputsDynamic(old_cnode->inputs(), new_cnode->inputs(), node_index, top_cell,
|
||||
old_node_info->op_index_of_cnode_inputs);
|
||||
}
|
||||
|
||||
FuncGraphPtr BpropGraphFinalOpt(const FuncGraphPtr &bprop_graph, bool need_renormalize) {
|
||||
|
@ -613,7 +629,9 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
|||
// Update bprop grad stack
|
||||
if (input_args_info->grad_is_running && !bprop_grad_stack_.empty()) {
|
||||
if (!bprop_grad_stack_.top().second) {
|
||||
curr_g()->set_output(GetInput(input_args_info->out_value, out_id));
|
||||
auto output_node = GetInput(input_args_info->out_value, out_id);
|
||||
input_args_info->use_dynamic_shape_process |= CheckGraphDynamic(output_node);
|
||||
curr_g()->set_output(output_node);
|
||||
bprop_grad_stack_.pop();
|
||||
return;
|
||||
} else if (bprop_grad_stack_.top().first == input_args_info->cell_id) {
|
||||
|
@ -621,9 +639,13 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
|||
}
|
||||
}
|
||||
// Just only dump the last forward graph
|
||||
if (is_top_cell_end && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
curr_g()->set_output(GetInput(input_args_info->out_value, out_id));
|
||||
PyNativeAlgo::Common::DumpGraphIR("fg.ir", curr_g());
|
||||
if (is_top_cell_end) {
|
||||
auto output_node = GetInput(input_args_info->out_value, out_id);
|
||||
input_args_info->use_dynamic_shape_process |= CheckGraphDynamic(output_node);
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
curr_g()->set_output(output_node);
|
||||
PyNativeAlgo::Common::DumpGraphIR("fg.ir", curr_g());
|
||||
}
|
||||
}
|
||||
// Reset grad flag and update output node of the outermost cell
|
||||
if (input_args_info->is_grad_topest_cell && is_top_cell_end) {
|
||||
|
@ -668,8 +690,7 @@ void GradExecutor::DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info,
|
|||
if (!input_args_info->grad_is_running || bprop_grad_stack_.top().second) {
|
||||
DoOpGrad(op_run_info, cnode, input_args_info->out_value);
|
||||
}
|
||||
CheckGraphDynamic(cnode, top_cell()->op_index());
|
||||
top_cell()->IncreaseOpIndex();
|
||||
(void)CheckGraphDynamic(cnode);
|
||||
SaveOutputNodeMap(out_id, op_run_info, cnode);
|
||||
}
|
||||
|
||||
|
@ -1104,24 +1125,33 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
|
|||
ClearGradRes();
|
||||
return;
|
||||
}
|
||||
// High grad hit cache
|
||||
if (!top_cell()->vm_compile()) {
|
||||
SwitchTopCell();
|
||||
return;
|
||||
}
|
||||
|
||||
auto first_grad_fg = cur_run_bprop_graph;
|
||||
if (has_custom_bprop) {
|
||||
first_grad_fg = curr_g();
|
||||
MS_LOG(DEBUG) << "Bprop nested";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(first_grad_fg);
|
||||
// Because ConvertPrimToPrimPy will change first_grad_fg, when hit bprop graph cache
|
||||
// resource->func_graph() will be changed, abstract may be nullptr.
|
||||
first_grad_fg = BasicClone(first_grad_fg);
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
|
||||
auto cur_vm_compile = top_cell()->vm_compile();
|
||||
ValuePtrList weights_args;
|
||||
DoParameterReplace(first_grad_fg, forward_args, &inputs, &weights_args);
|
||||
|
||||
auto cnode = curr_g()->NewCNode(inputs);
|
||||
auto out_value = PyNativeAlgo::DataConvert::BaseRefToValue(out);
|
||||
const auto &out_id = PyNativeAlgo::Common::GetIdByValue(out_value);
|
||||
top_cell()->SetNodeMapInGraphInfoMap(out_id, cnode);
|
||||
cnode->set_abstract(out_value->ToAbstract()->Broaden());
|
||||
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString();
|
||||
|
||||
// High grad hit cache
|
||||
if (!cur_vm_compile) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Because ConvertPrimToPrimPy will change first_grad_fg, when hit bprop graph cache
|
||||
// resource->func_graph() will be changed, abstract may be nullptr.
|
||||
first_grad_fg = BasicClone(first_grad_fg);
|
||||
if (!opt::ConvertPrimToPrimPy(first_grad_fg)) {
|
||||
MS_LOG(EXCEPTION) << "Convert PrimitiveC to PrimitivePy failed";
|
||||
}
|
||||
|
@ -1136,11 +1166,6 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
|
|||
r->Clean();
|
||||
|
||||
MS_LOG(DEBUG) << "Get cur graph ptr " << curr_g().get();
|
||||
auto cnode = curr_g()->NewCNode(inputs);
|
||||
auto out_value = PyNativeAlgo::DataConvert::BaseRefToValue(out);
|
||||
const auto &out_id = PyNativeAlgo::Common::GetIdByValue(out_value);
|
||||
top_cell()->SetNodeMapInGraphInfoMap(out_id, cnode);
|
||||
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString();
|
||||
|
||||
// Get input values
|
||||
ValuePtrList input_args(forward_args);
|
||||
|
@ -1380,6 +1405,7 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
|
|||
// Create make tuple node and record to graph info map.
|
||||
auto cnode = curr_g()->NewCNode(inputs);
|
||||
MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
|
||||
(void)CheckGraphDynamic(cnode);
|
||||
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode, -1, false);
|
||||
return cnode;
|
||||
}
|
||||
|
@ -1408,8 +1434,7 @@ AnfNodePtr GradExecutor::CreateTupleGetItemNode(const std::string &obj_id,
|
|||
c_node->set_abstract(prim_abs);
|
||||
}
|
||||
}
|
||||
CheckGraphDynamic(c_node, top_cell()->op_index());
|
||||
top_cell()->IncreaseOpIndex();
|
||||
(void)CheckGraphDynamic(c_node);
|
||||
MS_LOG(DEBUG) << "Get input node " << c_node->ToString() << ", id " << obj_id;
|
||||
return c_node;
|
||||
}
|
||||
|
@ -1470,8 +1495,8 @@ void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) co
|
|||
cnode->set_abstract(op_run_info->base_op_run_info.abstract);
|
||||
SaveOutputNodeMap(op_run_info->out_value_id, op_run_info, cnode);
|
||||
DoOpGrad(op_run_info, cnode, op_run_info->out_value);
|
||||
CheckGraphDynamic(cnode, top_cell()->op_index());
|
||||
UpdateForwardTensorInfoInBpropGraph(op_run_info);
|
||||
(void)CheckGraphDynamic(cnode);
|
||||
}
|
||||
|
||||
void GradExecutor::AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const {
|
||||
|
@ -1771,54 +1796,43 @@ void GradExecutor::SetBpropGraphJitLevel(const py::object &obj) const {
|
|||
graph_executor->SetJitConfig(jit_config_dict);
|
||||
}
|
||||
|
||||
void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t node_idx,
|
||||
void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const AnfNodePtr &anf_node, const size_t node_idx,
|
||||
bool is_ms_function_node,
|
||||
const std::string &graph_phase) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto node_info = std::make_shared<DynamicDetectNodeInfo>();
|
||||
if (!is_ms_function_node) {
|
||||
node_info->prim = GetCNodePrimitive(cnode);
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
const auto &input_node = cnode->input(i);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
|
||||
if (input_node->isa<ValueNode>()) {
|
||||
node_info->input_values[i] = GetValueNode(input_node);
|
||||
} else if (input_node->isa<CNode>()) {
|
||||
const auto &node_abs = input_node->abstract();
|
||||
auto op_index = top_cell()->get_op_index_by_cnode_hash(input_node->hash(), node_idx);
|
||||
node_info->input_cnode_info[i] = std::make_pair(op_index, node_abs);
|
||||
} else {
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "input_node:" << input_node->fullname_with_scope()
|
||||
<< " is none of value node, cnode and parameter.";
|
||||
}
|
||||
const auto ¶m = input_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
node_info->input_param_infos[i] = param->param_info();
|
||||
}
|
||||
node_info->anf_node = anf_node;
|
||||
if (anf_node->isa<CNode>()) {
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
(void)std::transform(
|
||||
cnode->inputs().begin() + 1, cnode->inputs().end(), std::back_inserter(node_info->op_index_of_cnode_inputs),
|
||||
[this, node_idx](const AnfNodePtr &n) { return top_cell()->get_op_index_by_cnode_hash(n->hash(), node_idx); });
|
||||
}
|
||||
node_info->output_abs = cnode->abstract();
|
||||
} else {
|
||||
node_info->is_graph_node = true;
|
||||
node_info->graph_phase = graph_phase;
|
||||
}
|
||||
top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx);
|
||||
|
||||
if (anf_node->isa<CNode>()) {
|
||||
top_cell()->set_cnode_hash_with_op_index(anf_node->hash(), node_idx);
|
||||
}
|
||||
|
||||
(void)cell_id_with_dynamic_detect_nodes_[top_cell()->obj_id_with_grad_order()].emplace_back(node_info);
|
||||
MS_LOG(DEBUG) << "Save node " << cnode->DebugString() << " firstly, node_idx: " << node_idx
|
||||
MS_LOG(DEBUG) << "Save node " << anf_node->DebugString() << " firstly, node_idx: " << node_idx
|
||||
<< ", is_ms_function_node: " << is_ms_function_node << ", graph_phase:" << graph_phase;
|
||||
}
|
||||
|
||||
bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node,
|
||||
bool GradExecutor::IsGraphDynamic(const AnfNodePtr &anf_node, const size_t node_idx, bool is_ms_function_node,
|
||||
const std::string &graph_phase) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
if (!is_cell_id_in_dynamic_detect_nodes_map_) {
|
||||
SaveDynamicDetectNodeInfoInFirstTime(cnode, node_idx, is_ms_function_node, graph_phase);
|
||||
SaveDynamicDetectNodeInfoInFirstTime(anf_node, node_idx, is_ms_function_node, graph_phase);
|
||||
// The net is regarded as a static net by default in the first time.
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Check node " << cnode->DebugString() << " node_idx: " << node_idx
|
||||
MS_LOG(DEBUG) << "Check node " << anf_node->DebugString() << " node_idx: " << node_idx
|
||||
<< ", is_ms_function_node: " << is_ms_function_node << ", graph_phase:" << graph_phase;
|
||||
const auto &dynamic_nodes = cell_id_with_dynamic_detect_nodes_[top_cell()->obj_id_with_grad_order()];
|
||||
if (node_idx >= dynamic_nodes.size()) {
|
||||
|
@ -1839,26 +1853,61 @@ bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t node_idx,
|
|||
return false;
|
||||
}
|
||||
|
||||
if (IsDynamicDetectNodeInfoChange(old_node_info, node_idx, cnode, top_cell())) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, node_idx: " << node_idx
|
||||
<< " is different, cnode: " << cnode->fullname_with_scope();
|
||||
return true;
|
||||
auto old_anf_node = old_node_info->anf_node;
|
||||
MS_EXCEPTION_IF_NULL(old_anf_node);
|
||||
if (anf_node->isa<CNode>()) {
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (IsDynamicDetectCnodeChange(old_node_info, cnode, node_idx, top_cell())) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, node_idx: " << node_idx
|
||||
<< " is different, cnode: " << cnode->fullname_with_scope();
|
||||
return true;
|
||||
}
|
||||
top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx);
|
||||
} else if (anf_node->isa<ValueNode>()) {
|
||||
if (!old_anf_node->isa<ValueNode>()) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, new node: " << anf_node->fullname_with_scope() << " is a value node,"
|
||||
<< " old node: " << old_anf_node->fullname_with_scope() << " is not a value node.";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!IsValuePtrEqual(GetValueNode(old_anf_node), GetValueNode(anf_node))) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, new node: " << anf_node->fullname_with_scope()
|
||||
<< " old node: " << old_anf_node->fullname_with_scope() << " value is different.";
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
if (!anf_node->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "anf_node: " << anf_node->fullname_with_scope()
|
||||
<< " is none of value node, cnode and parameter.";
|
||||
}
|
||||
|
||||
if (!IsParamInfoEqual(anf_node, old_anf_node)) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, new node: " << anf_node->fullname_with_scope()
|
||||
<< " old node: " << old_anf_node->fullname_with_scope() << " is different.";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node,
|
||||
bool GradExecutor::CheckGraphDynamic(const AnfNodePtr &anf_node, bool is_ms_function_node,
|
||||
const std::string &graph_phase) const {
|
||||
if (use_dynamic_shape_process_) {
|
||||
return;
|
||||
top_cell()->IncreaseOpIndex();
|
||||
return use_dynamic_shape_process_;
|
||||
}
|
||||
|
||||
use_dynamic_shape_process_ = IsGraphDynamic(cnode, node_idx, is_ms_function_node, graph_phase);
|
||||
const size_t node_idx = top_cell()->op_index();
|
||||
use_dynamic_shape_process_ = IsGraphDynamic(anf_node, node_idx, is_ms_function_node, graph_phase);
|
||||
top_cell()->IncreaseOpIndex();
|
||||
if (use_dynamic_shape_process_) {
|
||||
MS_LOG(DEBUG) << "Set use_dynamic_shape_process_: " << use_dynamic_shape_process_;
|
||||
cell_id_with_dynamic_detect_nodes_.clear();
|
||||
return use_dynamic_shape_process_;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,13 +37,10 @@ using ForwardExecutorPtr = std::shared_ptr<ForwardExecutor>;
|
|||
using ForwardExecutorWeakPtr = std::weak_ptr<ForwardExecutor>;
|
||||
|
||||
struct DynamicDetectNodeInfo {
|
||||
PrimitivePtr prim{nullptr};
|
||||
AbstractBasePtr output_abs{nullptr};
|
||||
AnfNodePtr anf_node;
|
||||
std::vector<size_t> op_index_of_cnode_inputs;
|
||||
bool is_graph_node{false};
|
||||
std::string graph_phase;
|
||||
mindspore::HashMap<size_t, std::pair<size_t, AbstractBasePtr>> input_cnode_info;
|
||||
mindspore::HashMap<size_t, ValuePtr> input_values;
|
||||
mindspore::HashMap<size_t, ParamInfoPtr> input_param_infos;
|
||||
};
|
||||
using DynamicDetectNodeInfoPtr = std::shared_ptr<DynamicDetectNodeInfo>;
|
||||
|
||||
|
@ -115,7 +112,7 @@ class GradExecutor {
|
|||
const std::vector<tensor::TensorPtr> &pre_tensors) const;
|
||||
void ClearRes();
|
||||
void WorkerJoin() { async_executor_->WorkerJoin(); }
|
||||
void CheckGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node = false,
|
||||
bool CheckGraphDynamic(const AnfNodePtr &anf_node, bool is_ms_function_node = false,
|
||||
const std::string &graph_phase = "") const;
|
||||
|
||||
private:
|
||||
|
@ -193,9 +190,9 @@ class GradExecutor {
|
|||
AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id,
|
||||
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const;
|
||||
|
||||
void SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, size_t node_idx, bool is_ms_function_node,
|
||||
void SaveDynamicDetectNodeInfoInFirstTime(const AnfNodePtr &anf_node, const size_t node_idx, bool is_ms_function_node,
|
||||
const std::string &graph_phase) const;
|
||||
bool IsGraphDynamic(const CNodePtr &cnode, size_t node_idx, bool is_ms_function_node,
|
||||
bool IsGraphDynamic(const AnfNodePtr &anf_node, const size_t node_idx, bool is_ms_function_node,
|
||||
const std::string &graph_phase) const;
|
||||
bool grad_flag_{false};
|
||||
bool grad_is_running_{false};
|
||||
|
|
|
@ -359,8 +359,7 @@ void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, co
|
|||
|
||||
auto grad_exec_ptr = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
|
||||
MS_EXCEPTION_IF_NULL(grad_exec_ptr);
|
||||
grad_exec_ptr->CheckGraphDynamic(ms_function_cnode, op_run_info->op_index, true,
|
||||
op_run_info->base_op_run_info.op_name);
|
||||
(void)grad_exec_ptr->CheckGraphDynamic(ms_function_cnode, true, op_run_info->base_op_run_info.op_name);
|
||||
}
|
||||
|
||||
void MsFunction::SetMsFuncGraphParameters(const FuncGraphPtr &ms_func_graph) {
|
||||
|
|
|
@ -88,7 +88,6 @@ void TopCellInfo::GetOpInfo(const FrontendOpRunInfoPtr &op_run_info) {
|
|||
op_run_info->op_info += "-" + shape->ToString();
|
||||
}
|
||||
op_run_info->op_index = op_index_;
|
||||
++op_index_;
|
||||
}
|
||||
|
||||
void TopCellInfo::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile) {
|
||||
|
|
Loading…
Reference in New Issue