From 69280185c4a1cbf48c71617d753dca8535ed14e3 Mon Sep 17 00:00:00 2001 From: He Wei Date: Mon, 1 Aug 2022 15:16:00 +0800 Subject: [PATCH] Optimize pointer casting for compile framework Use raw pointers whenever possible to avoid shared_ptr overhead. --- .../ccsrc/frontend/optimizer/ad/dfunctor.cc | 42 +++-- .../ccsrc/frontend/optimizer/ad/kprim.cc | 20 +-- .../ccsrc/frontend/optimizer/expander.cc | 23 +-- .../frontend/optimizer/graph_transform.cc | 2 +- mindspore/ccsrc/frontend/optimizer/opt.cc | 76 ++++---- mindspore/ccsrc/frontend/optimizer/pattern.cc | 22 +-- mindspore/ccsrc/frontend/optimizer/py_pass.cc | 24 +-- .../frontend/optimizer/py_pass_manager.cc | 4 +- .../ccsrc/frontend/optimizer/recompute.cc | 12 +- .../pipeline/jit/compile_cache_manager.cc | 2 +- mindspore/ccsrc/pipeline/jit/parse/resolve.cc | 8 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 31 ++-- .../ccsrc/pipeline/jit/pipeline_split.cc | 11 +- .../pipeline/jit/remove_value_node_dup.cc | 22 +-- .../jit/static_analysis/async_eval_result.cc | 11 +- .../pipeline/jit/static_analysis/evaluator.cc | 51 +++--- .../pipeline/jit/static_analysis/prim.cc | 162 +++++++++--------- .../jit/static_analysis/program_specialize.cc | 70 ++++---- .../jit/static_analysis/stack_frame.cc | 8 +- .../jit/static_analysis/static_analysis.cc | 26 +-- mindspore/ccsrc/pipeline/jit/validator.cc | 12 +- mindspore/core/abstract/abstract_function.cc | 6 +- mindspore/core/abstract/abstract_value.cc | 39 +++-- mindspore/core/abstract/abstract_value.h | 6 +- mindspore/core/abstract/analysis_context.cc | 6 +- mindspore/core/abstract/param_validator.cc | 4 +- mindspore/core/abstract/utils.cc | 14 +- mindspore/core/base/base_ref.cc | 6 +- mindspore/core/ir/anf.h | 6 +- mindspore/core/ir/dtype/ref.h | 7 +- mindspore/core/ir/func_graph.cc | 6 +- mindspore/core/ir/func_graph_cloner.cc | 18 +- mindspore/core/ir/func_graph_extends.cc | 2 +- mindspore/core/ir/graph_utils.cc | 61 +++---- mindspore/core/ir/graph_utils_extends.cc | 4 +- mindspore/core/ir/manager.cc | 4 +- mindspore/core/ir/value.h | 2 +- mindspore/core/ir/visitor.cc | 24 ++- mindspore/core/utils/ms_utils.h | 9 +- 39 files changed, 430 insertions(+), 433 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 7026103cfc4..a1ef17dc249 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -143,7 +143,7 @@ void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNode MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << "."; } mindspore::HashMap node_to_fg; - auto tuple_graphs = input->cast(); + auto tuple_graphs = input->cast_ptr(); for (size_t i = 1; i < tuple_graphs->size(); ++i) { auto graph = tuple_graphs->input(i); if (!IsValueNode(graph)) { @@ -191,7 +191,7 @@ static AnfNodePtr SkipHookNodeInBackProp(const AnfNodePtr &node) { if (IsPrimitiveCNode(node, prim::kPrimHookBackward) || IsPrimitiveCNode(node, prim::kPrimCellBackwardHook)) { MS_LOG(WARNING) << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation."; - auto output_cnode = node->cast(); + auto output_cnode = node->cast_ptr(); if (output_cnode->size() - 1 == 1) { return output_cnode->input(1); } @@ -217,16 +217,16 @@ static AnfNodePtr SkipHookNodeInBackProp(const AnfNodePtr &node) { return make_tuple; } if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { - auto tuple_get_item = node->cast(); + auto tuple_get_item = node->cast_ptr(); auto inp = tuple_get_item->input(1); if (IsPrimitiveCNode(inp, prim::kPrimHookBackward) || IsPrimitiveCNode(inp, prim::kPrimCellBackwardHook)) { MS_LOG(WARNING) << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation."; constexpr size_t idx = 2; - auto v_node = tuple_get_item->input(idx)->cast(); + auto v_node = tuple_get_item->input(idx)->cast_ptr(); MS_EXCEPTION_IF_NULL(v_node); auto out_idx = GetValue(v_node->value()); - return inp->cast()->input(LongToSize(out_idx) + 1); + return inp->cast_ptr()->input(LongToSize(out_idx) + 1); } } return node; @@ -238,7 +238,7 @@ AnfNodePtr HandleRealToComplex(const AnfNodePtr &input, const CNodePtr &din, con if (input_type == nullptr || !input_type->isa()) { return din; } - input_type = input_type->cast()->element(); + input_type = input_type->cast_ptr()->element(); MS_EXCEPTION_IF_NULL(input_type); if (input_type->type_id() == kNumberTypeComplex64 || input_type->type_id() == kNumberTypeComplex128) { return din; @@ -256,7 +256,7 @@ AnfNodePtr HandleRealToComplex(const AnfNodePtr &input, const CNodePtr &din, con if (din_type == nullptr || !din_type->isa()) { return din; } - din_type = din_type->cast()->element(); + din_type = din_type->cast_ptr()->element(); MS_EXCEPTION_IF_NULL(din_type); if (din_type->type_id() != kNumberTypeComplex64 && din_type->type_id() != kNumberTypeComplex128) { return din; @@ -689,8 +689,8 @@ void DFunctor::MapValueObject() { AdjointPtr adjoint = nullptr; if (IsValueNode(node)) { // Primitive. - auto prim = GetValueNode(node); - if (GetValueNode(node) == prim::kPrimReturn || + auto prim = GetValuePtr(node); + if ((prim->Hash() == prim::kPrimReturn->hash() && prim->name() == prim::kPrimReturn->name()) || (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) || (prim->Hash() == prim::kPrimCellBackwardHook->Hash() && prim->name() == prim::kPrimCellBackwardHook->name())) { @@ -784,17 +784,15 @@ void DFunctor::BroadCastStopFlag() { while (need_cut_) { need_cut_ = false; for (auto &node : primal_graph_->nodes()) { - if (node->isa()) { - auto cnode = node->cast(); - if (!cnode->stop_gradient()) { - // Cut off the cnode only when it's not referred any more - if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) || - AllReferencesStopped(cnode)) { - MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << "."; - cnode->set_stop_gradient(true); - // The stop set changed, more cut required - need_cut_ = true; - } + auto cnode = dyn_cast(node); + if (cnode != nullptr && !cnode->stop_gradient()) { + // Cut off the cnode only when it's not referred any more + if (cnode->IsApply(prim::kPrimStopGradient) || cnode->IsApply(prim::kPrimUpdateState) || + AllReferencesStopped(cnode)) { + MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << "."; + cnode->set_stop_gradient(true); + // The stop set changed, more cut required + need_cut_ = true; } } } @@ -809,7 +807,7 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) { } for (auto &kv : users) { auto &user = kv.first; - if (!user->isa() || !user->cast()->stop_gradient()) { + if (!user->isa() || !user->cast_ptr()->stop_gradient()) { return false; } } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index e6acc83bdd5..4a0be45e433 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -225,7 +225,7 @@ AnfNodePtr GetPythonOps(const FuncGraphPtr &fg, const AnfNodePtr &origin_node, c } else { python_ops_value = prim::GetPythonOps(iter->second.first); } - auto origin_cnode = origin_node->cast(); + auto origin_cnode = origin_node->cast_ptr(); MS_EXCEPTION_IF_NULL(origin_cnode); auto &origin_inputs = origin_cnode->inputs(); std::vector new_inputs{NewValueNode(python_ops_value)}; @@ -243,7 +243,7 @@ void ReplacePythonOps(const FuncGraphPtr &fg) { if (!node->isa()) { continue; } - auto cnode = node->cast(); + auto cnode = node->cast_ptr(); for (size_t i = 0; i < cnode->size(); ++i) { auto prim = GetCNodePrimitive(cnode->input(i)); if (prim == nullptr) { @@ -388,7 +388,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceB if (prim->is_base()) { fn = GetBpropFunction(prim->name()); } else { - fn = prim->cast()->GetBpropFunction(); + fn = prim->cast_ptr()->GetBpropFunction(); if (py::isinstance(fn)) { fn = GetBpropFunction(prim->name()); } @@ -491,8 +491,8 @@ static void AppendMonadOutput(const FuncGraphPtr &bprop_fg, const AnfNodePtr &mo auto output_cnode = output->cast(); if (output_cnode != nullptr) { // If output_cnode has the form like (make_tuple, x, y). - while (IsPrimitiveCNode(output_cnode, prim::kPrimDepend)) { - auto real_input = output_cnode->input(kRealInputIndexInDepend); + while (output_cnode->IsApply(prim::kPrimDepend)) { + const auto &real_input = output_cnode->input(kRealInputIndexInDepend); MS_EXCEPTION_IF_NULL(real_input); output_cnode = real_input->cast(); } @@ -555,7 +555,7 @@ void SetDumpFlag(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) { return; } auto attr = prim->GetAttr(kAttrDump); - if (attr != nullptr && attr->isa() && attr->cast()->value() == kValueTrue) { + if (attr != nullptr && attr->isa() && attr->cast_ptr()->value() == kValueTrue) { bprop_fg->set_flag(FUNC_GRAPH_FLAG_DUMP, true); } } @@ -660,7 +660,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr & // bprop_fg has been checked in caller if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) { // Set bprop output as (env, dx, dy, dz, ...) - auto cbprop = bprop_fg->output()->cast(); + auto cbprop = bprop_fg->output()->cast_ptr(); auto &inputs = cbprop->inputs(); std::vector args; @@ -742,7 +742,7 @@ void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const Func const auto ¤t_primal_fg_params = current_primal_fg->parameters(); // The lifted parameters are put in front: {lifted parameters, origin parameters, u/io monad}. for (size_t i = 0; i < current_primal_fg_params.size(); ++i) { - auto primal_parameter = dyn_cast(current_primal_fg_params[i]); + auto primal_parameter = dyn_cast_ptr(current_primal_fg_params[i]); MS_EXCEPTION_IF_NULL(primal_parameter); auto lifted = primal_parameter->template user_data(kLiftedUserDataKey); if (lifted == nullptr || !*lifted) { @@ -845,7 +845,7 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res if (cnode == users.end()) { MS_LOG(EXCEPTION) << "Fail to find cnode."; } - auto inputs_num = cnode->first->cast()->size() - 1; + auto inputs_num = cnode->first->cast_ptr()->size() - 1; auto func_graph = std::make_shared(); std::vector outputs; @@ -886,7 +886,7 @@ FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::Re if (cnode == users.end()) { MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString(); } - auto inputs_num = cnode->first->cast()->inputs().size() - 1; + auto inputs_num = cnode->first->cast_ptr()->inputs().size() - 1; auto effect_info = GetPrimEffectInfo(prim); // Don't add U or IO monad parameters as it will be added later. size_t monad_params_size = 0; diff --git a/mindspore/ccsrc/frontend/optimizer/expander.cc b/mindspore/ccsrc/frontend/optimizer/expander.cc index 32b78a1d223..0ab61eee772 100644 --- a/mindspore/ccsrc/frontend/optimizer/expander.cc +++ b/mindspore/ccsrc/frontend/optimizer/expander.cc @@ -33,22 +33,25 @@ namespace mindspore { /* namespace to support opt */ namespace opt { bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) { + static const std::map> op2attrs = { + {prim::kPrimBroadcastTo->name(), {kAttrShape}}, + {prim::kPrimReduceMax->name(), {kAttrKeepDims}}, + {prim::kPrimReduceMin->name(), {kAttrKeepDims}}, + {prim::kPrimReduceSum->name(), {kAttrKeepDims}}}; + auto todos = TopoSort(graph->get_return()); for (const auto &node : todos) { if (!node->isa() || !AnfUtils::IsRealKernel(node)) { continue; } auto primitive = GetCNodePrimitive(node); - if (!primitive || dyn_cast(primitive)) { + if (primitive == nullptr || primitive->isa()) { continue; } parallel::OperatorAttrs attrs; - std::map> op2attrs = {{prim::kPrimBroadcastTo->name(), {kAttrShape}}, - {prim::kPrimReduceMax->name(), {kAttrKeepDims}}, - {prim::kPrimReduceMin->name(), {kAttrKeepDims}}, - {prim::kPrimReduceSum->name(), {kAttrKeepDims}}}; - if (op2attrs.count(primitive->name()) != 0) { - for (auto &attr : op2attrs[primitive->name()]) { + auto iter = op2attrs.find(primitive->name()); + if (iter != op2attrs.end()) { + for (auto &attr : iter->second) { if (primitive->HasAttr(attr)) { (void)attrs.emplace_back(std::pair{attr, primitive->GetAttr(attr)}); } else { @@ -57,10 +60,10 @@ bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) { } } } - auto new_prim = parallel::CreateOpInstance(attrs, primitive->name(), "")->cast(); - (void)new_prim->SetAttrs(primitive->attrs()); + auto new_prim = parallel::CreateOpInstance(attrs, primitive->name(), ""); + (void)new_prim->cast_ptr()->SetAttrs(primitive->attrs()); AnfNodePtrList inputs = {NewValueNode(new_prim)}; - auto cnode = dyn_cast(node); + auto cnode = dyn_cast_ptr(node); (void)inputs.insert(inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend()); cnode->set_inputs(inputs); } diff --git a/mindspore/ccsrc/frontend/optimizer/graph_transform.cc b/mindspore/ccsrc/frontend/optimizer/graph_transform.cc index f8aa6c7a83f..a68b404b790 100644 --- a/mindspore/ccsrc/frontend/optimizer/graph_transform.cc +++ b/mindspore/ccsrc/frontend/optimizer/graph_transform.cc @@ -28,7 +28,7 @@ bool ContainSparseTensor(const abstract::AbstractBasePtr &abs) { return true; } if (abs->isa()) { - auto vec = abs->cast()->elements(); + auto vec = abs->cast_ptr()->elements(); return std::any_of(vec.begin(), vec.end(), ContainSparseTensor); } return false; diff --git a/mindspore/ccsrc/frontend/optimizer/opt.cc b/mindspore/ccsrc/frontend/optimizer/opt.cc index 03191a971c1..815e8e53bf3 100644 --- a/mindspore/ccsrc/frontend/optimizer/opt.cc +++ b/mindspore/ccsrc/frontend/optimizer/opt.cc @@ -40,27 +40,20 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std: const std::vector &prims, const RenormAction &renorm_action, bool has_priority_pattern) { auto fn = [prims](const AnfNodePtr &node) -> bool { - if (!node->isa()) { + auto cnode = dyn_cast_ptr(node); + if (cnode == nullptr) { return false; } - - auto cnode = node->cast(); - auto inp0 = cnode->input(0); - auto prim0 = GetValueNode(inp0); - if (prim0 == nullptr) { + auto cnode_prim = GetValuePtr(cnode->input(0)); + if (cnode_prim == nullptr) { return false; } - - auto hash = prim0->Hash(); - auto const &name = prim0->name(); - for (auto &prim : prims) { - if (hash == prim->Hash() && name == prim->name()) { - return true; - } - } - return false; + auto hash = cnode_prim->Hash(); + const auto &name = cnode_prim->name(); + return std::any_of(prims.begin(), prims.end(), [&hash, &name](const PrimitivePtr &prim) { + return (prim->Hash() == hash) && (prim->name() == name); + }); }; - return std::make_shared(transform, name, fn, renorm_action, has_priority_pattern); } @@ -93,17 +86,15 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode return result; } -static bool isTraversable(const AnfNodePtr &node) { - if (node == nullptr) { - return false; - } +static inline bool isTraversable(const AnfNodePtr &node) { if (node->isa() || node->isa()) { return true; } - if (IsValueNode(node) || IsValueNode(node)) { - return true; - } - return false; + // FuncGraph or RefKey value node is traversable. + auto value_node = dyn_cast_ptr(node); + MS_EXCEPTION_IF_NULL(value_node); + const auto &value = value_node->value(); + return (value != nullptr) && (value->isa() || value->isa()); } static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, @@ -131,34 +122,38 @@ static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &n } static void UpdateTransformingListForSubstitutions(const AnfNodePtr &node, std::deque *todo, bool change) { - if (IsValueNode(node)) { - (*todo).emplace_back(GetValueNode(node)->output()); + auto fg = GetValuePtr(node); + if (fg != nullptr) { + (void)todo->emplace_back(fg->output()); } if (change) { - (*todo).emplace_back(node); + (void)todo->emplace_back(node); } else { - if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo)); + auto cnode = dyn_cast_ptr(node); + if (cnode != nullptr) { + const auto &inputs = cnode->inputs(); + (void)todo->insert(todo->end(), inputs.cbegin(), inputs.cend()); } } } static void UpdateTransformingListForIR(const AnfNodePtr &node, std::deque *todo, bool change, const SubstitutionPtr &substitution) { - if (IsValueNode(node)) { - (*todo).emplace_back(GetValueNode(node)->output()); + auto fg = GetValuePtr(node); + if (fg != nullptr) { + (void)todo->emplace_back(fg->output()); } // If there is a priority pattern in substitution, don't transform the new node, // otherwise some nodes may match the wrong patterns. if (change && substitution != nullptr && !substitution->has_priority_pattern_) { - (*todo).emplace_back(node); + (void)todo->emplace_back(node); } else { - if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo)); + auto cnode = dyn_cast_ptr(node); + if (cnode != nullptr) { + const auto &inputs = cnode->inputs(); + (void)todo->insert(todo->end(), inputs.cbegin(), inputs.cend()); } } } @@ -194,12 +189,11 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con FuncGraphManagerPtr manager = optimizer->manager(); auto seen = NewSeenGeneration(); std::deque todo; - todo.emplace_back(func_graph->output()); + (void)todo.emplace_back(func_graph->output()); bool changes = false; - auto &all_nodes = manager->all_nodes(); while (!todo.empty()) { - AnfNodePtr node = todo.front(); + AnfNodePtr node = std::move(todo.front()); todo.pop_front(); if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { @@ -373,13 +367,13 @@ bool SimpleRewriter::Run() { continue; } node->seen_ = seen; - auto cnode = node->cast(); + auto cnode = node->cast_ptr(); if (cnode != nullptr) { for (auto &input : cnode->inputs()) { add_todo(input); } } else { - auto fg = GetValueNode(node); + auto fg = GetValuePtr(node); if (fg != nullptr) { add_todo(fg->output()); } diff --git a/mindspore/ccsrc/frontend/optimizer/pattern.cc b/mindspore/ccsrc/frontend/optimizer/pattern.cc index 96437c98a3f..2b6c16066f7 100644 --- a/mindspore/ccsrc/frontend/optimizer/pattern.cc +++ b/mindspore/ccsrc/frontend/optimizer/pattern.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ MatchResultPtr Call::match(const AnfNodePtr &node) { } MatchResultPtr res = std::make_shared(); // IsPrimitiveCNode - auto cnode = node->cast(); + auto cnode = node->cast_ptr(); MS_EXCEPTION_IF_NULL(cnode); // Check Primitive ValueNode if (prim_pattern_ != nullptr) { @@ -120,20 +120,16 @@ MatchResultPtr Any::match(const AnfNodePtr &node) { } MatchResultPtr Imm::match(const AnfNodePtr &node) { - if (!IsValueNode(node)) { + auto value_ptr = GetValuePtr(node); + if (value_ptr == nullptr) { return nullptr; } - // Check value - auto value_node = node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value_ptr = value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(value_ptr); - if (value_ptr->value() == value_) { - MatchResultPtr res = std::make_shared(); - res->add_entry(shared_from_base(), node); - return res; + if (value_ptr->value() != value_) { + return nullptr; } - return nullptr; + MatchResultPtr res = std::make_shared(); + res->add_entry(shared_from_base(), node); + return res; } AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) { diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.cc b/mindspore/ccsrc/frontend/optimizer/py_pass.cc index c236d2de981..d86ac851c87 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -57,7 +57,7 @@ bool IsTraversable(const AnfNodePtr &node) { AnfNodePtr BuildPrimitive(const PatternPtr &pattern) { // Build up AnfNode from primitive - auto prim_pattern = pattern->cast(); + auto prim_pattern = pattern->cast_ptr(); MS_EXCEPTION_IF_NULL(prim_pattern); PrimitivePyPtr prim = prim_pattern->matched_primitive(); MS_EXCEPTION_IF_NULL(prim); @@ -67,7 +67,7 @@ AnfNodePtr BuildPrimitive(const PatternPtr &pattern) { AnfNodePtr BuildNewTensor(const PatternPtr &pattern) { // Build a ValueNode from TensorPtr - auto new_tensor_pattern = pattern->cast(); + auto new_tensor_pattern = pattern->cast_ptr(); MS_EXCEPTION_IF_NULL(new_tensor_pattern); auto input_tensor = new_tensor_pattern->input_tensor(); MS_EXCEPTION_IF_NULL(input_tensor); @@ -76,7 +76,7 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern) { AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg, const FuncGraphPtr &top_graph) { - auto call_pattern = pattern->cast(); + auto call_pattern = pattern->cast_ptr(); MS_EXCEPTION_IF_NULL(call_pattern); auto prim = call_pattern->prim_value(); if (prim != nullptr) { @@ -88,7 +88,7 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP } AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &top_graph) { - auto new_para_pattern = pattern->cast(); + auto new_para_pattern = pattern->cast_ptr(); MS_EXCEPTION_IF_NULL(new_para_pattern); if (!new_para_pattern->built()) { static int64_t parameter_id = 0; @@ -122,7 +122,7 @@ AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &re } AnfNodePtr BuildImmNode(const PatternPtr &pattern) { - auto imm_pattern = pattern->cast(); + auto imm_pattern = pattern->cast_ptr(); MS_EXCEPTION_IF_NULL(imm_pattern); auto value = imm_pattern->value(); auto scalar_value_ptr = std::make_shared(value); @@ -134,7 +134,7 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr auto target_node = res->get_node(pattern); if (target_node != nullptr) { // If pattern is NewParameter, check whether it shouldn't last and is not built - auto new_para = pattern->cast(); + auto new_para = pattern->cast_ptr(); if (new_para == nullptr || new_para->should_last() || new_para->built()) { return target_node; } @@ -218,20 +218,20 @@ void ReflectParamBackToPython(const AnfNodePtr ¶m, const string ¶m_name, MS_LOG(EXCEPTION) << "Failed to convert new parameter to ValuePtr."; } MS_EXCEPTION_IF_NULL(param); - auto param_node = param->cast(); + auto param_node = param->cast_ptr(); MS_EXCEPTION_IF_NULL(param_node); param_node->set_default_param(param_value); } void Reset(const PatternPtr &pattern) { if (pattern->isa()) { - auto prim_pattern = pattern->cast(); + auto prim_pattern = pattern->cast_ptr(); prim_pattern->reset(); } else if (pattern->isa()) { - auto new_param_pattern = pattern->cast(); + auto new_param_pattern = pattern->cast_ptr(); new_param_pattern->reset(); } else if (pattern->isa()) { - auto call_with_pattern = pattern->cast(); + auto call_with_pattern = pattern->cast_ptr(); for (const auto &sub_pattern : call_with_pattern->inputs()) { Reset(sub_pattern); } @@ -257,7 +257,7 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) MS_EXCEPTION_IF_NULL(dst_pattern_); if (src_pattern_ == nullptr) { // Add NewParameter - auto new_para_pattern = dst_pattern_->cast(); + auto new_para_pattern = dst_pattern_->cast_ptr(); if (new_para_pattern == nullptr) { MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null."; } diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc index 207aa44e1d7..99147a62615 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,7 +81,7 @@ void PyPassManager::GenNewParameter(const PatternPtr ¶meter) { auto cur_pg = GetPassGroup(Phase::OPT); MS_EXCEPTION_IF_NULL(cur_pg); cur_pg->SetRunOnlyOnce(true); - auto new_para_pattern = parameter->cast(); + auto new_para_pattern = parameter->cast_ptr(); MS_EXCEPTION_IF_NULL(new_para_pattern); auto pass_name = new_para_pattern->para_name(); new_para_pattern->set_last(true); diff --git a/mindspore/ccsrc/frontend/optimizer/recompute.cc b/mindspore/ccsrc/frontend/optimizer/recompute.cc index 680deae476d..45ec1a16d3a 100644 --- a/mindspore/ccsrc/frontend/optimizer/recompute.cc +++ b/mindspore/ccsrc/frontend/optimizer/recompute.cc @@ -59,7 +59,7 @@ bool WithRecomputedScope(const AnfNodePtr &node) { ValuePtr GetRecomputeCNodeAttr(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); + auto cnode = node->cast_ptr(); if (cnode == nullptr) { return nullptr; } @@ -222,7 +222,7 @@ bool HasGradInputs(const AnfNodePtr &node, mindspore::HashMap if (has_grad_inputs_map->find(node) != has_grad_inputs_map->end()) { return has_grad_inputs_map->find(node)->second; } - auto cnode = node->cast(); + auto cnode = node->cast_ptr(); if (cnode == nullptr) { (void)has_grad_inputs_map->emplace(node, false); return false; @@ -230,7 +230,7 @@ bool HasGradInputs(const AnfNodePtr &node, mindspore::HashMap const auto &inputs = cnode->inputs(); for (size_t i = 0; i < inputs.size(); ++i) { // For the pipeline split case, the forward pass may depend on the backward pass. - if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && i == kDependAttachNodeIndex) { + if (cnode->IsApply(prim::kPrimDepend) && i == kDependAttachNodeIndex) { continue; } if (IsBpropNode(inputs[i]) || HasGradInputs(inputs[i], has_grad_inputs_map)) { @@ -320,7 +320,7 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector &o std::vector tuple_getitem_output_nodes; GetTupleGetItemOutputNodes(mng, node, &tuple_getitem_output_nodes); for (const auto &output_node : tuple_getitem_output_nodes) { - auto output_cnode = output_node->cast(); + auto output_cnode = output_node->cast_ptr(); MS_EXCEPTION_IF_NULL(output_cnode); output_cnode->AddAttr(kAttrRecompute, MakeValue(true)); } @@ -362,7 +362,7 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod for (size_t i = 0; i < origin_node->size(); ++i) { auto input = origin_node->input(i); if (i == 0 && IsPrimitive(input, prim::kPrimAllGather)) { - auto prim = GetValueNode(input); + auto prim = GetValuePtr(input); auto instance_name = prim->instance_name(); bool is_from_parallel_optimizer = instance_name.find("parallel_optimizer") != std::string::npos; int64_t fusion_id = prim->HasAttr(kAttrFusion) ? GetValue(prim->GetAttr(kAttrFusion)) : 0; @@ -426,7 +426,7 @@ void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const mindspore::HashSe for (const auto &target_node : target_nodes) { MS_EXCEPTION_IF_NULL(target_node); MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input"; - auto target_cnode = target_node->cast(); + auto target_cnode = target_node->cast_ptr(); MS_EXCEPTION_IF_NULL(target_cnode); std::vector new_target_inputs; for (const auto &input : target_cnode->inputs()) { diff --git a/mindspore/ccsrc/pipeline/jit/compile_cache_manager.cc b/mindspore/ccsrc/pipeline/jit/compile_cache_manager.cc index e44aa5eb839..9e89183b1ec 100644 --- a/mindspore/ccsrc/pipeline/jit/compile_cache_manager.cc +++ b/mindspore/ccsrc/pipeline/jit/compile_cache_manager.cc @@ -278,7 +278,7 @@ FuncGraphPtr CompileCacheManager::GetCachedFuncGraph(const FuncGraphManagerPtr & // The value of attr "shared_name" will changed every time. auto cnodes = fg->GetOrderedCnodes(); for (const auto &cnode : cnodes) { - auto prim = GetValueNode(cnode->input(0)); + auto prim = GetValuePtr(cnode->input(0)); if (prim != nullptr && prim->HasAttr("shared_name")) { prim->set_attr("shared_name", MakeValue(queue_name)); break; diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 533af8fd1e2..882393030d4 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -22,6 +22,7 @@ #include #include "ir/param_info.h" +#include "ir/value.h" #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/parse.h" #include "include/common/utils/python_adapter.h" @@ -31,6 +32,7 @@ #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass/symbol_resolver.h" #include "include/common/debug/anf_dump_utils.h" +#include "utils/log_adapter.h" namespace mindspore { namespace parse { @@ -452,7 +454,7 @@ AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::obj py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object namespace_obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); auto new_namespace = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj); - std::string attr_as_string = GetValueNode(attr)->value(); + const std::string &attr_as_string = GetValuePtr(attr)->value(); auto new_symbol = std::make_shared(attr_as_string); MS_LOG(DEBUG) << "name_space: " << new_namespace->ToString() << ", symbol: " << new_symbol->ToString(); @@ -489,7 +491,9 @@ AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py: } else if (count_msclass == sequence_size) { // Resolve MsClass instances. for (size_t i = 0; i < sequence_size; ++i) { - auto attr_str = GetValue(GetValueNode(attr)); + auto attr_str_ptr = GetValuePtr(attr); + MS_EXCEPTION_IF_NULL(attr_str_ptr); + const auto &attr_str = attr_str_ptr->value(); auto res = ResolveMsClassWithAttr(manager, sequence[i], attr_str, get_attr_node); (void)inputs.emplace_back(res); } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index c78636afc5b..ad9337639d7 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -214,7 +214,7 @@ void SetValueMutable(const abstract::AbstractBasePtr &abs) { return; } - auto abs_sequence = abs->cast(); + auto abs_sequence = abs->cast_ptr(); if (abs_sequence != nullptr) { const auto &elements = abs_sequence->elements(); for (auto &ele : elements) { @@ -223,7 +223,7 @@ void SetValueMutable(const abstract::AbstractBasePtr &abs) { return; } - auto abs_dict = abs->cast(); + auto abs_dict = abs->cast_ptr(); if (abs_dict != nullptr) { const auto &elements = abs_dict->elements(); for (auto &ele : elements) { @@ -590,16 +590,15 @@ void GraphExecutorPy::GetWeightInfo( auto x = root_node->input(1); MS_EXCEPTION_IF_NULL(x); if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) { - weight_name = weight_node->cast()->input(1)->cast()->name(); + weight_name = weight_node->cast_ptr()->input(1)->cast_ptr()->name(); } else { - auto para = weight_node->cast(); + auto para = weight_node->cast_ptr(); MS_EXCEPTION_IF_NULL(para); weight_name = para->name(); } // find the fakequant from input int64_t count = 0; const int64_t max_depth = 5; - CNodePtr cnode = nullptr; auto is_quant_cnode = [](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) || IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel) || @@ -610,7 +609,7 @@ void GraphExecutorPy::GetWeightInfo( if (count >= max_depth) { break; } - cnode = x->cast(); + auto cnode = x->cast_ptr(); if (cnode == nullptr || cnode->size() <= 1) { break; } @@ -624,9 +623,9 @@ void GraphExecutorPy::GetWeightInfo( if (!is_quant_cnode(x)) { return; } - cnode = x->cast(); + auto cnode = x->cast_ptr(); constexpr size_t expect_input_size = 4; - if (cnode == nullptr || IsPrimitiveCNode(cnode, prim::kPrimLoad) || cnode->size() != expect_input_size) { + if (cnode == nullptr || cnode->IsApply(prim::kPrimLoad) || cnode->size() != expect_input_size) { return; } const size_t fakequant_index = 2; @@ -636,18 +635,18 @@ void GraphExecutorPy::GetWeightInfo( } std::string fakequant_min_node_name; if (IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) { - fakequant_min_node_name = fakequant_min_node->cast()->input(1)->cast()->name(); + fakequant_min_node_name = fakequant_min_node->cast_ptr()->input(1)->cast_ptr()->name(); } else { - auto param = fakequant_min_node->cast(); + auto param = fakequant_min_node->cast_ptr(); MS_EXCEPTION_IF_NULL(param); fakequant_min_node_name = param->name(); } - auto quant_op_value = cnode->input(0)->cast()->value(); + const auto &quant_op_value = cnode->input(0)->cast_ptr()->value(); MS_EXCEPTION_IF_NULL(quant_op_value); if (!quant_op_value->isa()) { return; } - auto quant_op = quant_op_value->cast(); + auto quant_op = quant_op_value->cast_ptr(); (*fake_quant_table)[weight_name] = std::make_pair(quant_op->adapter(), fakequant_min_node_name); } @@ -677,7 +676,7 @@ std::map> GraphExecut } auto weight = root_node->input(weight_index); if (!is_quant_cnode(weight)) { - auto tuple_node = weight->cast(); + auto tuple_node = weight->cast_ptr(); if (tuple_node != nullptr) { auto fake_node = tuple_node->input(1); if (!is_quant_cnode(fake_node)) { @@ -688,7 +687,7 @@ std::map> GraphExecut } } // get parameter weight's name - auto cnode = weight->cast(); + auto cnode = weight->cast_ptr(); MS_EXCEPTION_IF_NULL(cnode); auto weight_node = cnode->input(weight_index); if (!weight_node->isa() && !IsPrimitiveCNode(weight_node, prim::kPrimLoad)) { @@ -1207,7 +1206,7 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef // Maybe some default parameter for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) { MS_EXCEPTION_IF_NULL(graph_params[i]); - auto param_ptr = (graph_params[i])->cast(); + auto param_ptr = (graph_params[i])->cast_ptr(); MS_EXCEPTION_IF_NULL(param_ptr); if (!param_ptr->has_default()) { MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param"; @@ -1346,7 +1345,7 @@ void GraphExecutorPy::UpdataParamNodeDefaultInput( auto ¶ms = func_graph->parameters(); for (const auto ¶m : params) { MS_EXCEPTION_IF_NULL(param); - auto param_cast = param->cast(); + auto param_cast = param->cast_ptr(); MS_EXCEPTION_IF_NULL(param_cast); auto iter = params_value.find(param_cast->name()); if (iter != params_value.end()) { diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_split.cc b/mindspore/ccsrc/pipeline/jit/pipeline_split.cc index 201602637a7..5403ce4e404 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline_split.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline_split.cc @@ -77,8 +77,7 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num static bool HasVirtualDataset(const std::vector &all_nodes) { for (auto &node : all_nodes) { - auto cnode = node->cast(); - if (IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) { + if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) { return true; } } @@ -96,7 +95,7 @@ static CNodePtr CreateTupleGetItem(const AnfNodePtr &node, size_t index, const F CNodePtr tuple_get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); MS_EXCEPTION_IF_NULL(tuple_get_item); tuple_get_item->set_scope(node->scope()); - auto input_abstract_tuple = node->abstract()->cast(); + auto input_abstract_tuple = node->abstract()->cast_ptr(); MS_EXCEPTION_IF_NULL(input_abstract_tuple); auto tuple_get_item_abstract = input_abstract_tuple->elements()[index]; MS_EXCEPTION_IF_NULL(tuple_get_item_abstract); @@ -134,7 +133,7 @@ static std::set FindForwardGraph(const FuncGraphPtr &root, const s } std::set input_parameters; for (auto &anf_param : root->parameters()) { - auto param = anf_param->cast(); + auto param = anf_param->cast_ptr(); if (!param->has_default()) { (void)input_parameters.insert(anf_param); } @@ -143,7 +142,7 @@ static std::set FindForwardGraph(const FuncGraphPtr &root, const s auto node_users_map = root->manager()->node_users(); auto node_users = node_users_map[input_parameter]; for (auto node_user : node_users) { - auto cnode = node_user.first->cast(); + auto cnode = node_user.first->cast_ptr(); if (IsValueNode(cnode->inputs()[0]) || (IsValueNode(cnode->inputs()[0]) && !root->has_flag(parallel::kTraining))) { (void)graph_sets.insert(cnode->func_graph()); @@ -155,7 +154,7 @@ static std::set FindForwardGraph(const FuncGraphPtr &root, const s if (!node->isa()) { continue; } - auto cnode = node->cast(); + auto cnode = node->cast_ptr(); if ((cnode->size() < NODE_INPUT_NUM) || !IsValueNode(cnode->input(0))) { continue; } diff --git a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc index 93eafb2b743..60a2811e1d1 100644 --- a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc +++ b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,13 @@ namespace mindspore { namespace pipeline { +static inline bool IsSameValue(Value *v1, Value *v2) { + if (v1->isa() && v2->isa()) { + return static_cast(v1)->ValueEqual(*(static_cast(v2))); + } + return *v1 == *v2; +} + void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache, HashValue *const hash_value) { MS_EXCEPTION_IF_NULL(manager); @@ -35,7 +42,7 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has if (IsValueNode(node)) { return; } - const auto &to_check_value = GetValueNode(node); + auto to_check_value = GetValuePtr(node); MS_EXCEPTION_IF_NULL(to_check_value); // Calculate hash value. @@ -62,20 +69,13 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has if (v == node) { return; } - const auto &existed_value = GetValueNode(v); + auto existed_value = GetValuePtr(v); MS_EXCEPTION_IF_NULL(existed_value); - auto equal = [&]() -> bool { - if (existed_value->isa() && to_check_value->isa()) { - return existed_value->cast()->ValueEqual(*(to_check_value->cast())); - } - return *existed_value == *to_check_value; - }; - if (equal()) { + if (IsSameValue(existed_value, to_check_value)) { (void)manager->Replace(node, v); return; } } - // Meet for the first time, append node to bucket. bucket.emplace_back(node); } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc index 3fa597b7dcb..e721de3b2b4 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2021-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -179,8 +179,9 @@ AbstractFunctionPtr GetAbstractFuncRecursively(const AbstractBasePtr &abs, const const std::size_t offset) { if (abs->isa()) { return abs->cast(); - } else if (abs->isa()) { - const auto &abs_seq = abs->cast(); + } + if (abs->isa()) { + auto abs_seq = abs->cast_ptr(); MS_EXCEPTION_IF_NULL(abs_seq); const auto &elements = abs_seq->elements(); if (offset >= index.size()) { @@ -190,12 +191,12 @@ AbstractFunctionPtr GetAbstractFuncRecursively(const AbstractBasePtr &abs, const MS_LOG(EXCEPTION) << "At offset" << offset << ", elements size of AsyncAbstract result: " << abs->ToString() << " is less than or equal to index: " << index[offset]; } - const auto &resolved = GetAbstractFuncRecursively(elements[index[offset]], index, offset + 1); + auto resolved = GetAbstractFuncRecursively(elements[index[offset]], index, offset + 1); if (!resolved->isa()) { MS_LOG(EXCEPTION) << "AsyncAbstract result cannot be resolved to AbstractFuncAtom, but: " << resolved->ToString(); } MS_LOG(DEBUG) << "Return abstract: " << resolved->ToString(); - return resolved->cast(); + return resolved; } MS_LOG(EXCEPTION) << "AsyncAbstract cannot resolved to AbstractFuncAtom or AbstractSeqeunce, but: " << abs->ToString(); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index a0547adae32..1deb3b43139 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -61,7 +61,7 @@ void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, } // namespace bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) { - auto new_sequence = dyn_cast(arg); + auto new_sequence = dyn_cast_ptr(arg); if (new_sequence != nullptr && new_sequence->sequence_nodes() != nullptr && new_sequence->size() != 0) { static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance(); auto prev_result = cache_mgr.GetValue(conf); @@ -69,7 +69,7 @@ bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) return false; } auto prev_abs = prev_result->abstract(); - auto old_sequence = dyn_cast(prev_abs); + auto old_sequence = dyn_cast_ptr(prev_abs); if (old_sequence != nullptr && (old_sequence->sequence_nodes() == nullptr || old_sequence->sequence_nodes()->empty()) && *arg == *prev_abs) { MS_LOG(DEBUG) << "Always eval"; @@ -255,7 +255,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr << fg->ToString() << "();"; } - auto func_graph_evaluator = dyn_cast(shared_from_base()); + auto func_graph_evaluator = mindspore::cast(this); if (func_graph_evaluator != nullptr) { if (engine->root_func_graph() == func_graph_evaluator->func_graph()) { engine->set_root_context(context); @@ -551,7 +551,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg AbstractBasePtrList bparams; bparams.push_back(SensitivityTransform(primal_func_)); // Check if primal func graph has the primitive returned sparse result in its bprop(). - auto real_primal_func = dyn_cast(primal_func_); + auto real_primal_func = dyn_cast_ptr(primal_func_); MS_EXCEPTION_IF_NULL(real_primal_func); FuncGraphPtr primal_func_graph = real_primal_func->func_graph(); MS_EXCEPTION_IF_NULL(primal_func_graph); @@ -625,7 +625,7 @@ AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_ MS_LOG(EXCEPTION) << "The orig_abs should be AbstractTensor when axis is " << *axis << ", but got a " << orig_abs->ToString() << "."; } - ShapeVector orig_shape = dyn_cast(orig_abs->BuildShape())->shape(); + ShapeVector orig_shape = dyn_cast_ptr(orig_abs->BuildShape())->shape(); int shape_len = SizeToInt(orig_shape.size()); if (*axis < -shape_len || *axis >= shape_len) { MS_LOG(EXCEPTION) << "The axis: " << *axis << " in 'in_axes' is out of bounds for array of dimension [" @@ -649,22 +649,21 @@ AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_ AbstractBasePtr GetLogicalViewAbs(const AbstractBasePtr &physical_view_abs, const ValuePtr &in_axes, int *axis_size) { MS_EXCEPTION_IF_NULL(physical_view_abs); MS_EXCEPTION_IF_NULL(in_axes); - auto physical_view_abs_sequence = dyn_cast(physical_view_abs); + auto physical_view_abs_sequence = dyn_cast_ptr(physical_view_abs); if (physical_view_abs_sequence != nullptr) { AbstractBasePtrList abs_list = physical_view_abs_sequence->elements(); AbstractBasePtrList logical_view_abs_list; - auto in_axes_seq = dyn_cast(in_axes); + auto in_axes_seq = dyn_cast_ptr(in_axes); int index = 0; - (void)std::transform( - abs_list.begin(), abs_list.end(), std::back_inserter(logical_view_abs_list), - [&axis_size, &index, &in_axes_seq, in_axes](const AbstractBasePtr &sub_abs) -> AbstractBasePtr { - ValuePtr sub_in_axes = in_axes; - if (in_axes->isa()) { - sub_in_axes = (*in_axes_seq)[index]; - index++; - } - return GetLogicalViewAbs(sub_abs, sub_in_axes, axis_size); - }); + (void)std::transform(abs_list.begin(), abs_list.end(), std::back_inserter(logical_view_abs_list), + [&axis_size, &index, in_axes_seq, in_axes](const AbstractBasePtr &sub_abs) -> AbstractBasePtr { + ValuePtr sub_in_axes = in_axes; + if (in_axes->isa()) { + sub_in_axes = (*in_axes_seq)[index]; + index++; + } + return GetLogicalViewAbs(sub_abs, sub_in_axes, axis_size); + }); if (physical_view_abs->isa()) { return std::make_shared(logical_view_abs_list, physical_view_abs_sequence->sequence_nodes()); } @@ -672,7 +671,7 @@ AbstractBasePtr GetLogicalViewAbs(const AbstractBasePtr &physical_view_abs, cons } ValuePtr in_axis = in_axes; if (in_axis->isa()) { - int axis = dyn_cast(in_axis)->value(); + int axis = dyn_cast_ptr(in_axis)->value(); auto logical_view_abs = ReduceDim(&axis, physical_view_abs, axis_size); return logical_view_abs; } @@ -689,7 +688,7 @@ AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_s AbstractBasePtr out_abs = nullptr; ShapeVector orig_shape; if (orig_abs->isa()) { - orig_shape = dyn_cast(orig_abs->BuildShape())->shape(); + orig_shape = dyn_cast_ptr(orig_abs->BuildShape())->shape(); } int shape_len = SizeToInt(orig_shape.size() + 1); if (*axis < -shape_len || *axis >= shape_len) { @@ -713,11 +712,11 @@ AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_s AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, const ValuePtr &out_axes, int axis_size) { MS_EXCEPTION_IF_NULL(logical_view_abs); - auto logical_view_abs_sequence = dyn_cast(logical_view_abs); + auto logical_view_abs_sequence = dyn_cast_ptr(logical_view_abs); if (logical_view_abs_sequence != nullptr) { AbstractBasePtrList logical_view_abs_list = logical_view_abs_sequence->elements(); AbstractBasePtrList physical_view_abs_list; - auto out_axes_seq = dyn_cast(out_axes); + auto out_axes_seq = dyn_cast_ptr(out_axes); if (out_axes_seq != nullptr) { if (logical_view_abs_list.size() != out_axes_seq->size()) { MS_LOG(EXCEPTION) << "The size of vmap's 'out_axes' should be equal to the number of results of 'fn': " @@ -727,7 +726,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons int index = 0; (void)std::transform( logical_view_abs_list.begin(), logical_view_abs_list.end(), std::back_inserter(physical_view_abs_list), - [&axis_size, &index, &out_axes_seq, out_axes](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { + [&axis_size, &index, out_axes_seq, out_axes](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { ValuePtr sub_out_axes = out_axes; if (out_axes->isa()) { sub_out_axes = (*out_axes_seq)[index]; @@ -737,7 +736,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons return GetPhysicalViewAbs(arg_spec, sub_out_axes, axis_size); } if (sub_out_axes->isa()) { - int axis = dyn_cast(sub_out_axes)->value(); + int axis = dyn_cast_ptr(sub_out_axes)->value(); return ExtendDim(&axis, arg_spec, axis_size); } else if (sub_out_axes->isa()) { return arg_spec; @@ -763,7 +762,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons } int axis = 0; - auto axis_int_ptr = dyn_cast(sub_out_axes); + auto axis_int_ptr = dyn_cast_ptr(sub_out_axes); if (axis_int_ptr != nullptr) { axis = LongToInt(axis_int_ptr->value()); } else { @@ -787,9 +786,9 @@ EvalResultPtr VmapEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList & int axis_size = -1; int index = 0; auto in_axes = in_axes_; - auto in_axes_seq = dyn_cast(in_axes); + auto in_axes_seq = dyn_cast_ptr(in_axes); (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list), - [&axis_size, &index, &in_axes_seq, in_axes](const ConfigPtr &conf) -> AbstractBasePtr { + [&axis_size, &index, in_axes_seq, in_axes](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); AbstractBasePtr abs = conf->ObtainEvalResult()->abstract(); // Drop the side effect tag parameters, because it has no mapping axis. diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index c4454dd316f..98e196f0e28 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -61,7 +61,7 @@ std::pair InterpretAbstractBoolChecker(const AbstractBasePtr &cond) auto value = cond->BuildValue(); if (value->isa()) { is_interpret = true; - auto interpreted_obj = value->cast(); + auto interpreted_obj = value->cast_ptr(); py::object obj = interpreted_obj->obj(); constexpr char PYTHON_MOD_PARSE_MODULE[] = "mindspore._extends.parse"; constexpr char PYTHON_MOD_CHECK_OBJ_BOOL[] = "check_obj_bool"; @@ -114,10 +114,10 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt }); // Do undetermined infer firstly. - auto do_signature = prim_->cast(); + auto do_signature = prim_->cast_ptr(); MS_EXCEPTION_IF_NULL(do_signature); auto &func = do_signature->function(); - auto do_signature_func = dyn_cast(func); + auto do_signature_func = dyn_cast_ptr(func); if (do_signature_func != nullptr) { if (prims_to_skip_undetermined_infer.find(do_signature_func->name()) == prims_to_skip_undetermined_infer.end()) { auto ret_abstract = EvalUndeterminedArgs(args_spec_list); @@ -168,11 +168,11 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) { MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]); if (specialize_args_before_unpack[index]->isa()) { - auto arg_tuple = specialize_args_before_unpack[index]->cast(); + auto arg_tuple = specialize_args_before_unpack[index]->cast_ptr(); std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(), std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; }); } else if (specialize_args_before_unpack[index]->isa()) { - auto arg_dict = specialize_args_before_unpack[index]->cast(); + auto arg_dict = specialize_args_before_unpack[index]->cast_ptr(); auto dict_elems = arg_dict->elements(); (void)std::transform( dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args), @@ -197,9 +197,9 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; } - auto unpack_graph = prim_->cast(); + auto unpack_graph = prim_->cast_ptr(); MS_EXCEPTION_IF_NULL(unpack_graph); - auto out_node = out_conf->node()->cast(); + auto out_node = out_conf->node()->cast_ptr(); MS_EXCEPTION_IF_NULL(out_node); const auto &out_node_inputs = out_node->inputs(); if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) { @@ -219,11 +219,11 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt MS_LOG(EXCEPTION) << "args_spec_list can't be empty."; } MS_EXCEPTION_IF_NULL(args_spec_list[0]); - auto fn = args_spec_list[0]->cast(); + auto fn = args_spec_list[0]->cast_ptr(); if (fn == nullptr) { MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); } - auto real_fn = fn->cast(); + auto real_fn = fn->cast_ptr(); MS_EXCEPTION_IF_NULL(real_fn); FuncGraphPtr forward_graph = real_fn->func_graph(); MS_EXCEPTION_IF_NULL(forward_graph); @@ -255,14 +255,14 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac MS_EXCEPTION_IF_NULL(func_graph); AnfNodePtr target_node = source_node; if (node_type->isa()) { - auto x = node_type->cast(); + auto x = node_type->cast_ptr(); if (x->element()->BuildType()->isa()) { auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); MS_EXCEPTION_IF_NULL(cast); target_node = func_graph->NewCNodeAfter(source_node, {NewValueNode(cast), source_node, target_type}); } } else if (node_type->isa()) { - auto x = node_type->cast(); + auto x = node_type->cast_ptr(); auto &items = x->elements(); std::vector nodes; nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); @@ -276,7 +276,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac } target_node = func_graph->NewCNode(nodes); } else if (node_type->isa()) { - auto x = node_type->cast(); + auto x = node_type->cast_ptr(); auto &items = x->elements(); std::vector dict_key_nodes; std::vector dict_value_nodes; @@ -293,7 +293,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(std::move(dict_key_nodes)), func_graph->NewCNode(std::move(dict_value_nodes))}); } else if (node_type->isa()) { - auto x = node_type->cast(); + auto x = node_type->cast_ptr(); std::string kwarg_key = x->get_key(); AnfNodePtr kwarg_value_node = func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node}); @@ -336,7 +336,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph()); if (new_node->isa()) { - auto new_cnode = new_node->cast(); + auto new_cnode = new_node->cast_ptr(); new_cnode->CloneCNodeInfo(out_node); } return engine->ForwardConfig(out_conf, fn_conf); @@ -439,7 +439,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base, bool only_conver if (dyn_shape_value) { for (size_t i = 0; i < len; i++) { if (!res[i].contains(py::str(ATTR_SHAPE_VALUE))) { - auto const_abstract_value = arg_tuple->elements()[i]->cast(); + auto const_abstract_value = arg_tuple->elements()[i]->cast_ptr(); MS_EXCEPTION_IF_NULL(const_abstract_value); auto const_value = const_abstract_value->BuildValue(); MS_EXCEPTION_IF_NULL(const_value); @@ -547,7 +547,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert if (shape_value) { for (size_t i = 0; i < len; i++) { if (!res[i].contains(py::str(ATTR_SHAPE_VALUE))) { - auto const_abstract_value = arg_list->elements()[i]->cast(); + auto const_abstract_value = arg_list->elements()[i]->cast_ptr(); MS_EXCEPTION_IF_NULL(const_abstract_value); auto const_value = const_abstract_value->BuildValue(); MS_EXCEPTION_IF_NULL(const_value); @@ -565,7 +565,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert } void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, bool only_convert_value, py::dict *dic) { - auto arg_tensor = dyn_cast(abs_base); + auto arg_tensor = dyn_cast_ptr(abs_base); MS_EXCEPTION_IF_NULL(dic); MS_EXCEPTION_IF_NULL(arg_tensor); if (only_convert_value) { @@ -601,16 +601,16 @@ py::object GetPyObjForPrimitiveAbstract(const PrimitiveAbstractClosurePtr &prim_ return py::none(); } if (prim->isa()) { - auto do_sig_prim = prim->cast(); + auto do_sig_prim = prim->cast_ptr(); auto value = do_sig_prim->function(); if (!value->isa()) { return py::none(); } - auto prim_py = value->cast(); + auto prim_py = value->cast_ptr(); return prim_py->GetPyObj(); } if (prim->isa()) { - auto prim_py = prim->cast(); + auto prim_py = prim->cast_ptr(); return prim_py->GetPyObj(); } return py::none(); @@ -621,7 +621,7 @@ bool IsCallInstance(const PartialAbstractClosurePtr &partial_abs) { if (!fn->isa()) { return false; } - auto abs_prim = fn->cast(); + auto abs_prim = fn->cast_ptr(); auto prim = abs_prim->prim(); if (prim->name() == prim::kPrimCallInstance->name()) { return true; @@ -643,14 +643,14 @@ void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict * auto value = args[0]->BuildValue(); MS_EXCEPTION_IF_NULL(value); if (IsCallInstance(partial_abs)) { - auto value_obj = value->cast(); + auto value_obj = value->cast_ptr(); if (value_obj != nullptr) { (*dic)[ATTR_DTYPE] = std::make_shared(); (*dic)[ATTR_VALUE] = value_obj->obj(); return; } } - auto value_obj = value->cast(); + auto value_obj = value->cast_ptr(); if (value_obj != nullptr) { (*dic)[ATTR_DTYPE] = std::make_shared(); (*dic)[ATTR_VALUE] = value_obj->obj(); @@ -708,35 +708,35 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_conv } else if (abs_base->isa()) { return AbstractListToPython(abs_base, only_convert_value); } else if (abs_base->isa()) { - auto arg_slice = dyn_cast(abs_base); + auto arg_slice = dyn_cast_ptr(abs_base); ShapeVector shape; dic[ATTR_SHAPE] = shape; dic[ATTR_DTYPE] = arg_slice->BuildType(); dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue()); } else if (abs_base->isa()) { - auto arg = dyn_cast(abs_base); + auto arg = dyn_cast_ptr(abs_base); dic[ATTR_SHAPE] = arg->shape()->shape(); dic[ATTR_DTYPE] = arg->BuildType(); dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); } else if (abs_base->isa()) { - auto arg = dyn_cast(abs_base); + auto arg = dyn_cast_ptr(abs_base); AbstractBasePtrList sparse_shape = arg->shape()->elements(); ShapeVector sparse_shape_vector; (void)std::transform(sparse_shape.begin(), sparse_shape.end(), std::back_inserter(sparse_shape_vector), [](const AbstractBasePtr &e) -> int64_t { - ValuePtr value = e->cast()->BuildValue(); + ValuePtr value = e->cast_ptr()->BuildValue(); return GetValue(value); }); dic[ATTR_SHAPE] = sparse_shape_vector; dic[ATTR_DTYPE] = arg->BuildType(); dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); } else if (abs_base->isa()) { - auto arg = dyn_cast(abs_base); + auto arg = dyn_cast_ptr(abs_base); AbstractBasePtrList sparse_shape = arg->shape()->elements(); ShapeVector sparse_shape_vector; (void)std::transform(sparse_shape.begin(), sparse_shape.end(), std::back_inserter(sparse_shape_vector), [](const AbstractBasePtr &e) -> int64_t { - ValuePtr value = e->cast()->BuildValue(); + ValuePtr value = e->cast_ptr()->BuildValue(); return GetValue(value); }); dic[ATTR_SHAPE] = sparse_shape_vector; @@ -753,7 +753,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_conv } else if (abs_base->isa()) { ConvertAbstractFunctionToPython(abs_base, &dic); } else if (abs_base->isa()) { - auto arg = dyn_cast(abs_base); + auto arg = dyn_cast_ptr(abs_base); dic[ATTR_SHAPE] = py::none(); dic[ATTR_DTYPE] = arg->BuildType(); dic[ATTR_VALUE] = py::none(); @@ -804,7 +804,7 @@ void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBa << "]'s attribute[output_num]:" << output_num << " not matches the infer result " << res_spec->ToString(); } else if (res_spec->isa() && - (res_spec->cast()->size() != LongToSize(output_num))) { + (res_spec->cast_ptr()->size() != LongToSize(output_num))) { MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString() << "]'s attribute[output_num]:" << output_num << " not matches the infer result " << res_spec->ToString(); @@ -833,7 +833,7 @@ void SetShapeValue(const AbstractBasePtr &tensor, const py::object &output) { if (!converted) { MS_LOG(EXCEPTION) << "Convert shape max value data failed"; } - auto abs_tensor = dyn_cast(tensor); + auto abs_tensor = dyn_cast_ptr(tensor); abs_tensor->set_value_range(min_value, max_value); } if (!output.contains(py::str(ATTR_SHAPE_VALUE))) { @@ -849,7 +849,7 @@ void SetShapeValue(const AbstractBasePtr &tensor, const py::object &output) { if (!converted) { MS_LOG(EXCEPTION) << "Convert shape value data failed"; } - auto abs_tensor = dyn_cast(tensor); + auto abs_tensor = dyn_cast_ptr(tensor); abs_tensor->set_shape_value(shape_value); } @@ -1023,7 +1023,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &, MS_EXCEPTION_IF_NULL(res_spec); if (res_spec->isa()) { // Replace to tensor constant node in specialize - auto res_tensor = res_spec->cast(); + auto res_tensor = res_spec->cast_ptr(); res_tensor->set_value(converted_ret); } return std::make_shared(res_spec, std::make_shared(added_attrs)); @@ -1248,8 +1248,8 @@ EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Ab // Check if all arguments are scalar type. MS_EXCEPTION_IF_NULL(arg); if (arg->isa()) { - auto arg_scalar = dyn_cast(arg); - auto arg_value = arg_scalar->GetValueTrack(); + auto arg_scalar = dyn_cast_ptr(arg); + const auto &arg_value = arg_scalar->GetValueTrack(); value_list.push_back(arg_value); } else { // Raise TypeError Expected Scalar. @@ -1335,18 +1335,18 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_ AbstractBasePtr abstract = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); // Create new cnode std::vector input = {NewValueNode(prim::kPrimPartial)}; - auto func_graph_func = dyn_cast(abstract); + auto func_graph_func = dyn_cast_ptr(abstract); if (func_graph_func != nullptr) { FuncGraphPtr fg = func_graph_func->func_graph(); input.push_back(NewValueNode(fg)); } else { - auto prim_func = dyn_cast(abstract); + auto prim_func = dyn_cast_ptr(abstract); MS_EXCEPTION_IF_NULL(prim_func); PrimitivePtr prim = prim_func->prim(); input.push_back(NewValueNode(prim)); } - AnfNodeConfigPtr conf = dyn_cast(data_conf); + auto conf = dyn_cast_ptr(data_conf); MS_EXCEPTION_IF_NULL(conf); input.push_back(conf->node()); MS_EXCEPTION_IF_NULL(old_conf); @@ -1367,15 +1367,15 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con MS_EXCEPTION_IF_NULL(attr_value); ValuePtr item_value = attr_value; if (item_value->isa()) { - item_value = std::make_shared(item_value->cast()->value()); + item_value = std::make_shared(item_value->cast_ptr()->value()); } if (!item_value->isa()) { MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString(); } // item_name to func addr from obj_map - parse::SymbolPtr symbol = item_value->cast(); - parse::NameSpacePtr name_space = data_value->cast(); + auto symbol = item_value->cast(); + auto name_space = data_value->cast(); MS_EXCEPTION_IF_NULL(out_conf); auto out_node = out_conf->node(); FuncGraphPtr func_graph = out_node->func_graph(); @@ -1429,12 +1429,12 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &, if (!item_value->isa()) { MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString(); } - std::string item_name = item_value->cast()->value(); + const auto &item_name = item_value->cast_ptr()->value(); // Get ms_class object. if (!data_value->isa()) { MS_LOG(EXCEPTION) << "Expect a ms_class object, but got " << data_value->ToString(); } - auto ms_class = data_value->cast(); + auto ms_class = data_value->cast_ptr(); MS_LOG(DEBUG) << "Resolve ms_class (" << ms_class->name() << ") with item " << item_name << "."; // Get the attr/method of ms_class object. @@ -1461,7 +1461,7 @@ EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AnalysisEnginePtr &engi if (python_obj == nullptr) { return nullptr; } - auto wrapper_obj = dyn_cast(python_obj); + auto wrapper_obj = dyn_cast_ptr(python_obj); MS_EXCEPTION_IF_NULL(wrapper_obj); py::object real_python_obj = wrapper_obj->obj(); MS_LOG(DEBUG) << "item_value: " << item_value->ToString() << ", func_value: " << func_value->ToString() @@ -1485,7 +1485,7 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt if (!item_value->isa()) { MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString(); } - std::string item_name = item_value->cast()->value(); + std::string item_name = item_value->cast_ptr()->value(); REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD; Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name); if (require.empty()) { @@ -1520,7 +1520,7 @@ ValuePtr GetMsClassObject(const AbstractBasePtr &abs) { if (!abs->isa()) { return nullptr; } - auto partial_abs = abs->cast(); + auto partial_abs = abs->cast_ptr(); auto fn = partial_abs->fn(); if (!fn->isa()) { return nullptr; @@ -1568,7 +1568,7 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt if (class_value != nullptr) { return GetEvaluatedValueForMsClassAttrOrMethod(engine, item_value, class_value, data_conf, out_conf); } - auto data_func_graph = dyn_cast(data_args); + auto data_func_graph = dyn_cast_ptr(data_args); if (data_func_graph != nullptr) { auto res = GetEvaluatedValueForCellAttrOrMethod(engine, item_value, data_func_graph->func_graph(), data_conf, out_conf); @@ -1642,7 +1642,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator { if (args_conf_list.size() != 1) { MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size(); } - AnfNodeConfigPtr node_conf = dyn_cast(args_conf_list[0]); + auto node_conf = dyn_cast_ptr(args_conf_list[0]); MS_EXCEPTION_IF_NULL(node_conf); MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult()); AbstractBasePtr x = node_conf->ObtainEvalResult()->abstract(); @@ -1662,7 +1662,7 @@ static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const FuncGraphPtr &root_g = root_g_set.back(); for (auto ¶m_node : root_g->parameters()) { auto param = param_node->cast(); - if (param && name == param->name()) { + if (param != nullptr && param->name() == name) { return param; } } @@ -1680,7 +1680,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { return nullptr; } static TypePtr type = std::make_shared(); - auto node_conf = dyn_cast(args_conf_list[0]); + auto node_conf = dyn_cast_ptr(args_conf_list[0]); if (node_conf == nullptr) { MS_LOG(ERROR) << "Conf should be AnfNodeConfig"; return nullptr; @@ -1688,7 +1688,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult()); AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract(); MS_EXCEPTION_IF_NULL(abs); - AbstractRefPtr ref_abs = abs->cast(); + auto ref_abs = abs->cast_ptr(); if (ref_abs == nullptr) { MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString(); return nullptr; @@ -1701,11 +1701,11 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { // Only after that funcgrpah is inlined, the RefToEmbed CNode should be evaluated to specific SymbolicKey. bool ifEmbedIsWeight = false; if (node_conf->node() != nullptr && node_conf->node()->isa()) { - auto param = node_conf->node()->cast(); + auto param = node_conf->node()->cast_ptr(); MS_EXCEPTION_IF_NULL(param); ifEmbedIsWeight = param->has_default(); } - auto refkey = ref_abs->ref_key_value()->cast(); + auto refkey = ref_abs->ref_key_value()->cast_ptr(); if (refkey == nullptr || !ifEmbedIsWeight) { auto ret = std::make_shared(type); auto ref_value = ref_abs->ref(); @@ -1788,12 +1788,12 @@ class ResolveEvaluator : public TransitionPrimEvaluator { bool IsContainUndetermined(const AbstractBasePtr &arg) { if (arg->isa()) { - auto seq_arg = arg->cast(); + auto seq_arg = arg->cast_ptr(); return std::any_of(seq_arg->elements().begin(), seq_arg->elements().end(), IsContainUndetermined); } if (arg->isa()) { - auto kw_arg = arg->cast(); + auto kw_arg = arg->cast_ptr(); return IsContainUndetermined(kw_arg->get_arg()); } @@ -1824,7 +1824,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { ValuePtr value_track = arg_class_type->GetValueTrack(); MS_EXCEPTION_IF_NULL(value_track); - parse::PyObjectWrapperPtr type_obj = dyn_cast(value_track); + auto type_obj = dyn_cast_ptr(value_track); if (type_obj == nullptr) { MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << "."; } @@ -1911,7 +1911,7 @@ class CallInstanceEvaluator : public TransitionPrimEvaluator { } ValuePtr value_track = arg_cls->GetValueTrack(); MS_EXCEPTION_IF_NULL(value_track); - parse::MsClassObjectPtr ms_class = dyn_cast(value_track); + auto ms_class = dyn_cast_ptr(value_track); if (ms_class == nullptr) { MS_LOG(EXCEPTION) << "CallInstanceEvaluator only supports MsClassObject."; } @@ -1933,7 +1933,7 @@ class CallInstanceEvaluator : public TransitionPrimEvaluator { // Replace net with net.__call__ AnfNodePtr old_node = out_conf->node(); MS_EXCEPTION_IF_NULL(old_node); - CNodePtr old_cnode = dyn_cast(old_node); + auto old_cnode = dyn_cast_ptr(old_node); MS_EXCEPTION_IF_NULL(old_cnode); std::vector inputs = {NewValueNode(call_func_graph)}; for (size_t i = 1; i < old_cnode->size(); i++) { @@ -1966,7 +1966,7 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator { ValuePtr value_track = args_spec_list[0]->GetValueTrack(); MS_EXCEPTION_IF_NULL(value_track); - std::shared_ptr script_obj = dyn_cast(value_track); + auto script_obj = dyn_cast_ptr(value_track); if (script_obj == nullptr) { MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << "."; } @@ -2027,12 +2027,12 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator { const auto &element_name = element.first; const auto &element_abs = element.second; if (element_abs->isa()) { - const auto &element_abs_fn = element_abs->cast(); - const auto &fg = element_abs_fn->func_graph(); + auto element_abs_fn = element_abs->cast_ptr(); + auto fg = element_abs_fn->func_graph(); MS_EXCEPTION_IF_NULL(fg); auto wrapper_obj = fg->python_obj(); if (wrapper_obj != nullptr && wrapper_obj->isa()) { - auto fn_py_obj = wrapper_obj->cast()->obj(); + auto fn_py_obj = wrapper_obj->cast_ptr()->obj(); (*global_params_dict)[py::str(element_name)] = fn_py_obj; MS_LOG(DEBUG) << "Found global python function object for " << element_name << ", add it to global dict."; } @@ -2138,10 +2138,10 @@ class PartialEvaluator : public Evaluator { auto func = CheckArg("partial", args_spec_list, 0); // Sometimes, node[0] in out_conf becomes phi0; if (func->isa()) { - auto prim_func = dyn_cast(func); + auto prim_func = dyn_cast_ptr(func); MS_EXCEPTION_IF_NULL(prim_func->prim()); if (prim_func->prim()->isa()) { - prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast(prim_func->prim()); + auto do_signature_prim = dyn_cast_ptr(prim_func->prim()); return HandleDoSignature(engine, do_signature_prim->function(), out_conf); } } @@ -2179,7 +2179,7 @@ class PartialEvaluator : public Evaluator { MS_EXCEPTION_IF_NULL(engine); MS_EXCEPTION_IF_NULL(out_conf); MS_EXCEPTION_IF_NULL(out_conf->node()); - auto cnode = out_conf->node()->cast(); + auto cnode = out_conf->node()->cast_ptr(); if (cnode == nullptr) { MS_LOG(EXCEPTION) << "Cnode is nullptr"; } @@ -2234,7 +2234,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator { std::string exception_string; // Processed in units of nodes. Raise ValueError(xxxx) size_t index_begin = 2; - auto cnode = node->cast(); + auto cnode = node->cast_ptr(); MS_EXCEPTION_IF_NULL(cnode); auto inputs = cnode->inputs(); bool need_out_symbol = inputs.size() > 3; @@ -2268,17 +2268,17 @@ class RaiseEvaluator : public TransitionPrimEvaluator { bool CheckNeedSymbol(const AnfNodePtr &, const AbstractBasePtr &abs) const { bool need_symbol = false; if (abs->isa()) { - auto scalar = abs->cast(); + auto scalar = abs->cast_ptr(); auto scalar_value = scalar->BuildValue(); if (scalar_value->isa()) { need_symbol = true; } } else if (abs->isa()) { - auto abs_list = abs->cast(); + auto abs_list = abs->cast_ptr(); const auto &elements = abs_list->elements(); for (auto &element : elements) { if (element->isa()) { - auto scalar = element->cast(); + auto scalar = element->cast_ptr(); auto scalar_value = scalar->BuildValue(); if (scalar_value->isa()) { need_symbol = true; @@ -2310,7 +2310,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator { std::string GetTupleString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) { std::string exception_str; // Process raise ValueError("str") - auto arg_tuple = arg->cast(); + auto arg_tuple = arg->cast_ptr(); const auto &arg_tuple_elements = arg_tuple->elements(); if (arg_tuple_elements.size() == 0) { MS_LOG(EXCEPTION) << "The arg_tuple_elements can't be empty."; @@ -2334,7 +2334,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator { std::string GetListString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) { std::string exception_str; // Process raise ValueError("str") - auto arg_list = arg->cast(); + auto arg_list = arg->cast_ptr(); const auto &arg_list_elements = arg_list->elements(); if (arg_list_elements.size() == 0) { MS_LOG(EXCEPTION) << "The arg_list_elements can't be empty."; @@ -2358,7 +2358,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator { std::string GetExceptionType(const AbstractBasePtr &abs) const { std::string str; if (abs->isa()) { - auto scalar = abs->cast(); + auto scalar = abs->cast_ptr(); auto scalar_value = scalar->BuildValue(); if (scalar_value->isa()) { str = GetValue(scalar_value); @@ -2491,8 +2491,8 @@ namespace { bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) { MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(model); - auto x_tuple = dyn_cast(x); - auto model_tuple = dyn_cast(model); + auto x_tuple = dyn_cast_ptr(x); + auto model_tuple = dyn_cast_ptr(model); if (x_tuple == nullptr || model_tuple == nullptr) { return false; @@ -2518,8 +2518,8 @@ bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) { bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) { MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(model); - auto x_tensor = dyn_cast(x); - auto model_tensor = dyn_cast(model); + auto x_tensor = dyn_cast_ptr(x); + auto model_tensor = dyn_cast_ptr(model); if (x_tensor == nullptr || model_tensor == nullptr) { return false; @@ -2535,8 +2535,8 @@ bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) { bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) { MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(model); - auto x_list = dyn_cast(x); - auto model_list = dyn_cast(model); + auto x_list = dyn_cast_ptr(x); + auto model_list = dyn_cast_ptr(model); if (x_list == nullptr || model_list == nullptr) { return false; @@ -2563,10 +2563,10 @@ bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) { inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) { MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(model); - if (dyn_cast(x) == nullptr) { + if (dyn_cast_ptr(x) == nullptr) { return false; } - TypePtr x_type = x->GetTypeTrack(); + auto &x_type = x->GetTypeTrack(); return IsSubType(x_type, model); } } // namespace diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc index b9ae59eba4e..34c75cce4f2 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -63,7 +63,7 @@ bool CanSpecializeValueNode(const AnfNodePtr &node) { } if (IsValueNode(node)) { if (node->abstract() != nullptr) { - auto abs_func = node->abstract()->cast(); + auto abs_func = node->abstract()->cast_ptr(); // If this funcgraph had specialized in ProcessCNode of FirstPass, // then ignore it. if (abs_func != nullptr && abs_func->specialized()) { @@ -110,12 +110,12 @@ void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) { } // Handle MakeTuple/MakeList CNode. - auto cnode = dyn_cast(node); + auto cnode = dyn_cast_ptr(node); if (cnode != nullptr) { if (pos + 1 >= cnode->inputs().size()) { continue; } - auto input_value = GetValueNode(cnode->input(pos + 1)); + auto input_value = GetValuePtr(cnode->input(pos + 1)); if (input_value == nullptr || input_value->value() != kDeadNodeName) { continue; } @@ -129,7 +129,7 @@ void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) { // Change the abstract. (*flags)[pos] = false; // Change the use flag as 0. - auto sequence_abs = dyn_cast(node->abstract()); + auto sequence_abs = dyn_cast_ptr(node->abstract()); if (sequence_abs != nullptr && !sequence_abs->PurifyElements()) { MS_LOG(ERROR) << "Purify elements failed, abstract: " << sequence_abs->ToString() << ", node: " << node->DebugString(recursive_level); @@ -138,13 +138,13 @@ void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) { } // Handle ValueTuple/ValueList. if (IsValueNode(node) || IsValueNode(node)) { - auto sequence_value = GetValueNode(node); + auto sequence_value = GetValuePtr(node); MS_EXCEPTION_IF_NULL(sequence_value); if (pos >= sequence_value->value().size()) { continue; } ValuePtr element_value = sequence_value->value()[pos]; - auto element_str_value = element_value->cast(); + auto element_str_value = element_value->cast_ptr(); if (element_str_value == nullptr || element_str_value->value() != kDeadNodeName) { continue; } @@ -157,7 +157,7 @@ void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) { // Change the abstract. (*flags)[pos] = false; // Change the use flag as 0. - auto sequence_abs = dyn_cast(node->abstract()); + auto sequence_abs = dyn_cast_ptr(node->abstract()); if (sequence_abs != nullptr && !sequence_abs->PurifyElements()) { constexpr int recursive_level = 2; MS_LOG(ERROR) << "Purify elements failed, abstract: " << sequence_abs->ToString() @@ -250,7 +250,7 @@ AbstractBasePtr ProgramSpecializer::SpecializeAbstractFuncRecursively(const Abst auto build_new_abs = [this, &func_atoms](const AbstractFuncAtomPtr &poss) { auto resolved_atom = poss; if (poss->isa()) { - const auto &async_abs_func = poss->cast(); + auto async_abs_func = poss->cast_ptr(); const auto &resolved_func = async_abs_func->GetUnique(); resolved_atom = resolved_func->cast(); MS_EXCEPTION_IF_NULL(resolved_atom); @@ -306,7 +306,7 @@ void ProgramSpecializer::SpecializeCNodeInput0FuncGraph() { if (!node->isa()) { continue; } - auto &input0 = node->cast()->input(0); + auto &input0 = node->cast_ptr()->input(0); MS_EXCEPTION_IF_NULL(input0); if (IsValueNode(input0)) { continue; @@ -383,7 +383,7 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node) { MS_EXCEPTION_IF_NULL(node); - auto c_node = node->cast(); + auto c_node = node->cast_ptr(); MS_EXCEPTION_IF_NULL(c_node); auto inputs = c_node->inputs(); std::vector new_inputs; @@ -403,7 +403,7 @@ void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const An return new_inp; }); - auto c_new_node = new_node->cast(); + auto c_new_node = new_node->cast_ptr(); MS_EXCEPTION_IF_NULL(c_new_node); c_new_node->set_inputs(new_inputs); } @@ -534,7 +534,7 @@ void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node, if (new_node == old_node) { return; } - AbstractSequencePtr old_sequence_abs = dyn_cast(old_abs); + auto old_sequence_abs = dyn_cast_ptr(old_abs); if (old_sequence_abs == nullptr || old_sequence_abs->sequence_nodes() == nullptr || old_sequence_abs->sequence_nodes()->empty()) { MS_LOG(DEBUG) << "No sequence node in old abs, " << old_node->DebugString() << " --> " << new_node->DebugString(); @@ -590,7 +590,7 @@ void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node, } MS_LOG(ERROR) << "New abstract, " << old_node->DebugString() << " --> " << new_node->DebugString() << ", elements_use_flags: " << (*flags); - AbstractSequencePtr new_sequence_abs = dyn_cast(new_abs); + auto new_sequence_abs = dyn_cast_ptr(new_abs); if (new_sequence_abs == nullptr) { MS_LOG(EXCEPTION) << "New node should be sequence type as well, but got " << new_abs->ToString(); } @@ -608,7 +608,7 @@ void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node, template void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecializer *const specializer) { const auto &old_input = cnode->input(index); - auto sequence_value = GetValueNode>(old_input); + auto sequence_value = GetValuePtr(old_input); if (sequence_value == nullptr) { return; } @@ -625,7 +625,7 @@ void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecial } for (size_t i = 0; i < sequence_value_size; ++i) { ValuePtr old_sequence_value = sequence_value->value()[i]; - auto old_sequence_str_value = old_sequence_value->cast(); + auto old_sequence_str_value = old_sequence_value->cast_ptr(); if (!(*flags)[i]) { auto zero = MakeValue(0); (void)elements.emplace_back(zero); @@ -643,7 +643,7 @@ void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecial auto new_sequence_value = std::make_shared(elements); auto new_input = NewValueNode(new_sequence_value); auto new_input_abs = new_sequence_value->ToAbstract(); - AbstractSequencePtr new_sequence_abs = dyn_cast(new_input_abs); + auto new_sequence_abs = dyn_cast(new_input_abs); MS_EXCEPTION_IF_NULL(new_sequence_abs); std::shared_ptr sequence_nodes = std::make_shared(); (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(new_input)); @@ -704,7 +704,7 @@ void FuncGraphSpecializer::EliminateUnusedSequenceItem(const CNodePtr &cnode) co (void)inputs.emplace_back(cnode->input(0)); for (size_t i = 0; i < (*flags).size(); ++i) { auto old_input = cnode->input(i + 1); - auto old_input_value = GetValueNode(old_input); + auto old_input_value = GetValuePtr(old_input); if (!(*flags)[i]) { auto zero_value = NewValueNode(MakeValue(0)); zero_value->set_abstract(std::make_shared(std::make_shared(0))); @@ -754,7 +754,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { MS_LOG(EXCEPTION) << "Fail to get abstract value with " << conf->ToString() << ", for " << new_node->DebugString(); } if (new_node->isa() && new_node->abstract()->isa()) { - auto partial_abstract = dyn_cast(new_node->abstract()); + auto partial_abstract = dyn_cast_ptr(new_node->abstract()); if (partial_abstract->node() == node) { partial_abstract->set_node(new_node); } @@ -768,8 +768,8 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { } static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0"); auto attrs = conf->ObtainEvalResult()->attribute(); - auto c_old = node->cast(); - auto c_new = new_node->cast(); + auto c_old = node->cast_ptr(); + auto c_new = new_node->cast_ptr(); MS_EXCEPTION_IF_NULL(c_new); auto new_inputs = c_new->inputs(); auto old_inputs = c_old->inputs(); @@ -786,7 +786,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { bool ignore_build_value = false; AnfNodePtr replace_node = nullptr; if (specializer_->engine()->check_isolated_side_effect()) { - auto cnode_input = dyn_cast(node_input); + auto cnode_input = dyn_cast_ptr(node_input); ignore_build_value = (cnode_input != nullptr && cnode_input->has_isolated_side_effect_node()); if (ignore_build_value) { MS_LOG(INFO) << "Don't build value node for CNode which contains isolated side-effect inputs, node: " @@ -878,7 +878,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const CNodePtr &cnode, con const AbstractBasePtr &abs, const AbstractBasePtrList &argvals) { MS_EXCEPTION_IF_NULL(abs); MS_EXCEPTION_IF_NULL(func); - AbstractFunctionPtr real_a = dyn_cast(abs); + auto real_a = dyn_cast_ptr(abs); MS_EXCEPTION_IF_NULL(real_a); AbstractFunctionPtr func_abs = real_a->GetUnique(); @@ -905,7 +905,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const CNodePtr &cnode, con // Set the flag, so this MetaFuncGraph will be Re-AutoMonaded. MS_EXCEPTION_IF_NULL(func_abs); if (func_abs->isa()) { - auto specialized_fg = GetValueNode(repl); + auto specialized_fg = GetValuePtr(repl); if (specialized_fg != nullptr && (argvals.size() > 1) && argvals.back() != nullptr && argvals.back()->isa()) { specialized_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true); @@ -924,7 +924,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const CNodePtr &cnode MS_EXCEPTION_IF_NULL(errcode); *errcode = kSpecializeSuccess; - auto real_func = dyn_cast(func_abs); + auto real_func = dyn_cast_ptr(func_abs); if (real_func != nullptr) { return BuildValueNode(real_func->prim(), abs); } @@ -942,7 +942,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const CNodePtr &cnode argvals = result.first; AbstractBasePtr unique_output = result.second; - auto prim_func = dyn_cast(func_abs); + auto prim_func = dyn_cast_ptr(func_abs); if (prim_func != nullptr) { auto type_func = std::make_shared(prim_func->prim(), argvals, unique_output); return BuildValueNode(prim_func->prim(), type_func); @@ -993,7 +993,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &c AbstractBasePtrList args; auto backed_fnval = fnval; if (fnval->isa()) { - auto partial_closure = dyn_cast(fnval); + auto partial_closure = dyn_cast_ptr(fnval); backed_fnval = partial_closure->fn(); args = partial_closure->args(); } @@ -1133,7 +1133,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) { // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...) const size_t arg_start_index = 2; while (IsPrimitiveCNode(func, prim::kPrimPartial)) { - std::vector inputs = func->cast()->inputs(); + std::vector inputs = func->cast_ptr()->inputs(); // First element is partial, second is func so arg is start from 2 (void)args.insert(args.cbegin(), inputs.cbegin() + SizeToInt(arg_start_index), inputs.cend()); func = inputs[1]; @@ -1222,10 +1222,10 @@ bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argv return true; } if (func->isa() && argvals.empty()) { - auto meta_func_graph_wrapper = dyn_cast(func); + auto meta_func_graph_wrapper = dyn_cast_ptr(func); auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph(); if (meta_func_graph != nullptr && meta_func_graph->isa()) { - auto do_signature = dyn_cast(meta_func_graph); + auto do_signature = dyn_cast_ptr(meta_func_graph); if (do_signature != nullptr && do_signature->function()->isa()) { MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY."; return true; @@ -1331,7 +1331,7 @@ AnfNodePtr FuncGraphSpecializer::BuildValueNodeForAbstractFunction(const AnfNode const AbstractFunctionPtr &abs) { ValuePtr value = nullptr; if (abs->isa()) { - auto real_fn = dyn_cast(abs); + auto real_fn = dyn_cast_ptr(abs); // For primitive, check if the attribute is the same with cnode inferred attribute, if not, clone a new one if (attrs != nullptr) { value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs); @@ -1339,20 +1339,20 @@ AnfNodePtr FuncGraphSpecializer::BuildValueNodeForAbstractFunction(const AnfNode value = real_fn->prim(); } } else if (abs->isa()) { - auto real_fn = dyn_cast(abs); + auto real_fn = dyn_cast_ptr(abs); value = real_fn->meta_func_graph(); } else if (abs->isa()) { - auto real_fn = dyn_cast(abs); + auto real_fn = dyn_cast_ptr(abs); value = real_fn->func_graph(); } else { return nullptr; } MS_EXCEPTION_IF_NULL(value); - if (!value->isa() || value->cast()->parent() == nullptr || - (IsValueNode(origin_node) && IsVisible(func_graph_, value->cast()->parent()))) { + if (!value->isa() || value->cast_ptr()->parent() == nullptr || + (IsValueNode(origin_node) && IsVisible(func_graph_, value->cast_ptr()->parent()))) { return BuildValueNode(value, ival); } else if (IsPrimitiveCNode(cnode, prim::kPrimJ) && origin_node->isa() && - !value->cast()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) { + !value->cast_ptr()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) { // Only if J(Parameter=func_graph) and func_graph(aka 'value') is not K graph. MS_LOG(DEBUG) << "Specialize the parameter used by J CNode, cnode: " << cnode->DebugString(); return BuildValueNode(value, ival); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc index 6c9275d36aa..b1b1af8c92f 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc @@ -43,7 +43,7 @@ AnalysisContextPtr StackFrame::GetParentContext(const BaseFuncGraphEvaluatorPtr MS_EXCEPTION_IF_NULL(graph_func); MS_EXCEPTION_IF_NULL(fg_evaluator); AnalysisContextPtr parent_context = nullptr; - auto func_graph_abs = dyn_cast(graph_func); + auto func_graph_abs = dyn_cast_ptr(graph_func); if (func_graph_abs != nullptr) { // Set parent context for FuncGraphAbstractClosure. parent_context = func_graph_abs->context(); } else if (graph_func->isa()) { // Or DummyContext for MetaFuncGraphAbstractClosure. @@ -164,7 +164,7 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) { << ", current_context_: " << current_context_->ToString(); AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_, current_context_->func_graph()); EvalResultPtr node_eval_result = nullptr; - const auto &fg_evaluator = dyn_cast(evaluator()); + auto fg_evaluator = dyn_cast_ptr(evaluator()); if (fg_evaluator == nullptr) { MS_LOG(EXCEPTION) << "Evaluator should be a BaseGraphEvaluator, but got " << evaluator()->ToString(); } @@ -194,7 +194,7 @@ void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last // Check if child func graph contains isolated side-effect. if (engine->check_isolated_side_effect()) { if (last_stack_frame->func_graph()->has_isolated_side_effect_node()) { - auto cnode = dyn_cast(CurrentNode()); + auto cnode = dyn_cast_ptr(CurrentNode()); MS_EXCEPTION_IF_NULL(cnode); cnode->set_has_isolated_side_effect_node(true); cnode->func_graph()->set_has_isolated_side_effect_node(true); @@ -205,7 +205,7 @@ void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last auto evaluator = last_stack_frame->evaluator(); MS_EXCEPTION_IF_NULL(evaluator); evaluator->evaluator_cache_mgr()->SetValue(last_stack_frame->args_abs_list(), result); - const auto &fg_evaluator = dyn_cast(evaluator); + auto fg_evaluator = dyn_cast_ptr(evaluator); if (fg_evaluator == nullptr) { MS_LOG(EXCEPTION) << "Evaluator should be a BaseGraphEvaluator, but got " << evaluator->ToString(); } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 49b66cb7449..fc380544161 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -315,7 +315,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf return std::make_shared(possible_func->Clone(), std::make_shared()); } - AbstractFunctionPtr func = dyn_cast(possible_func); + auto func = dyn_cast_ptr(possible_func); if (func == nullptr) { CheckInterpretedObject(possible_func); MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << possible_func->ToString() << "."; @@ -357,11 +357,11 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list); // Check if func graph contains isolated side-effect, and sync. if (check_isolated_side_effect()) { - FuncGraphAbstractClosurePtr func_graph_abs = dyn_cast(func); + auto func_graph_abs = mindspore::cast(func); if (func_graph_abs != nullptr) { contains_isolated_side_effect |= func_graph_abs->func_graph()->has_isolated_side_effect_node(); } - MetaFuncGraphAbstractClosurePtr meta_func_graph_abs = dyn_cast(func); + auto meta_func_graph_abs = mindspore::cast(func); if (meta_func_graph_abs != nullptr) { contains_isolated_side_effect |= meta_func_graph_abs->meta_func_graph()->has_isolated_side_effect_node(); } @@ -652,7 +652,7 @@ EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, c // again, so update the config_map with new_conf; anfnode_config_map_[orig_conf] = new_conf; MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->ToString() << ", to new_conf: " << new_conf->ToString(); - auto old_cnode = orig_conf->node()->cast(); + auto old_cnode = orig_conf->node()->cast_ptr(); auto new_cnode = new_conf->node()->cast(); if (old_cnode != nullptr && new_cnode != nullptr) { if (old_cnode->func_graph() == new_cnode->func_graph()) { @@ -757,20 +757,20 @@ std::string JoinBranchesFailedInfo(const AbstractBasePtr &spec, const AbstractBa << spec->ToString() << ", and that of the previous branch is " << last_spec->ToString() << ".\n" << "The node is " << node->DebugString(recursive_level); if (node->isa()) { - auto cnode = node->cast()->input(0); + auto cnode = node->cast_ptr()->input(0); if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) { // {prim::kPrimSwitch, cond, true_branch, false_branch} constexpr int true_index = 2; constexpr int false_index = 3; - auto inputs = cnode->cast()->inputs(); + const auto &inputs = cnode->cast_ptr()->inputs(); buffer << ", true branch: " << inputs.at(true_index)->ToString() << ", false branch: " << inputs.at(false_index)->ToString(); } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) { // {prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, branch1, branch2, ...}} constexpr int branch_index = 2; - auto tuple_node = cnode->cast()->input(branch_index); + const auto &tuple_node = cnode->cast_ptr()->input(branch_index); if (IsPrimitiveCNode(tuple_node, prim::kPrimMakeTuple)) { - auto tuple_inputs = tuple_node->cast()->inputs(); + const auto &tuple_inputs = tuple_node->cast_ptr()->inputs(); for (size_t i = 1; i < tuple_inputs.size(); i++) { buffer << ", branch" << i << ": " << tuple_inputs.at(i); } @@ -827,7 +827,7 @@ bool NeedWaitForBranches(const AbstractBasePtr &abstract) { return true; } if (abstract->isa()) { - auto elements = abstract->cast()->elements(); + auto elements = abstract->cast_ptr()->elements(); if (std::any_of(elements.begin(), elements.end(), [](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) { return true; @@ -888,7 +888,7 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs, const std::vector &pending_async_abstract_list, const std::vector &index) { - const auto sequence_abs = dyn_cast(orig_abs); + auto sequence_abs = dyn_cast_ptr(orig_abs); if (sequence_abs != nullptr) { const auto &orig_elements = sequence_abs->elements(); AbstractBasePtrList new_elements; @@ -1153,7 +1153,7 @@ AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &cont static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0"); if (enable_eliminate_unused_element && value->isa()) { auto abs = value->ToAbstract(); - auto sequence_abs = dyn_cast(abs); + auto sequence_abs = dyn_cast_ptr(abs); MS_EXCEPTION_IF_NULL(sequence_abs); if (anf_node != nullptr) { SetSequenceNodeElementsUseFlags(anf_node, std::make_shared>(sequence_abs->elements().size())); @@ -1179,12 +1179,12 @@ EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrLi if (evaluator == nullptr) { MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ")."; } - auto trivial_evaluator = dyn_cast(evaluator); + auto trivial_evaluator = dyn_cast_ptr(evaluator); if (trivial_evaluator != nullptr) { return trivial_evaluator->EvalPrim(nullptr, arg_specs); } // Support MakeTuple/MakeList ops in PyNative mode. - auto transition_evaluator = dyn_cast(evaluator); + auto transition_evaluator = dyn_cast_ptr(evaluator); if (transition_evaluator != nullptr && (transition_evaluator->isa() || transition_evaluator->isa())) { return transition_evaluator->EvalPrim(nullptr, arg_specs, nullptr, nullptr); diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index f9b6837b8f1..38f82cd7482 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -146,15 +146,15 @@ void ValidateValueNode(const AnfNodePtr &node) { } void CheckValueTuple(const AnfNodePtr &node) { - const auto &value_node = node->cast(); + auto value_node = node->cast_ptr(); MS_EXCEPTION_IF_NULL(value_node); - const auto value = value_node->value(); + const auto &value = value_node->value(); MS_EXCEPTION_IF_NULL(value); - const auto value_tuple = value->cast(); + auto value_tuple = value->cast_ptr(); MS_EXCEPTION_IF_NULL(value_tuple); - const auto tuple_values = value_tuple->value(); + const auto &tuple_values = value_tuple->value(); for (size_t i = 0; i < tuple_values.size(); ++i) { - const auto input_node = NewValueNode(tuple_values[i]); + auto input_node = NewValueNode(tuple_values[i]); ValidateOperation(input_node); ValidateValueNode(input_node); } @@ -167,7 +167,7 @@ void Validate(const FuncGraphPtr &func_graph) { for (auto node : all_nodes) { TraceGuard guard(std::make_shared(node->debug_info())); while (IsPrimitiveCNode(node, prim::kPrimReturn) || IsPrimitiveCNode(node, prim::kPrimDepend)) { - node = node->cast()->input(1); + node = node->cast_ptr()->input(1); } if (IsValueNode(node)) { CheckValueTuple(node); diff --git a/mindspore/core/abstract/abstract_function.cc b/mindspore/core/abstract/abstract_function.cc index d4ff55d01ad..84ccbc89d78 100644 --- a/mindspore/core/abstract/abstract_function.cc +++ b/mindspore/core/abstract/abstract_function.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) { } return std::make_shared(this_func, other); } - auto other_union = dyn_cast(other); + auto other_union = dyn_cast_ptr(other); MS_EXCEPTION_IF_NULL(other_union); if (other_union->IsSuperSet(this_func)) { return other; @@ -122,7 +122,7 @@ AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) { } return std::make_shared(this_func, other); } - auto other_union = dyn_cast(other); + auto other_union = dyn_cast_ptr(other); MS_EXCEPTION_IF_NULL(other_union); if (other_union->IsSuperSet(this_func)) { return other; diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index aa71ed1aaa7..5c2b8380e58 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -173,14 +173,14 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { if (*this == *other) { return shared_from_base(); } - auto type_self = GetTypeTrack(); - auto type_other = other->GetTypeTrack(); + const auto &type_self = GetTypeTrack(); + const auto &type_other = other->GetTypeTrack(); TypePtr res_type = TypeJoin(type_self, type_other); if (res_type == kAnyType) { TypeJoinLogging(type_self, type_other, shared_from_base(), other); } - auto value_self = GetValueTrack(); - auto value_other = other->GetValueTrack(); + const auto &value_self = GetValueTrack(); + const auto &value_other = other->GetValueTrack(); ValuePtr res_value = ValueJoin(value_self, value_other); if (res_value == value_self) { return shared_from_base(); @@ -193,7 +193,7 @@ AbstractBasePtr AbstractType::Clone() const { if (value_self == nullptr || !value_self->isa()) { return nullptr; } - TypePtr type_self = value_self->cast(); + auto type_self = value_self->cast_ptr(); return std::make_shared(type_self->Clone()); } @@ -201,7 +201,8 @@ bool AbstractType::operator==(const AbstractBase &other) const { if (this == &other) { return true; } - return tid() == other.tid() && IsEqual(dyn_cast(GetValueTrack()), dyn_cast(other.GetValueTrack())); + return tid() == other.tid() && + IsEqual(dyn_cast_ptr(GetValueTrack()), dyn_cast_ptr(other.GetValueTrack())); } std::string AbstractType::ToString() const { @@ -215,7 +216,7 @@ std::string AbstractType::ToString() const { buffer << type_name() << "(Value: nullptr)"; return buffer.str(); } - TypePtr type_self = value_self->cast(); + auto type_self = value_self->cast_ptr(); buffer << type_name() << "(" << "Value: " << type_self->ToString() << ")"; return buffer.str(); @@ -496,7 +497,7 @@ AnfNodeWeakPtrList AbstractSequence::SequenceNodesJoin(const AbstractBasePtr &ot if (!enable_eliminate_unused_element || this->sequence_nodes() == nullptr) { return sequence_nodes; } - auto other_sequence = dyn_cast(other); + auto other_sequence = dyn_cast_ptr(other); if (other_sequence == nullptr) { return sequence_nodes; } @@ -714,7 +715,7 @@ template MS_CORE_API ValuePtr AbstractSequence::ElementsBuildValue() template AbstractBasePtr AbstractSequence::ElementsJoin(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); - auto other_sequeue = dyn_cast(other); + auto other_sequeue = dyn_cast_ptr(other); if (other_sequeue == nullptr) { AbstractTypeJoinLogging(shared_from_base(), other); } @@ -761,7 +762,7 @@ bool AbstractSequence::operator==(const AbstractSequence &other) const { } void AbstractTuple::set_shape(const BaseShapePtr &shape) { - auto tuple_shape = dyn_cast(shape); + auto tuple_shape = dyn_cast_ptr(shape); MS_EXCEPTION_IF_NULL(tuple_shape); if (tuple_shape->shape().size() != elements_.size()) { MS_LOG(EXCEPTION) << "Size mismatch: " << tuple_shape->shape().size() << " vs " << elements_.size(); @@ -789,7 +790,7 @@ bool AbstractTuple::ContainsAllBroadenTensors() const { for (size_t i = 0; i < elements_.size(); ++i) { if (!(elements_[i]->isa() && elements_[i]->IsBroaden()) && !(elements_[i]->isa() && - elements_[i]->cast()->ContainsAllBroadenTensors())) { + elements_[i]->cast_ptr()->ContainsAllBroadenTensors())) { return false; } } @@ -927,7 +928,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { // AbstractTensor join with AbstractUndetermined if (other_type->type_id() == kObjectTypeUndeterminedType) { - auto other_undetermined_tensor = dyn_cast(other); + auto other_undetermined_tensor = dyn_cast_ptr(other); MS_EXCEPTION_IF_NULL(other_undetermined_tensor); // Check shape auto res_shape = ShapeJoin(shape(), other_undetermined_tensor->shape()); @@ -941,7 +942,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { } // AbstractTensor join with AbstractTensor - auto other_tensor = dyn_cast(other); + auto other_tensor = dyn_cast_ptr(other); if (other_tensor == nullptr) { AbstractTypeJoinLogging(shared_from_base(), other); } @@ -964,8 +965,8 @@ bool AbstractTensor::equal_to(const AbstractTensor &other) const { return true; } // Check value. for AbstractTensor, both value should be AnyValue. - auto v1 = GetValueTrack(); - auto v2 = other.GetValueTrack(); + const auto &v1 = GetValueTrack(); + const auto &v2 = other.GetValueTrack(); if (v1 != v2 && (v1 == nullptr || !v1->isa() || v2 == nullptr || !v2->isa())) { return false; } @@ -1142,7 +1143,7 @@ TypePtr AbstractJTagged::BuildType() const { AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); - auto other_jtagged = dyn_cast(other); + auto other_jtagged = dyn_cast_ptr(other); if (other_jtagged == nullptr) { AbstractTypeJoinLogging(shared_from_base(), other); } @@ -1182,8 +1183,8 @@ AbstractRefTensor::AbstractRefTensor(const AbstractTensorPtr &ref_value, const V TypePtr AbstractRefTensor::BuildType() const { auto type = AbstractTensor::BuildType(); - MS_EXCEPTION_IF_NULL(type); - auto subtype = type->cast(); + auto subtype = dyn_cast_ptr(type); + MS_EXCEPTION_IF_NULL(subtype); return std::make_shared(subtype); } @@ -1430,7 +1431,7 @@ BaseShapePtrList AbstractSparseTensor::ElementsShapeTupleRecursive() const { BaseShapePtrList element_shape_list; for (const auto &element : elements()) { MS_EXCEPTION_IF_NULL(element); - auto abs_tuple = element->cast(); + auto abs_tuple = element->cast_ptr(); if (abs_tuple == nullptr) { element_shape_list.push_back(element->BuildShape()); } else { diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index 3e4d388dbb5..1c4e3579dd7 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -120,17 +120,17 @@ class MS_CORE_API AbstractBase : public Base { /// \brief Get the abstract value, which is tracked. /// /// \return A pointer to the Value. - ValuePtr GetValueTrack() const { return value_; } + const ValuePtr &GetValueTrack() const { return value_; } /// \brief Get the abstract type, which is tracked. /// /// \return A pointer to the Type. - TypePtr GetTypeTrack() const { return type_; } + const TypePtr &GetTypeTrack() const { return type_; } /// \brief Get the abstract shape, which is tracked. /// /// \return A pointer to the BaseShape. - BaseShapePtr GetShapeTrack() const { return shape_; } + const BaseShapePtr &GetShapeTrack() const { return shape_; } /// \brief Try to build a real value from an abstract value. /// diff --git a/mindspore/core/abstract/analysis_context.cc b/mindspore/core/abstract/analysis_context.cc index 8f63ea40a65..5333bcb77d0 100644 --- a/mindspore/core/abstract/analysis_context.cc +++ b/mindspore/core/abstract/analysis_context.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -141,8 +141,8 @@ bool AnalysisContext::operator==(const AnalysisContext &other) const { if (func_graph_->has_flag(GRAPH_FLAG_IS_WHILE_HEADER) && args_spec_list_[i]->isa() && other.args_spec_list_[i]->isa()) { - auto temp_this = args_spec_list_[i]->cast()->Copy(); - auto temp_other = other.args_spec_list_[i]->cast()->Copy(); + auto temp_this = args_spec_list_[i]->cast_ptr()->Copy(); + auto temp_other = other.args_spec_list_[i]->cast_ptr()->Copy(); temp_this->set_tracking_id(nullptr); temp_other->set_tracking_id(nullptr); if (!(*temp_this == *temp_other)) { diff --git a/mindspore/core/abstract/param_validator.cc b/mindspore/core/abstract/param_validator.cc index abd4907a75d..f182dbb5d50 100644 --- a/mindspore/core/abstract/param_validator.cc +++ b/mindspore/core/abstract/param_validator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ ABSTRACT_REPORT_NAME_DEC(KeywordArg) TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &error_message_prefix) { auto ori_type = type; if (type->isa()) { - auto tensor = type->cast(); + auto tensor = type->cast_ptr(); type = tensor->element(); MS_EXCEPTION_IF_NULL(type); } diff --git a/mindspore/core/abstract/utils.cc b/mindspore/core/abstract/utils.cc index 2801afb935a..112926f2121 100644 --- a/mindspore/core/abstract/utils.cc +++ b/mindspore/core/abstract/utils.cc @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -212,7 +212,7 @@ AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const Abstrac } AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) { - AbstractFunctionPtr f_spec = dyn_cast(spec); + auto f_spec = dyn_cast_ptr(spec); if (f_spec != nullptr) { return std::make_shared(kAnyValue, std::make_shared()); } @@ -285,7 +285,7 @@ AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) { auto ret_shape = std::make_shared(ret_vec, min_shape_vec, max_shape_vec); if (type->isa()) { - auto tensor_type = type->cast(); + auto tensor_type = type->cast_ptr(); MS_EXCEPTION_IF_NULL(tensor_type); auto element = std::make_shared(kAnyValue, tensor_type->element()); tensor = std::make_shared(element, ret_shape); @@ -319,8 +319,8 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type } return MakeAbstractTensor(shape, type); } else if (base_shape->isa() && type->isa()) { - auto shape_tuple = base_shape->cast(); - auto type_tuple = type->cast(); + auto shape_tuple = base_shape->cast_ptr(); + auto type_tuple = type->cast_ptr(); AbstractBasePtrList ptr_list; for (size_t it = 0; it < shape_tuple->size(); ++it) { auto tensor_it = MakeAbstract((*shape_tuple)[it], (*type_tuple)[it]); @@ -329,8 +329,8 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type auto tuple = std::make_shared(ptr_list); return tuple; } else if (base_shape->isa() && type->isa()) { - auto shape_list = base_shape->cast(); - auto type_list = type->cast(); + auto shape_list = base_shape->cast_ptr(); + auto type_list = type->cast_ptr(); AbstractBasePtrList ptr_list; for (size_t it = 0; it < shape_list->size(); ++it) { auto tensor_it = MakeAbstract((*shape_list)[it], (*type_list)[it]); diff --git a/mindspore/core/base/base_ref.cc b/mindspore/core/base/base_ref.cc index 0a9aed3e80b..a1074f66507 100644 --- a/mindspore/core/base/base_ref.cc +++ b/mindspore/core/base/base_ref.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,12 +42,12 @@ bool BaseRef::operator==(const BaseRef &other) const { return false; } if (m_ptr->isa()) { - return *(m_ptr->cast()) == *(other.m_ptr->cast()); + return *(m_ptr->cast_ptr()) == *(other.m_ptr->cast_ptr()); } // for noderef equal if (m_ptr->isa()) { - return *std::static_pointer_cast(m_ptr) == *std::static_pointer_cast(other.m_ptr); + return *(m_ptr->cast_ptr()) == *(other.m_ptr->cast_ptr()); } // for node equal diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 405dd9c212a..4714eea2913 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -1173,7 +1173,11 @@ inline S GetValueNode(const AnfNodePtr &node) { template ::value, S>::type * = nullptr> inline S *GetValuePtr(const AnfNodePtr &node) { - auto value = GetValuePtr(node); + auto value_node = dyn_cast_ptr(node); + if (value_node == nullptr) { + return nullptr; + } + const auto &value = value_node->value(); return (value == nullptr) ? nullptr : value->cast_ptr(); } diff --git a/mindspore/core/ir/dtype/ref.h b/mindspore/core/ir/dtype/ref.h index 084cdb23fd5..bce3283140a 100644 --- a/mindspore/core/ir/dtype/ref.h +++ b/mindspore/core/ir/dtype/ref.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,6 +55,11 @@ class MS_CORE_API RefType final : public TensorType { /// \param[in] subtype Define the TensorType for RefType object to refer to. explicit RefType(const TensorTypePtr &subtype) : TensorType(subtype->element()) {} + /// \brief Constructor for RefType. + /// + /// \param[in] subtype Define the TensorType for RefType object to refer to. + explicit RefType(const TensorType *subtype) : TensorType(subtype->element()) {} + /// \brief Destructor of RefType. ~RefType() override {} MS_DECLARE_PARENT(RefType, TensorType) diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index 13c4022a8e7..c2a26b84a00 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -592,7 +592,7 @@ std::string FuncGraph::GetVariableArgName() { const auto ¶m_node = GetVariableArgParameter(); MS_EXCEPTION_IF_NULL(param_node); - const auto ¶meter = param_node->cast(); + auto parameter = param_node->cast_ptr(); MS_EXCEPTION_IF_NULL(parameter); return parameter->name(); } @@ -614,7 +614,7 @@ std::string FuncGraph::GetVariableKwargName() { MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is " << fv_param_count_ << ", parameters is less than 1 + fv_param_count"; } - const auto ¶meter = parameters_[(parameters_.size() - fv_param_count_) - 1]->cast(); + auto parameter = parameters_[(parameters_.size() - fv_param_count_) - 1]->cast_ptr(); MS_EXCEPTION_IF_NULL(parameter); return parameter->name(); } @@ -666,7 +666,7 @@ int FuncGraph::GetPositionalArgsCount() const { AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { for (size_t i = 0; i < parameters_.size(); ++i) { MS_EXCEPTION_IF_NULL(parameters_[i]); - auto param_cast = parameters_[i]->cast(); + auto param_cast = parameters_[i]->cast_ptr(); MS_EXCEPTION_IF_NULL(param_cast); if (param_cast->name() == name) { return parameters_[i]; diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index 1eef5f17538..7a051a5ad64 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -87,7 +87,7 @@ void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) { void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(target); - auto old_param = node->cast(); + auto old_param = node->cast_ptr(); MS_EXCEPTION_IF_NULL(old_param); auto debug_info = CloneNodeDebugInfo(node->debug_info(), relation_); auto new_param = (is_add ? target->add_parameter(std::move(debug_info)) @@ -136,7 +136,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node) { ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope(); new_const->set_scope(scope); new_const->set_abstract(node->abstract()); - new_const->set_has_new_value(node->cast()->has_new_value()); + new_const->set_has_new_value(node->cast_ptr()->has_new_value()); repl_node_[node] = std::move(new_const); } @@ -148,7 +148,7 @@ void Cloner::CloneFuncGraphValueNode(const AnfNodePtr &node, const FuncGraphPtr ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope(); new_const->set_scope(scope); new_const->set_abstract(node->abstract()); - new_const->set_has_new_value(node->cast()->has_new_value()); + new_const->set_has_new_value(node->cast_ptr()->has_new_value()); repl_node_[node] = std::move(new_const); } @@ -229,7 +229,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func auto &cnodes = func_graph->func_graph_cnodes_index(); for (auto &cnode : cnodes) { - auto parent = cnode.first->first->cast(); + auto parent = cnode.first->first->cast_ptr(); MS_EXCEPTION_IF_NULL(parent); const auto &valuenode = parent->input(IntToSize(cnode.first->second)); CloneFuncGraphValueNode(valuenode, target_func_graph); @@ -291,7 +291,7 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) { auto free_var_node = utils::cast(free_var); // Don't lift weight parameter to top func_graph. if (IsLiftTopFuncGraph(func_graph) && free_var_node->isa()) { - auto free_var_param = free_var_node->cast(); + auto free_var_param = free_var_node->cast_ptr(); if (free_var_param->has_default()) { MS_LOG(DEBUG) << "Bypass weight param: " << free_var_param->DebugString() << " for top_func_graph: " << lift_top_func_graph->ToString(); @@ -306,7 +306,7 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) { } MS_LOG(DEBUG) << "Gen param: " << free_var_node->ToString() << " for func_graph: " << func_graph->ToString(); - auto fv_parameter = AddParameter(func_graph, utils::cast(free_var)); + auto fv_parameter = AddParameter(func_graph, free_var_node); fv_parameter->set_user_data("lifted_from_fv", std::make_shared(true)); auto &fg_params = repl_func_graph_params_[func_graph]; (void)fg_params.emplace_back(fv_parameter); @@ -316,7 +316,7 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) { void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) const { param->set_abstract(node->abstract()); if (node->isa()) { - ParameterPtr old_param = node->cast(); + auto old_param = node->cast_ptr(); if (old_param->has_default()) { // Default parameter can be shared since it is readonly. param->set_default_param(old_param->default_param()); @@ -728,11 +728,11 @@ void Cloner::CloneNodes() { void Cloner::LinkEdges() { for (auto &repl : repl_node_) { - CNodePtr old_node = dyn_cast(repl.first); + auto old_node = dyn_cast_ptr(repl.first); if (old_node == nullptr) { continue; } - CNodePtr new_node = repl.second->cast(); + auto new_node = repl.second->cast_ptr(); MS_EXCEPTION_IF_NULL(new_node); for (auto &input : old_node->inputs()) { auto iter = repl_node_.find(input); diff --git a/mindspore/core/ir/func_graph_extends.cc b/mindspore/core/ir/func_graph_extends.cc index 9467a23a29c..8ca8eb1ca1e 100644 --- a/mindspore/core/ir/func_graph_extends.cc +++ b/mindspore/core/ir/func_graph_extends.cc @@ -131,7 +131,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(), [param_name](const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto param = node->cast(); + auto param = node->cast_ptr(); return param != nullptr && param->name() == param_name; }); if (find_kw_arg_in_list) { diff --git a/mindspore/core/ir/graph_utils.cc b/mindspore/core/ir/graph_utils.cc index 522825c9490..179ad1f7e84 100644 --- a/mindspore/core/ir/graph_utils.cc +++ b/mindspore/core/ir/graph_utils.cc @@ -186,8 +186,8 @@ std::vector SuccDeeper(const AnfNodePtr &node) { return vecs; } - if (IsValueNode(node)) { - auto graph = GetValueNode(node); + auto graph = GetValuePtr(node); + if (graph != nullptr) { auto &ret = graph->return_node(); if (ret != nullptr) { vecs.push_back(ret); @@ -209,19 +209,16 @@ std::vector SuccDeeperSimple(const AnfNodePtr &node) { return vecs; } - if (IsValueNode(node)) { - auto graph = GetValueNode(node); + auto graph = GetValuePtr(node); + if (graph != nullptr) { auto &ret = graph->return_node(); if (ret != nullptr) { vecs.push_back(ret); } - return vecs; - } else { - if (node->isa()) { - FetchCNodeSuccessors(node->cast(), &vecs); - } - return vecs; + } else if (node->isa()) { + FetchCNodeSuccessors(node->cast(), &vecs); } + return vecs; } std::vector SuccIncoming(const AnfNodePtr &node) { @@ -234,26 +231,24 @@ std::vector SuccIncoming(const AnfNodePtr &node) { } std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node) { - std::vector vecs; - if (node == nullptr) { - return vecs; - } auto cnode = dyn_cast(node); - if (cnode != nullptr) { - auto &inputs = cnode->inputs(); - // Check if free variables used. - for (const auto &input : inputs) { - auto input_fg = GetValueNode(input); - if (input_fg) { - for (auto &fv : input_fg->free_variables_nodes()) { - if (fv->func_graph() == fg && fg->nodes().contains(fv)) { - vecs.push_back(fv); - } + if (cnode == nullptr) { + return {}; + } + std::vector vecs; + const auto &inputs = cnode->inputs(); + // Check if free variables used. + for (const auto &input : inputs) { + auto input_fg = GetValuePtr(input); + if (input_fg != nullptr) { + for (auto &fv : input_fg->free_variables_nodes()) { + if (fv->func_graph() == fg && fg->nodes().contains(fv)) { + vecs.push_back(fv); } } } - FetchCNodeSuccessors(cnode, &vecs); } + FetchCNodeSuccessors(cnode, &vecs); return vecs; } @@ -263,28 +258,24 @@ std::vector SuccWithFilter(const GraphFilterFunc &graph_filter, cons return vecs; } - if (IsValueNode(node)) { - auto graph = GetValueNode(node); + auto graph = GetValueNode(node); + if (graph != nullptr) { if (graph_filter != nullptr && graph_filter(graph)) { return vecs; } - auto &ret = graph->return_node(); if (ret != nullptr) { vecs.push_back(ret); } - return vecs; - } else { - if (node->isa()) { - FetchCNodeSuccessors(node->cast(), &vecs); - } - return vecs; + } else if (node->isa()) { + FetchCNodeSuccessors(node->cast(), &vecs); } + return vecs; } const std::vector &GetInputs(const AnfNodePtr &node) { static std::vector empty_inputs; - auto cnode = dyn_cast(node); + auto cnode = dyn_cast_ptr(node); if (cnode != nullptr) { return cnode->inputs(); } diff --git a/mindspore/core/ir/graph_utils_extends.cc b/mindspore/core/ir/graph_utils_extends.cc index b1c65347fd8..8955bd561fd 100644 --- a/mindspore/core/ir/graph_utils_extends.cc +++ b/mindspore/core/ir/graph_utils_extends.cc @@ -92,8 +92,8 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher { if (!IsValueNode(vnode)) { return; } - auto fg = GetValueNode(vnode); - AnfNodePtr ret = fg->return_node(); + auto fg = GetValuePtr(vnode); + const auto &ret = fg->return_node(); DeepFirstSearcher::Visit(ret); } diff --git a/mindspore/core/ir/manager.cc b/mindspore/core/ir/manager.cc index 852ca3eb1ea..37708161290 100644 --- a/mindspore/core/ir/manager.cc +++ b/mindspore/core/ir/manager.cc @@ -661,7 +661,7 @@ void FuncGraphManager::MoveAllCNodeDropGraph(const FuncGraphPtr &source, const F const ScopePtr &scope) { AnfNodePtr source_return = source->get_return(); AnfNodePtr source_output = source->output(); - AnfNodePtr source_prim = source_return->cast()->input(0); + AnfNodePtr source_prim = source_return->cast_ptr()->input(0); int index = 0; (void)node_users_[source_prim].erase(make_pair(source_return, index)); @@ -1105,7 +1105,7 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se std::find_if(meta_fg_prim_values.begin(), meta_fg_prim_values.end(), [seen_num](const auto &iter) { // Check g1->MetaFgPrim(fg)->g2->g cycle. if (IsValueNode(iter.first)) { - auto func_graph = GetValueNode(iter.first); + auto func_graph = GetValuePtr(iter.first); return func_graph->seen_ != seen_num; } if (IsValueNode(iter.first)) { diff --git a/mindspore/core/ir/value.h b/mindspore/core/ir/value.h index 275b18269e0..2a9cf52f4bb 100644 --- a/mindspore/core/ir/value.h +++ b/mindspore/core/ir/value.h @@ -435,7 +435,7 @@ IMM_TRAITS(StringImmPtr, const char *) /// \brief RefKey defines a class whose real type is String. /// \brief Notice: RefKey is keep for compatible only, we use RefKey just as StringImm. -class MS_CORE_API RefKey : public StringImm { +class MS_CORE_API RefKey final : public StringImm { public: /// \brief Constructor of RefKey. /// diff --git a/mindspore/core/ir/visitor.cc b/mindspore/core/ir/visitor.cc index 3866052e970..a97a7c4fac1 100644 --- a/mindspore/core/ir/visitor.cc +++ b/mindspore/core/ir/visitor.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,8 +27,8 @@ void AnfIrVisitor::Visit(const CNodePtr &cnode) { } void AnfIrVisitor::Visit(const ValueNodePtr &vnode) { - if (IsValueNode(vnode)) { - auto func_graph = GetValueNode(vnode); + auto func_graph = GetValuePtr(vnode); + if (func_graph != nullptr) { Visit(func_graph->output()); } } @@ -36,36 +36,34 @@ void AnfIrVisitor::Visit(const ValueNodePtr &vnode) { void AnfIrVisitor::Visit(const ParameterPtr &) {} VisitFuncType AnfIrVisitor::Match(const PrimitivePtr &prim, const std::vector &funcs) { - auto fn = [prim, funcs, this](const AnfNodePtr &node) { + return [prim, funcs, this](const AnfNodePtr &node) { if (!IsPrimitiveCNode(node, prim)) { return; } - auto &inputs = node->cast()->inputs(); + auto &inputs = node->cast_ptr()->inputs(); auto funcs_size = funcs.size(); auto inputs_size = inputs.size(); - // check the inputs are matched with the predicate functions + // Check the inputs are matched with the predicate functions. if (funcs_size > 0) { - // use the predicate function list to check the number of inputs + // Use the predicate function list to check the number of inputs. if (funcs_size != (inputs_size - 1)) { return; } - // check per input - for (size_t i = 0; i < funcs_size; i++) { + // Check inputs. + for (size_t i = 0; i < funcs_size; ++i) { if (!funcs[i](inputs[i + 1])) { return; } } } - // visit the inputs - for (size_t i = 1; i < inputs_size; i++) { + // Visit argument inputs. + for (size_t i = 1; i < inputs_size; ++i) { this->Visit(inputs[i]); } }; - - return fn; } } // namespace mindspore diff --git a/mindspore/core/utils/ms_utils.h b/mindspore/core/utils/ms_utils.h index 94b24c11ad9..03584903c7b 100644 --- a/mindspore/core/utils/ms_utils.h +++ b/mindspore/core/utils/ms_utils.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -103,7 +103,7 @@ static inline bool UseMPI() { } template -inline bool IsEqual(const std::shared_ptr &a, const std::shared_ptr &b) { +bool IsEqual(const T *a, const T *b) { if (a == b) { return true; } @@ -113,6 +113,11 @@ inline bool IsEqual(const std::shared_ptr &a, const std::shared_ptr &b) { return *a == *b; } +template +bool IsEqual(const std::shared_ptr &a, const std::shared_ptr &b) { + return IsEqual(a.get(), b.get()); +} + template bool IsAttrsEqual(const T &a, const T &b) { if (&a == &b) {