!39471 Optimize pointer casting for compile framework

Merge pull request !39471 from hewei/opt_perf1
This commit is contained in:
i-robot 2022-08-03 07:33:29 +00:00 committed by Gitee
commit 6685a89540
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
39 changed files with 430 additions and 433 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -143,7 +143,7 @@ void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNode
MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
}
mindspore::HashMap<AnfNodePtr, FuncGraphPtr> node_to_fg;
auto tuple_graphs = input->cast<CNodePtr>();
auto tuple_graphs = input->cast_ptr<CNode>();
for (size_t i = 1; i < tuple_graphs->size(); ++i) {
auto graph = tuple_graphs->input(i);
if (!IsValueNode<FuncGraph>(graph)) {
@ -191,7 +191,7 @@ static AnfNodePtr SkipHookNodeInBackProp(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimHookBackward) || IsPrimitiveCNode(node, prim::kPrimCellBackwardHook)) {
MS_LOG(WARNING)
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
auto output_cnode = node->cast<CNodePtr>();
auto output_cnode = node->cast_ptr<CNode>();
if (output_cnode->size() - 1 == 1) {
return output_cnode->input(1);
}
@ -217,16 +217,16 @@ static AnfNodePtr SkipHookNodeInBackProp(const AnfNodePtr &node) {
return make_tuple;
}
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
auto tuple_get_item = node->cast<CNodePtr>();
auto tuple_get_item = node->cast_ptr<CNode>();
auto inp = tuple_get_item->input(1);
if (IsPrimitiveCNode(inp, prim::kPrimHookBackward) || IsPrimitiveCNode(inp, prim::kPrimCellBackwardHook)) {
MS_LOG(WARNING)
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
constexpr size_t idx = 2;
auto v_node = tuple_get_item->input(idx)->cast<ValueNodePtr>();
auto v_node = tuple_get_item->input(idx)->cast_ptr<ValueNode>();
MS_EXCEPTION_IF_NULL(v_node);
auto out_idx = GetValue<int64_t>(v_node->value());
return inp->cast<CNodePtr>()->input(LongToSize(out_idx) + 1);
return inp->cast_ptr<CNode>()->input(LongToSize(out_idx) + 1);
}
}
return node;
@ -238,7 +238,7 @@ AnfNodePtr HandleRealToComplex(const AnfNodePtr &input, const CNodePtr &din, con
if (input_type == nullptr || !input_type->isa<TensorType>()) {
return din;
}
input_type = input_type->cast<TensorTypePtr>()->element();
input_type = input_type->cast_ptr<TensorType>()->element();
MS_EXCEPTION_IF_NULL(input_type);
if (input_type->type_id() == kNumberTypeComplex64 || input_type->type_id() == kNumberTypeComplex128) {
return din;
@ -256,7 +256,7 @@ AnfNodePtr HandleRealToComplex(const AnfNodePtr &input, const CNodePtr &din, con
if (din_type == nullptr || !din_type->isa<TensorType>()) {
return din;
}
din_type = din_type->cast<TensorTypePtr>()->element();
din_type = din_type->cast_ptr<TensorType>()->element();
MS_EXCEPTION_IF_NULL(din_type);
if (din_type->type_id() != kNumberTypeComplex64 && din_type->type_id() != kNumberTypeComplex128) {
return din;
@ -689,8 +689,8 @@ void DFunctor::MapValueObject() {
AdjointPtr adjoint = nullptr;
if (IsValueNode<Primitive>(node)) { // Primitive.
auto prim = GetValueNode<PrimitivePtr>(node);
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn ||
auto prim = GetValuePtr<Primitive>(node);
if ((prim->Hash() == prim::kPrimReturn->hash() && prim->name() == prim::kPrimReturn->name()) ||
(prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) ||
(prim->Hash() == prim::kPrimCellBackwardHook->Hash() &&
prim->name() == prim::kPrimCellBackwardHook->name())) {
@ -784,17 +784,15 @@ void DFunctor::BroadCastStopFlag() {
while (need_cut_) {
need_cut_ = false;
for (auto &node : primal_graph_->nodes()) {
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (!cnode->stop_gradient()) {
// Cut off the cnode only when it's not referred any more
if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) ||
AllReferencesStopped(cnode)) {
MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
cnode->set_stop_gradient(true);
// The stop set changed, more cut required
need_cut_ = true;
}
auto cnode = dyn_cast<CNode>(node);
if (cnode != nullptr && !cnode->stop_gradient()) {
// Cut off the cnode only when it's not referred any more
if (cnode->IsApply(prim::kPrimStopGradient) || cnode->IsApply(prim::kPrimUpdateState) ||
AllReferencesStopped(cnode)) {
MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
cnode->set_stop_gradient(true);
// The stop set changed, more cut required
need_cut_ = true;
}
}
}
@ -809,7 +807,7 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
}
for (auto &kv : users) {
auto &user = kv.first;
if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
if (!user->isa<CNode>() || !user->cast_ptr<CNode>()->stop_gradient()) {
return false;
}
}

View File

@ -225,7 +225,7 @@ AnfNodePtr GetPythonOps(const FuncGraphPtr &fg, const AnfNodePtr &origin_node, c
} else {
python_ops_value = prim::GetPythonOps(iter->second.first);
}
auto origin_cnode = origin_node->cast<CNodePtr>();
auto origin_cnode = origin_node->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(origin_cnode);
auto &origin_inputs = origin_cnode->inputs();
std::vector<AnfNodePtr> new_inputs{NewValueNode(python_ops_value)};
@ -243,7 +243,7 @@ void ReplacePythonOps(const FuncGraphPtr &fg) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto cnode = node->cast_ptr<CNode>();
for (size_t i = 0; i < cnode->size(); ++i) {
auto prim = GetCNodePrimitive(cnode->input(i));
if (prim == nullptr) {
@ -388,7 +388,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceB
if (prim->is_base()) {
fn = GetBpropFunction(prim->name());
} else {
fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
fn = prim->cast_ptr<PrimitivePy>()->GetBpropFunction();
if (py::isinstance<py::none>(fn)) {
fn = GetBpropFunction(prim->name());
}
@ -491,8 +491,8 @@ static void AppendMonadOutput(const FuncGraphPtr &bprop_fg, const AnfNodePtr &mo
auto output_cnode = output->cast<CNodePtr>();
if (output_cnode != nullptr) {
// If output_cnode has the form like (make_tuple, x, y).
while (IsPrimitiveCNode(output_cnode, prim::kPrimDepend)) {
auto real_input = output_cnode->input(kRealInputIndexInDepend);
while (output_cnode->IsApply(prim::kPrimDepend)) {
const auto &real_input = output_cnode->input(kRealInputIndexInDepend);
MS_EXCEPTION_IF_NULL(real_input);
output_cnode = real_input->cast<CNodePtr>();
}
@ -555,7 +555,7 @@ void SetDumpFlag(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) {
return;
}
auto attr = prim->GetAttr(kAttrDump);
if (attr != nullptr && attr->isa<StringImm>() && attr->cast<StringImmPtr>()->value() == kValueTrue) {
if (attr != nullptr && attr->isa<StringImm>() && attr->cast_ptr<StringImm>()->value() == kValueTrue) {
bprop_fg->set_flag(FUNC_GRAPH_FLAG_DUMP, true);
}
}
@ -660,7 +660,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
// bprop_fg has been checked in caller
if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) {
// Set bprop output as (env, dx, dy, dz, ...)
auto cbprop = bprop_fg->output()->cast<CNodePtr>();
auto cbprop = bprop_fg->output()->cast_ptr<CNode>();
auto &inputs = cbprop->inputs();
std::vector<AnfNodePtr> args;
@ -742,7 +742,7 @@ void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const Func
const auto &current_primal_fg_params = current_primal_fg->parameters();
// The lifted parameters are put in front: {lifted parameters, origin parameters, u/io monad}.
for (size_t i = 0; i < current_primal_fg_params.size(); ++i) {
auto primal_parameter = dyn_cast<Parameter>(current_primal_fg_params[i]);
auto primal_parameter = dyn_cast_ptr<Parameter>(current_primal_fg_params[i]);
MS_EXCEPTION_IF_NULL(primal_parameter);
auto lifted = primal_parameter->template user_data<bool>(kLiftedUserDataKey);
if (lifted == nullptr || !*lifted) {
@ -845,7 +845,7 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res
if (cnode == users.end()) {
MS_LOG(EXCEPTION) << "Fail to find cnode.";
}
auto inputs_num = cnode->first->cast<CNodePtr>()->size() - 1;
auto inputs_num = cnode->first->cast_ptr<CNode>()->size() - 1;
auto func_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;
@ -886,7 +886,7 @@ FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::Re
if (cnode == users.end()) {
MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
}
auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;
auto inputs_num = cnode->first->cast_ptr<CNode>()->inputs().size() - 1;
auto effect_info = GetPrimEffectInfo(prim);
// Don't add U or IO monad parameters as it will be added later.
size_t monad_params_size = 0;

View File

@ -33,22 +33,25 @@ namespace mindspore {
/* namespace to support opt */
namespace opt {
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
static const std::map<std::string, std::vector<std::string>> op2attrs = {
{prim::kPrimBroadcastTo->name(), {kAttrShape}},
{prim::kPrimReduceMax->name(), {kAttrKeepDims}},
{prim::kPrimReduceMin->name(), {kAttrKeepDims}},
{prim::kPrimReduceSum->name(), {kAttrKeepDims}}};
auto todos = TopoSort(graph->get_return());
for (const auto &node : todos) {
if (!node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
continue;
}
auto primitive = GetCNodePrimitive(node);
if (!primitive || dyn_cast<PrimitivePy>(primitive)) {
if (primitive == nullptr || primitive->isa<PrimitivePy>()) {
continue;
}
parallel::OperatorAttrs attrs;
std::map<std::string, std::vector<std::string>> op2attrs = {{prim::kPrimBroadcastTo->name(), {kAttrShape}},
{prim::kPrimReduceMax->name(), {kAttrKeepDims}},
{prim::kPrimReduceMin->name(), {kAttrKeepDims}},
{prim::kPrimReduceSum->name(), {kAttrKeepDims}}};
if (op2attrs.count(primitive->name()) != 0) {
for (auto &attr : op2attrs[primitive->name()]) {
auto iter = op2attrs.find(primitive->name());
if (iter != op2attrs.end()) {
for (auto &attr : iter->second) {
if (primitive->HasAttr(attr)) {
(void)attrs.emplace_back(std::pair{attr, primitive->GetAttr(attr)});
} else {
@ -57,10 +60,10 @@ bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
}
}
}
auto new_prim = parallel::CreateOpInstance(attrs, primitive->name(), "")->cast<PrimitivePtr>();
(void)new_prim->SetAttrs(primitive->attrs());
auto new_prim = parallel::CreateOpInstance(attrs, primitive->name(), "");
(void)new_prim->cast_ptr<Primitive>()->SetAttrs(primitive->attrs());
AnfNodePtrList inputs = {NewValueNode(new_prim)};
auto cnode = dyn_cast<CNode>(node);
auto cnode = dyn_cast_ptr<CNode>(node);
(void)inputs.insert(inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
cnode->set_inputs(inputs);
}

View File

@ -28,7 +28,7 @@ bool ContainSparseTensor(const abstract::AbstractBasePtr &abs) {
return true;
}
if (abs->isa<abstract::AbstractTuple>()) {
auto vec = abs->cast<abstract::AbstractTuplePtr>()->elements();
auto vec = abs->cast_ptr<abstract::AbstractTuple>()->elements();
return std::any_of(vec.begin(), vec.end(), ContainSparseTensor);
}
return false;

View File

@ -40,27 +40,20 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std:
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action,
bool has_priority_pattern) {
auto fn = [prims](const AnfNodePtr &node) -> bool {
if (!node->isa<CNode>()) {
auto cnode = dyn_cast_ptr<CNode>(node);
if (cnode == nullptr) {
return false;
}
auto cnode = node->cast<CNodePtr>();
auto inp0 = cnode->input(0);
auto prim0 = GetValueNode<PrimitivePtr>(inp0);
if (prim0 == nullptr) {
auto cnode_prim = GetValuePtr<Primitive>(cnode->input(0));
if (cnode_prim == nullptr) {
return false;
}
auto hash = prim0->Hash();
auto const &name = prim0->name();
for (auto &prim : prims) {
if (hash == prim->Hash() && name == prim->name()) {
return true;
}
}
return false;
auto hash = cnode_prim->Hash();
const auto &name = cnode_prim->name();
return std::any_of(prims.begin(), prims.end(), [&hash, &name](const PrimitivePtr &prim) {
return (prim->Hash() == hash) && (prim->name() == name);
});
};
return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
}
@ -93,17 +86,15 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode
return result;
}
static bool isTraversable(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
}
static inline bool isTraversable(const AnfNodePtr &node) {
if (node->isa<CNode>() || node->isa<Parameter>()) {
return true;
}
if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) {
return true;
}
return false;
// FuncGraph or RefKey value node is traversable.
auto value_node = dyn_cast_ptr<ValueNode>(node);
MS_EXCEPTION_IF_NULL(value_node);
const auto &value = value_node->value();
return (value != nullptr) && (value->isa<FuncGraph>() || value->isa<RefKey>());
}
static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node,
@ -131,34 +122,38 @@ static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &n
}
static void UpdateTransformingListForSubstitutions(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change) {
if (IsValueNode<FuncGraph>(node)) {
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
auto fg = GetValuePtr<FuncGraph>(node);
if (fg != nullptr) {
(void)todo->emplace_back(fg->output());
}
if (change) {
(*todo).emplace_back(node);
(void)todo->emplace_back(node);
} else {
if (node->isa<CNode>()) {
auto &inputs = node->cast<CNodePtr>()->inputs();
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
auto cnode = dyn_cast_ptr<CNode>(node);
if (cnode != nullptr) {
const auto &inputs = cnode->inputs();
(void)todo->insert(todo->end(), inputs.cbegin(), inputs.cend());
}
}
}
static void UpdateTransformingListForIR(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change,
const SubstitutionPtr &substitution) {
if (IsValueNode<FuncGraph>(node)) {
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
auto fg = GetValuePtr<FuncGraph>(node);
if (fg != nullptr) {
(void)todo->emplace_back(fg->output());
}
// If there is a priority pattern in substitution, don't transform the new node,
// otherwise some nodes may match the wrong patterns.
if (change && substitution != nullptr && !substitution->has_priority_pattern_) {
(*todo).emplace_back(node);
(void)todo->emplace_back(node);
} else {
if (node->isa<CNode>()) {
auto &inputs = node->cast<CNodePtr>()->inputs();
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
auto cnode = dyn_cast_ptr<CNode>(node);
if (cnode != nullptr) {
const auto &inputs = cnode->inputs();
(void)todo->insert(todo->end(), inputs.cbegin(), inputs.cend());
}
}
}
@ -194,12 +189,11 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
FuncGraphManagerPtr manager = optimizer->manager();
auto seen = NewSeenGeneration();
std::deque<AnfNodePtr> todo;
todo.emplace_back(func_graph->output());
(void)todo.emplace_back(func_graph->output());
bool changes = false;
auto &all_nodes = manager->all_nodes();
while (!todo.empty()) {
AnfNodePtr node = todo.front();
AnfNodePtr node = std::move(todo.front());
todo.pop_front();
if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
@ -373,13 +367,13 @@ bool SimpleRewriter::Run() {
continue;
}
node->seen_ = seen;
auto cnode = node->cast<CNodePtr>();
auto cnode = node->cast_ptr<CNode>();
if (cnode != nullptr) {
for (auto &input : cnode->inputs()) {
add_todo(input);
}
} else {
auto fg = GetValueNode<FuncGraphPtr>(node);
auto fg = GetValuePtr<FuncGraph>(node);
if (fg != nullptr) {
add_todo(fg->output());
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -43,7 +43,7 @@ MatchResultPtr Call::match(const AnfNodePtr &node) {
}
MatchResultPtr res = std::make_shared<MatchResult>();
// IsPrimitiveCNode
auto cnode = node->cast<CNodePtr>();
auto cnode = node->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(cnode);
// Check Primitive ValueNode
if (prim_pattern_ != nullptr) {
@ -120,20 +120,16 @@ MatchResultPtr Any::match(const AnfNodePtr &node) {
}
MatchResultPtr Imm::match(const AnfNodePtr &node) {
if (!IsValueNode<Int32Imm>(node)) {
auto value_ptr = GetValuePtr<Int32Imm>(node);
if (value_ptr == nullptr) {
return nullptr;
}
// Check value
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value_ptr = value_node->value()->cast<Int32ImmPtr>();
MS_EXCEPTION_IF_NULL(value_ptr);
if (value_ptr->value() == value_) {
MatchResultPtr res = std::make_shared<MatchResult>();
res->add_entry(shared_from_base<Imm>(), node);
return res;
if (value_ptr->value() != value_) {
return nullptr;
}
return nullptr;
MatchResultPtr res = std::make_shared<MatchResult>();
res->add_entry(shared_from_base<Imm>(), node);
return res;
}
AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -57,7 +57,7 @@ bool IsTraversable(const AnfNodePtr &node) {
AnfNodePtr BuildPrimitive(const PatternPtr &pattern) {
// Build up AnfNode from primitive
auto prim_pattern = pattern->cast<PrimPtr>();
auto prim_pattern = pattern->cast_ptr<Prim>();
MS_EXCEPTION_IF_NULL(prim_pattern);
PrimitivePyPtr prim = prim_pattern->matched_primitive();
MS_EXCEPTION_IF_NULL(prim);
@ -67,7 +67,7 @@ AnfNodePtr BuildPrimitive(const PatternPtr &pattern) {
AnfNodePtr BuildNewTensor(const PatternPtr &pattern) {
// Build a ValueNode from TensorPtr
auto new_tensor_pattern = pattern->cast<NewTensorPtr>();
auto new_tensor_pattern = pattern->cast_ptr<NewTensor>();
MS_EXCEPTION_IF_NULL(new_tensor_pattern);
auto input_tensor = new_tensor_pattern->input_tensor();
MS_EXCEPTION_IF_NULL(input_tensor);
@ -76,7 +76,7 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern) {
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg,
const FuncGraphPtr &top_graph) {
auto call_pattern = pattern->cast<CallPtr>();
auto call_pattern = pattern->cast_ptr<Call>();
MS_EXCEPTION_IF_NULL(call_pattern);
auto prim = call_pattern->prim_value();
if (prim != nullptr) {
@ -88,7 +88,7 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP
}
AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &top_graph) {
auto new_para_pattern = pattern->cast<NewParameterPtr>();
auto new_para_pattern = pattern->cast_ptr<NewParameter>();
MS_EXCEPTION_IF_NULL(new_para_pattern);
if (!new_para_pattern->built()) {
static int64_t parameter_id = 0;
@ -122,7 +122,7 @@ AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &re
}
AnfNodePtr BuildImmNode(const PatternPtr &pattern) {
auto imm_pattern = pattern->cast<ImmPtr>();
auto imm_pattern = pattern->cast_ptr<Imm>();
MS_EXCEPTION_IF_NULL(imm_pattern);
auto value = imm_pattern->value();
auto scalar_value_ptr = std::make_shared<Int64Imm>(value);
@ -134,7 +134,7 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
auto target_node = res->get_node(pattern);
if (target_node != nullptr) {
// If pattern is NewParameter, check whether it shouldn't last and is not built
auto new_para = pattern->cast<NewParameterPtr>();
auto new_para = pattern->cast_ptr<NewParameter>();
if (new_para == nullptr || new_para->should_last() || new_para->built()) {
return target_node;
}
@ -218,20 +218,20 @@ void ReflectParamBackToPython(const AnfNodePtr &param, const string &param_name,
MS_LOG(EXCEPTION) << "Failed to convert new parameter to ValuePtr.";
}
MS_EXCEPTION_IF_NULL(param);
auto param_node = param->cast<ParameterPtr>();
auto param_node = param->cast_ptr<Parameter>();
MS_EXCEPTION_IF_NULL(param_node);
param_node->set_default_param(param_value);
}
void Reset(const PatternPtr &pattern) {
if (pattern->isa<Prim>()) {
auto prim_pattern = pattern->cast<PrimPtr>();
auto prim_pattern = pattern->cast_ptr<Prim>();
prim_pattern->reset();
} else if (pattern->isa<NewParameter>()) {
auto new_param_pattern = pattern->cast<NewParameterPtr>();
auto new_param_pattern = pattern->cast_ptr<NewParameter>();
new_param_pattern->reset();
} else if (pattern->isa<Call>()) {
auto call_with_pattern = pattern->cast<CallPtr>();
auto call_with_pattern = pattern->cast_ptr<Call>();
for (const auto &sub_pattern : call_with_pattern->inputs()) {
Reset(sub_pattern);
}
@ -257,7 +257,7 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
MS_EXCEPTION_IF_NULL(dst_pattern_);
if (src_pattern_ == nullptr) {
// Add NewParameter
auto new_para_pattern = dst_pattern_->cast<NewParameterPtr>();
auto new_para_pattern = dst_pattern_->cast_ptr<NewParameter>();
if (new_para_pattern == nullptr) {
MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null.";
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -81,7 +81,7 @@ void PyPassManager::GenNewParameter(const PatternPtr &parameter) {
auto cur_pg = GetPassGroup(Phase::OPT);
MS_EXCEPTION_IF_NULL(cur_pg);
cur_pg->SetRunOnlyOnce(true);
auto new_para_pattern = parameter->cast<NewParameterPtr>();
auto new_para_pattern = parameter->cast_ptr<NewParameter>();
MS_EXCEPTION_IF_NULL(new_para_pattern);
auto pass_name = new_para_pattern->para_name();
new_para_pattern->set_last(true);

View File

@ -59,7 +59,7 @@ bool WithRecomputedScope(const AnfNodePtr &node) {
ValuePtr GetRecomputeCNodeAttr(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
auto cnode = node->cast_ptr<CNode>();
if (cnode == nullptr) {
return nullptr;
}
@ -222,7 +222,7 @@ bool HasGradInputs(const AnfNodePtr &node, mindspore::HashMap<AnfNodePtr, bool>
if (has_grad_inputs_map->find(node) != has_grad_inputs_map->end()) {
return has_grad_inputs_map->find(node)->second;
}
auto cnode = node->cast<CNodePtr>();
auto cnode = node->cast_ptr<CNode>();
if (cnode == nullptr) {
(void)has_grad_inputs_map->emplace(node, false);
return false;
@ -230,7 +230,7 @@ bool HasGradInputs(const AnfNodePtr &node, mindspore::HashMap<AnfNodePtr, bool>
const auto &inputs = cnode->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
// For the pipeline split case, the forward pass may depend on the backward pass.
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && i == kDependAttachNodeIndex) {
if (cnode->IsApply(prim::kPrimDepend) && i == kDependAttachNodeIndex) {
continue;
}
if (IsBpropNode(inputs[i]) || HasGradInputs(inputs[i], has_grad_inputs_map)) {
@ -320,7 +320,7 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
std::vector<AnfNodePtr> tuple_getitem_output_nodes;
GetTupleGetItemOutputNodes(mng, node, &tuple_getitem_output_nodes);
for (const auto &output_node : tuple_getitem_output_nodes) {
auto output_cnode = output_node->cast<CNodePtr>();
auto output_cnode = output_node->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(output_cnode);
output_cnode->AddAttr(kAttrRecompute, MakeValue(true));
}
@ -362,7 +362,7 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod
for (size_t i = 0; i < origin_node->size(); ++i) {
auto input = origin_node->input(i);
if (i == 0 && IsPrimitive(input, prim::kPrimAllGather)) {
auto prim = GetValueNode<PrimitivePtr>(input);
auto prim = GetValuePtr<Primitive>(input);
auto instance_name = prim->instance_name();
bool is_from_parallel_optimizer = instance_name.find("parallel_optimizer") != std::string::npos;
int64_t fusion_id = prim->HasAttr(kAttrFusion) ? GetValue<int64_t>(prim->GetAttr(kAttrFusion)) : 0;
@ -426,7 +426,7 @@ void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const mindspore::HashSe
for (const auto &target_node : target_nodes) {
MS_EXCEPTION_IF_NULL(target_node);
MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input";
auto target_cnode = target_node->cast<CNodePtr>();
auto target_cnode = target_node->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(target_cnode);
std::vector<AnfNodePtr> new_target_inputs;
for (const auto &input : target_cnode->inputs()) {

View File

@ -278,7 +278,7 @@ FuncGraphPtr CompileCacheManager::GetCachedFuncGraph(const FuncGraphManagerPtr &
// The value of attr "shared_name" will changed every time.
auto cnodes = fg->GetOrderedCnodes();
for (const auto &cnode : cnodes) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
auto prim = GetValuePtr<Primitive>(cnode->input(0));
if (prim != nullptr && prim->HasAttr("shared_name")) {
prim->set_attr("shared_name", MakeValue(queue_name));
break;

View File

@ -22,6 +22,7 @@
#include <vector>
#include "ir/param_info.h"
#include "ir/value.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/parse/parse.h"
#include "include/common/utils/python_adapter.h"
@ -31,6 +32,7 @@
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/symbol_resolver.h"
#include "include/common/debug/anf_dump_utils.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace parse {
@ -452,7 +454,7 @@ AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::obj
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object namespace_obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj);
std::string attr_as_string = GetValueNode<StringImmPtr>(attr)->value();
const std::string &attr_as_string = GetValuePtr<StringImm>(attr)->value();
auto new_symbol = std::make_shared<Symbol>(attr_as_string);
MS_LOG(DEBUG) << "name_space: " << new_namespace->ToString() << ", symbol: " << new_symbol->ToString();
@ -489,7 +491,9 @@ AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py:
} else if (count_msclass == sequence_size) {
// Resolve MsClass instances.
for (size_t i = 0; i < sequence_size; ++i) {
auto attr_str = GetValue<std::string>(GetValueNode(attr));
auto attr_str_ptr = GetValuePtr<StringImm>(attr);
MS_EXCEPTION_IF_NULL(attr_str_ptr);
const auto &attr_str = attr_str_ptr->value();
auto res = ResolveMsClassWithAttr(manager, sequence[i], attr_str, get_attr_node);
(void)inputs.emplace_back(res);
}

View File

@ -214,7 +214,7 @@ void SetValueMutable(const abstract::AbstractBasePtr &abs) {
return;
}
auto abs_sequence = abs->cast<abstract::AbstractSequencePtr>();
auto abs_sequence = abs->cast_ptr<abstract::AbstractSequence>();
if (abs_sequence != nullptr) {
const auto &elements = abs_sequence->elements();
for (auto &ele : elements) {
@ -223,7 +223,7 @@ void SetValueMutable(const abstract::AbstractBasePtr &abs) {
return;
}
auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
auto abs_dict = abs->cast_ptr<abstract::AbstractDictionary>();
if (abs_dict != nullptr) {
const auto &elements = abs_dict->elements();
for (auto &ele : elements) {
@ -590,16 +590,15 @@ void GraphExecutorPy::GetWeightInfo(
auto x = root_node->input(1);
MS_EXCEPTION_IF_NULL(x);
if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
weight_name = weight_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
weight_name = weight_node->cast_ptr<CNode>()->input(1)->cast_ptr<Parameter>()->name();
} else {
auto para = weight_node->cast<ParameterPtr>();
auto para = weight_node->cast_ptr<Parameter>();
MS_EXCEPTION_IF_NULL(para);
weight_name = para->name();
}
// find the fakequant from input
int64_t count = 0;
const int64_t max_depth = 5;
CNodePtr cnode = nullptr;
auto is_quant_cnode = [](const AnfNodePtr &node) {
return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel) ||
@ -610,7 +609,7 @@ void GraphExecutorPy::GetWeightInfo(
if (count >= max_depth) {
break;
}
cnode = x->cast<CNodePtr>();
auto cnode = x->cast_ptr<CNode>();
if (cnode == nullptr || cnode->size() <= 1) {
break;
}
@ -624,9 +623,9 @@ void GraphExecutorPy::GetWeightInfo(
if (!is_quant_cnode(x)) {
return;
}
cnode = x->cast<CNodePtr>();
auto cnode = x->cast_ptr<CNode>();
constexpr size_t expect_input_size = 4;
if (cnode == nullptr || IsPrimitiveCNode(cnode, prim::kPrimLoad) || cnode->size() != expect_input_size) {
if (cnode == nullptr || cnode->IsApply(prim::kPrimLoad) || cnode->size() != expect_input_size) {
return;
}
const size_t fakequant_index = 2;
@ -636,18 +635,18 @@ void GraphExecutorPy::GetWeightInfo(
}
std::string fakequant_min_node_name;
if (IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
fakequant_min_node_name = fakequant_min_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
fakequant_min_node_name = fakequant_min_node->cast_ptr<CNode>()->input(1)->cast_ptr<Parameter>()->name();
} else {
auto param = fakequant_min_node->cast<ParameterPtr>();
auto param = fakequant_min_node->cast_ptr<Parameter>();
MS_EXCEPTION_IF_NULL(param);
fakequant_min_node_name = param->name();
}
auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
const auto &quant_op_value = cnode->input(0)->cast_ptr<ValueNode>()->value();
MS_EXCEPTION_IF_NULL(quant_op_value);
if (!quant_op_value->isa<PrimitivePy>()) {
return;
}
auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
auto quant_op = quant_op_value->cast_ptr<PrimitivePy>();
(*fake_quant_table)[weight_name] = std::make_pair(quant_op->adapter(), fakequant_min_node_name);
}
@ -677,7 +676,7 @@ std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> GraphExecut
}
auto weight = root_node->input(weight_index);
if (!is_quant_cnode(weight)) {
auto tuple_node = weight->cast<CNodePtr>();
auto tuple_node = weight->cast_ptr<CNode>();
if (tuple_node != nullptr) {
auto fake_node = tuple_node->input(1);
if (!is_quant_cnode(fake_node)) {
@ -688,7 +687,7 @@ std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> GraphExecut
}
}
// get parameter weight's name
auto cnode = weight->cast<CNodePtr>();
auto cnode = weight->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(cnode);
auto weight_node = cnode->input(weight_index);
if (!weight_node->isa<Parameter>() && !IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
@ -1207,7 +1206,7 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef
// Maybe some default parameter
for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) {
MS_EXCEPTION_IF_NULL(graph_params[i]);
auto param_ptr = (graph_params[i])->cast<ParameterPtr>();
auto param_ptr = (graph_params[i])->cast_ptr<Parameter>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (!param_ptr->has_default()) {
MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
@ -1346,7 +1345,7 @@ void GraphExecutorPy::UpdataParamNodeDefaultInput(
auto &params = func_graph->parameters();
for (const auto &param : params) {
MS_EXCEPTION_IF_NULL(param);
auto param_cast = param->cast<ParameterPtr>();
auto param_cast = param->cast_ptr<Parameter>();
MS_EXCEPTION_IF_NULL(param_cast);
auto iter = params_value.find(param_cast->name());
if (iter != params_value.end()) {

View File

@ -77,8 +77,7 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num
static bool HasVirtualDataset(const std::vector<AnfNodePtr> &all_nodes) {
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
if (IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) {
if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
return true;
}
}
@ -96,7 +95,7 @@ static CNodePtr CreateTupleGetItem(const AnfNodePtr &node, size_t index, const F
CNodePtr tuple_get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
MS_EXCEPTION_IF_NULL(tuple_get_item);
tuple_get_item->set_scope(node->scope());
auto input_abstract_tuple = node->abstract()->cast<abstract::AbstractTuplePtr>();
auto input_abstract_tuple = node->abstract()->cast_ptr<abstract::AbstractTuple>();
MS_EXCEPTION_IF_NULL(input_abstract_tuple);
auto tuple_get_item_abstract = input_abstract_tuple->elements()[index];
MS_EXCEPTION_IF_NULL(tuple_get_item_abstract);
@ -134,7 +133,7 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
}
std::set<AnfNodePtr> input_parameters;
for (auto &anf_param : root->parameters()) {
auto param = anf_param->cast<ParameterPtr>();
auto param = anf_param->cast_ptr<Parameter>();
if (!param->has_default()) {
(void)input_parameters.insert(anf_param);
}
@ -143,7 +142,7 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
auto node_users_map = root->manager()->node_users();
auto node_users = node_users_map[input_parameter];
for (auto node_user : node_users) {
auto cnode = node_user.first->cast<CNodePtr>();
auto cnode = node_user.first->cast_ptr<CNode>();
if (IsValueNode<Primitive>(cnode->inputs()[0]) ||
(IsValueNode<FuncGraph>(cnode->inputs()[0]) && !root->has_flag(parallel::kTraining))) {
(void)graph_sets.insert(cnode->func_graph());
@ -155,7 +154,7 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto cnode = node->cast_ptr<CNode>();
if ((cnode->size() < NODE_INPUT_NUM) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,13 @@
namespace mindspore {
namespace pipeline {
static inline bool IsSameValue(Value *v1, Value *v2) {
if (v1->isa<tensor::Tensor>() && v2->isa<tensor::Tensor>()) {
return static_cast<tensor::Tensor *>(v1)->ValueEqual(*(static_cast<tensor::Tensor *>(v2)));
}
return *v1 == *v2;
}
void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache,
HashValue *const hash_value) {
MS_EXCEPTION_IF_NULL(manager);
@ -35,7 +42,7 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has
if (IsValueNode<FuncGraph>(node)) {
return;
}
const auto &to_check_value = GetValueNode(node);
auto to_check_value = GetValuePtr(node);
MS_EXCEPTION_IF_NULL(to_check_value);
// Calculate hash value.
@ -62,20 +69,13 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has
if (v == node) {
return;
}
const auto &existed_value = GetValueNode(v);
auto existed_value = GetValuePtr(v);
MS_EXCEPTION_IF_NULL(existed_value);
auto equal = [&]() -> bool {
if (existed_value->isa<tensor::Tensor>() && to_check_value->isa<tensor::Tensor>()) {
return existed_value->cast<tensor::TensorPtr>()->ValueEqual(*(to_check_value->cast<tensor::TensorPtr>()));
}
return *existed_value == *to_check_value;
};
if (equal()) {
if (IsSameValue(existed_value, to_check_value)) {
(void)manager->Replace(node, v);
return;
}
}
// Meet for the first time, append node to bucket.
bucket.emplace_back(node);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -179,8 +179,9 @@ AbstractFunctionPtr GetAbstractFuncRecursively(const AbstractBasePtr &abs, const
const std::size_t offset) {
if (abs->isa<AbstractFuncAtom>()) {
return abs->cast<AbstractFuncAtomPtr>();
} else if (abs->isa<AbstractSequence>()) {
const auto &abs_seq = abs->cast<AbstractSequencePtr>();
}
if (abs->isa<AbstractSequence>()) {
auto abs_seq = abs->cast_ptr<AbstractSequence>();
MS_EXCEPTION_IF_NULL(abs_seq);
const auto &elements = abs_seq->elements();
if (offset >= index.size()) {
@ -190,12 +191,12 @@ AbstractFunctionPtr GetAbstractFuncRecursively(const AbstractBasePtr &abs, const
MS_LOG(EXCEPTION) << "At offset" << offset << ", elements size of AsyncAbstract result: " << abs->ToString()
<< " is less than or equal to index: " << index[offset];
}
const auto &resolved = GetAbstractFuncRecursively(elements[index[offset]], index, offset + 1);
auto resolved = GetAbstractFuncRecursively(elements[index[offset]], index, offset + 1);
if (!resolved->isa<AbstractFuncAtom>()) {
MS_LOG(EXCEPTION) << "AsyncAbstract result cannot be resolved to AbstractFuncAtom, but: " << resolved->ToString();
}
MS_LOG(DEBUG) << "Return abstract: " << resolved->ToString();
return resolved->cast<AbstractFuncAtomPtr>();
return resolved;
}
MS_LOG(EXCEPTION) << "AsyncAbstract cannot resolved to AbstractFuncAtom or AbstractSeqeunce, but: "
<< abs->ToString();

View File

@ -61,7 +61,7 @@ void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &,
} // namespace
bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
auto new_sequence = dyn_cast<AbstractSequence>(arg);
auto new_sequence = dyn_cast_ptr<AbstractSequence>(arg);
if (new_sequence != nullptr && new_sequence->sequence_nodes() != nullptr && new_sequence->size() != 0) {
static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
auto prev_result = cache_mgr.GetValue(conf);
@ -69,7 +69,7 @@ bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg)
return false;
}
auto prev_abs = prev_result->abstract();
auto old_sequence = dyn_cast<AbstractSequence>(prev_abs);
auto old_sequence = dyn_cast_ptr<AbstractSequence>(prev_abs);
if (old_sequence != nullptr &&
(old_sequence->sequence_nodes() == nullptr || old_sequence->sequence_nodes()->empty()) && *arg == *prev_abs) {
MS_LOG(DEBUG) << "Always eval";
@ -255,7 +255,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
<< fg->ToString() << "();";
}
auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>());
auto func_graph_evaluator = mindspore::cast<FuncGraphEvaluator>(this);
if (func_graph_evaluator != nullptr) {
if (engine->root_func_graph() == func_graph_evaluator->func_graph()) {
engine->set_root_context(context);
@ -551,7 +551,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
AbstractBasePtrList bparams;
bparams.push_back(SensitivityTransform(primal_func_));
// Check if primal func graph has the primitive returned sparse result in its bprop().
auto real_primal_func = dyn_cast<FuncGraphAbstractClosure>(primal_func_);
auto real_primal_func = dyn_cast_ptr<FuncGraphAbstractClosure>(primal_func_);
MS_EXCEPTION_IF_NULL(real_primal_func);
FuncGraphPtr primal_func_graph = real_primal_func->func_graph();
MS_EXCEPTION_IF_NULL(primal_func_graph);
@ -625,7 +625,7 @@ AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_
MS_LOG(EXCEPTION) << "The orig_abs should be AbstractTensor when axis is " << *axis << ", but got a "
<< orig_abs->ToString() << ".";
}
ShapeVector orig_shape = dyn_cast<abstract::Shape>(orig_abs->BuildShape())->shape();
ShapeVector orig_shape = dyn_cast_ptr<abstract::Shape>(orig_abs->BuildShape())->shape();
int shape_len = SizeToInt(orig_shape.size());
if (*axis < -shape_len || *axis >= shape_len) {
MS_LOG(EXCEPTION) << "The axis: " << *axis << " in 'in_axes' is out of bounds for array of dimension ["
@ -649,22 +649,21 @@ AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_
AbstractBasePtr GetLogicalViewAbs(const AbstractBasePtr &physical_view_abs, const ValuePtr &in_axes, int *axis_size) {
MS_EXCEPTION_IF_NULL(physical_view_abs);
MS_EXCEPTION_IF_NULL(in_axes);
auto physical_view_abs_sequence = dyn_cast<abstract::AbstractSequence>(physical_view_abs);
auto physical_view_abs_sequence = dyn_cast_ptr<abstract::AbstractSequence>(physical_view_abs);
if (physical_view_abs_sequence != nullptr) {
AbstractBasePtrList abs_list = physical_view_abs_sequence->elements();
AbstractBasePtrList logical_view_abs_list;
auto in_axes_seq = dyn_cast<ValueSequeue>(in_axes);
auto in_axes_seq = dyn_cast_ptr<ValueSequeue>(in_axes);
int index = 0;
(void)std::transform(
abs_list.begin(), abs_list.end(), std::back_inserter(logical_view_abs_list),
[&axis_size, &index, &in_axes_seq, in_axes](const AbstractBasePtr &sub_abs) -> AbstractBasePtr {
ValuePtr sub_in_axes = in_axes;
if (in_axes->isa<ValueSequeue>()) {
sub_in_axes = (*in_axes_seq)[index];
index++;
}
return GetLogicalViewAbs(sub_abs, sub_in_axes, axis_size);
});
(void)std::transform(abs_list.begin(), abs_list.end(), std::back_inserter(logical_view_abs_list),
[&axis_size, &index, in_axes_seq, in_axes](const AbstractBasePtr &sub_abs) -> AbstractBasePtr {
ValuePtr sub_in_axes = in_axes;
if (in_axes->isa<ValueSequeue>()) {
sub_in_axes = (*in_axes_seq)[index];
index++;
}
return GetLogicalViewAbs(sub_abs, sub_in_axes, axis_size);
});
if (physical_view_abs->isa<AbstractList>()) {
return std::make_shared<AbstractList>(logical_view_abs_list, physical_view_abs_sequence->sequence_nodes());
}
@ -672,7 +671,7 @@ AbstractBasePtr GetLogicalViewAbs(const AbstractBasePtr &physical_view_abs, cons
}
ValuePtr in_axis = in_axes;
if (in_axis->isa<Int64Imm>()) {
int axis = dyn_cast<Int64Imm>(in_axis)->value();
int axis = dyn_cast_ptr<Int64Imm>(in_axis)->value();
auto logical_view_abs = ReduceDim(&axis, physical_view_abs, axis_size);
return logical_view_abs;
}
@ -689,7 +688,7 @@ AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_s
AbstractBasePtr out_abs = nullptr;
ShapeVector orig_shape;
if (orig_abs->isa<abstract::AbstractTensor>()) {
orig_shape = dyn_cast<abstract::Shape>(orig_abs->BuildShape())->shape();
orig_shape = dyn_cast_ptr<abstract::Shape>(orig_abs->BuildShape())->shape();
}
int shape_len = SizeToInt(orig_shape.size() + 1);
if (*axis < -shape_len || *axis >= shape_len) {
@ -713,11 +712,11 @@ AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_s
AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, const ValuePtr &out_axes, int axis_size) {
MS_EXCEPTION_IF_NULL(logical_view_abs);
auto logical_view_abs_sequence = dyn_cast<abstract::AbstractSequence>(logical_view_abs);
auto logical_view_abs_sequence = dyn_cast_ptr<abstract::AbstractSequence>(logical_view_abs);
if (logical_view_abs_sequence != nullptr) {
AbstractBasePtrList logical_view_abs_list = logical_view_abs_sequence->elements();
AbstractBasePtrList physical_view_abs_list;
auto out_axes_seq = dyn_cast<ValueSequeue>(out_axes);
auto out_axes_seq = dyn_cast_ptr<ValueSequeue>(out_axes);
if (out_axes_seq != nullptr) {
if (logical_view_abs_list.size() != out_axes_seq->size()) {
MS_LOG(EXCEPTION) << "The size of vmap's 'out_axes' should be equal to the number of results of 'fn': "
@ -727,7 +726,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons
int index = 0;
(void)std::transform(
logical_view_abs_list.begin(), logical_view_abs_list.end(), std::back_inserter(physical_view_abs_list),
[&axis_size, &index, &out_axes_seq, out_axes](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
[&axis_size, &index, out_axes_seq, out_axes](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
ValuePtr sub_out_axes = out_axes;
if (out_axes->isa<ValueSequeue>()) {
sub_out_axes = (*out_axes_seq)[index];
@ -737,7 +736,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons
return GetPhysicalViewAbs(arg_spec, sub_out_axes, axis_size);
}
if (sub_out_axes->isa<Int64Imm>()) {
int axis = dyn_cast<Int64Imm>(sub_out_axes)->value();
int axis = dyn_cast_ptr<Int64Imm>(sub_out_axes)->value();
return ExtendDim(&axis, arg_spec, axis_size);
} else if (sub_out_axes->isa<None>()) {
return arg_spec;
@ -763,7 +762,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons
}
int axis = 0;
auto axis_int_ptr = dyn_cast<Int64Imm>(sub_out_axes);
auto axis_int_ptr = dyn_cast_ptr<Int64Imm>(sub_out_axes);
if (axis_int_ptr != nullptr) {
axis = LongToInt(axis_int_ptr->value());
} else {
@ -787,9 +786,9 @@ EvalResultPtr VmapEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &
int axis_size = -1;
int index = 0;
auto in_axes = in_axes_;
auto in_axes_seq = dyn_cast<ValueSequeue>(in_axes);
auto in_axes_seq = dyn_cast_ptr<ValueSequeue>(in_axes);
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
[&axis_size, &index, &in_axes_seq, in_axes](const ConfigPtr &conf) -> AbstractBasePtr {
[&axis_size, &index, in_axes_seq, in_axes](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
AbstractBasePtr abs = conf->ObtainEvalResult()->abstract();
// Drop the side effect tag parameters, because it has no mapping axis.

View File

@ -61,7 +61,7 @@ std::pair<bool, bool> InterpretAbstractBoolChecker(const AbstractBasePtr &cond)
auto value = cond->BuildValue();
if (value->isa<parse::InterpretedObject>()) {
is_interpret = true;
auto interpreted_obj = value->cast<parse::InterpretedObjectPtr>();
auto interpreted_obj = value->cast_ptr<parse::InterpretedObject>();
py::object obj = interpreted_obj->obj();
constexpr char PYTHON_MOD_PARSE_MODULE[] = "mindspore._extends.parse";
constexpr char PYTHON_MOD_CHECK_OBJ_BOOL[] = "check_obj_bool";
@ -114,10 +114,10 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
});
// Do undetermined infer firstly.
auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>();
auto do_signature = prim_->cast_ptr<prim::DoSignaturePrimitive>();
MS_EXCEPTION_IF_NULL(do_signature);
auto &func = do_signature->function();
auto do_signature_func = dyn_cast<Primitive>(func);
auto do_signature_func = dyn_cast_ptr<Primitive>(func);
if (do_signature_func != nullptr) {
if (prims_to_skip_undetermined_infer.find(do_signature_func->name()) == prims_to_skip_undetermined_infer.end()) {
auto ret_abstract = EvalUndeterminedArgs(args_spec_list);
@ -168,11 +168,11 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) {
MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]);
if (specialize_args_before_unpack[index]->isa<AbstractTuple>()) {
auto arg_tuple = specialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
auto arg_tuple = specialize_args_before_unpack[index]->cast_ptr<AbstractTuple>();
std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(),
std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; });
} else if (specialize_args_before_unpack[index]->isa<AbstractDictionary>()) {
auto arg_dict = specialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
auto arg_dict = specialize_args_before_unpack[index]->cast_ptr<AbstractDictionary>();
auto dict_elems = arg_dict->elements();
(void)std::transform(
dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args),
@ -197,9 +197,9 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
auto unpack_graph = prim_->cast_ptr<prim::UnpackGraphPrimitive>();
MS_EXCEPTION_IF_NULL(unpack_graph);
auto out_node = out_conf->node()->cast<CNodePtr>();
auto out_node = out_conf->node()->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(out_node);
const auto &out_node_inputs = out_node->inputs();
if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
@ -219,11 +219,11 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
MS_LOG(EXCEPTION) << "args_spec_list can't be empty.";
}
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
auto fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
auto fn = args_spec_list[0]->cast_ptr<AbstractFunction>();
if (fn == nullptr) {
MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
}
auto real_fn = fn->cast<FuncGraphAbstractClosurePtr>();
auto real_fn = fn->cast_ptr<FuncGraphAbstractClosure>();
MS_EXCEPTION_IF_NULL(real_fn);
FuncGraphPtr forward_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(forward_graph);
@ -255,14 +255,14 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
MS_EXCEPTION_IF_NULL(func_graph);
AnfNodePtr target_node = source_node;
if (node_type->isa<AbstractTensor>()) {
auto x = node_type->cast<AbstractTensorPtr>();
auto x = node_type->cast_ptr<AbstractTensor>();
if (x->element()->BuildType()->isa<Float>()) {
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
MS_EXCEPTION_IF_NULL(cast);
target_node = func_graph->NewCNodeAfter(source_node, {NewValueNode(cast), source_node, target_type});
}
} else if (node_type->isa<AbstractTuple>()) {
auto x = node_type->cast<AbstractTuplePtr>();
auto x = node_type->cast_ptr<AbstractTuple>();
auto &items = x->elements();
std::vector<AnfNodePtr> nodes;
nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
@ -276,7 +276,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
}
target_node = func_graph->NewCNode(nodes);
} else if (node_type->isa<AbstractDictionary>()) {
auto x = node_type->cast<AbstractDictionaryPtr>();
auto x = node_type->cast_ptr<AbstractDictionary>();
auto &items = x->elements();
std::vector<AnfNodePtr> dict_key_nodes;
std::vector<AnfNodePtr> dict_value_nodes;
@ -293,7 +293,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(std::move(dict_key_nodes)),
func_graph->NewCNode(std::move(dict_value_nodes))});
} else if (node_type->isa<AbstractKeywordArg>()) {
auto x = node_type->cast<AbstractKeywordArgPtr>();
auto x = node_type->cast_ptr<AbstractKeywordArg>();
std::string kwarg_key = x->get_key();
AnfNodePtr kwarg_value_node =
func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node});
@ -336,7 +336,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
if (new_node->isa<CNode>()) {
auto new_cnode = new_node->cast<CNodePtr>();
auto new_cnode = new_node->cast_ptr<CNode>();
new_cnode->CloneCNodeInfo(out_node);
}
return engine->ForwardConfig(out_conf, fn_conf);
@ -439,7 +439,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base, bool only_conver
if (dyn_shape_value) {
for (size_t i = 0; i < len; i++) {
if (!res[i].contains(py::str(ATTR_SHAPE_VALUE))) {
auto const_abstract_value = arg_tuple->elements()[i]->cast<abstract::AbstractTensorPtr>();
auto const_abstract_value = arg_tuple->elements()[i]->cast_ptr<abstract::AbstractTensor>();
MS_EXCEPTION_IF_NULL(const_abstract_value);
auto const_value = const_abstract_value->BuildValue();
MS_EXCEPTION_IF_NULL(const_value);
@ -547,7 +547,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert
if (shape_value) {
for (size_t i = 0; i < len; i++) {
if (!res[i].contains(py::str(ATTR_SHAPE_VALUE))) {
auto const_abstract_value = arg_list->elements()[i]->cast<abstract::AbstractTensorPtr>();
auto const_abstract_value = arg_list->elements()[i]->cast_ptr<abstract::AbstractTensor>();
MS_EXCEPTION_IF_NULL(const_abstract_value);
auto const_value = const_abstract_value->BuildValue();
MS_EXCEPTION_IF_NULL(const_value);
@ -565,7 +565,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert
}
void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, bool only_convert_value, py::dict *dic) {
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
auto arg_tensor = dyn_cast_ptr<AbstractTensor>(abs_base);
MS_EXCEPTION_IF_NULL(dic);
MS_EXCEPTION_IF_NULL(arg_tensor);
if (only_convert_value) {
@ -601,16 +601,16 @@ py::object GetPyObjForPrimitiveAbstract(const PrimitiveAbstractClosurePtr &prim_
return py::none();
}
if (prim->isa<prim::DoSignaturePrimitive>()) {
auto do_sig_prim = prim->cast<prim::DoSignaturePrimitivePtr>();
auto do_sig_prim = prim->cast_ptr<prim::DoSignaturePrimitive>();
auto value = do_sig_prim->function();
if (!value->isa<PrimitivePy>()) {
return py::none();
}
auto prim_py = value->cast<PrimitivePyPtr>();
auto prim_py = value->cast_ptr<PrimitivePy>();
return prim_py->GetPyObj();
}
if (prim->isa<PrimitivePy>()) {
auto prim_py = prim->cast<PrimitivePyPtr>();
auto prim_py = prim->cast_ptr<PrimitivePy>();
return prim_py->GetPyObj();
}
return py::none();
@ -621,7 +621,7 @@ bool IsCallInstance(const PartialAbstractClosurePtr &partial_abs) {
if (!fn->isa<PrimitiveAbstractClosure>()) {
return false;
}
auto abs_prim = fn->cast<PrimitiveAbstractClosurePtr>();
auto abs_prim = fn->cast_ptr<PrimitiveAbstractClosure>();
auto prim = abs_prim->prim();
if (prim->name() == prim::kPrimCallInstance->name()) {
return true;
@ -643,14 +643,14 @@ void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *
auto value = args[0]->BuildValue();
MS_EXCEPTION_IF_NULL(value);
if (IsCallInstance(partial_abs)) {
auto value_obj = value->cast<parse::MsClassObjectPtr>();
auto value_obj = value->cast_ptr<parse::MsClassObject>();
if (value_obj != nullptr) {
(*dic)[ATTR_DTYPE] = std::make_shared<MsClassType>();
(*dic)[ATTR_VALUE] = value_obj->obj();
return;
}
}
auto value_obj = value->cast<parse::ClassTypePtr>();
auto value_obj = value->cast_ptr<parse::ClassType>();
if (value_obj != nullptr) {
(*dic)[ATTR_DTYPE] = std::make_shared<TypeType>();
(*dic)[ATTR_VALUE] = value_obj->obj();
@ -708,35 +708,35 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_conv
} else if (abs_base->isa<AbstractList>()) {
return AbstractListToPython(abs_base, only_convert_value);
} else if (abs_base->isa<AbstractSlice>()) {
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
auto arg_slice = dyn_cast_ptr<AbstractSlice>(abs_base);
ShapeVector shape;
dic[ATTR_SHAPE] = shape;
dic[ATTR_DTYPE] = arg_slice->BuildType();
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
} else if (abs_base->isa<AbstractRowTensor>()) {
auto arg = dyn_cast<AbstractRowTensor>(abs_base);
auto arg = dyn_cast_ptr<AbstractRowTensor>(abs_base);
dic[ATTR_SHAPE] = arg->shape()->shape();
dic[ATTR_DTYPE] = arg->BuildType();
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
} else if (abs_base->isa<AbstractCOOTensor>()) {
auto arg = dyn_cast<AbstractCOOTensor>(abs_base);
auto arg = dyn_cast_ptr<AbstractCOOTensor>(abs_base);
AbstractBasePtrList sparse_shape = arg->shape()->elements();
ShapeVector sparse_shape_vector;
(void)std::transform(sparse_shape.begin(), sparse_shape.end(), std::back_inserter(sparse_shape_vector),
[](const AbstractBasePtr &e) -> int64_t {
ValuePtr value = e->cast<AbstractScalarPtr>()->BuildValue();
ValuePtr value = e->cast_ptr<AbstractScalar>()->BuildValue();
return GetValue<int64_t>(value);
});
dic[ATTR_SHAPE] = sparse_shape_vector;
dic[ATTR_DTYPE] = arg->BuildType();
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
} else if (abs_base->isa<AbstractCSRTensor>()) {
auto arg = dyn_cast<AbstractCSRTensor>(abs_base);
auto arg = dyn_cast_ptr<AbstractCSRTensor>(abs_base);
AbstractBasePtrList sparse_shape = arg->shape()->elements();
ShapeVector sparse_shape_vector;
(void)std::transform(sparse_shape.begin(), sparse_shape.end(), std::back_inserter(sparse_shape_vector),
[](const AbstractBasePtr &e) -> int64_t {
ValuePtr value = e->cast<AbstractScalarPtr>()->BuildValue();
ValuePtr value = e->cast_ptr<AbstractScalar>()->BuildValue();
return GetValue<int64_t>(value);
});
dic[ATTR_SHAPE] = sparse_shape_vector;
@ -753,7 +753,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_conv
} else if (abs_base->isa<AbstractFunction>()) {
ConvertAbstractFunctionToPython(abs_base, &dic);
} else if (abs_base->isa<AbstractUndetermined>()) {
auto arg = dyn_cast<AbstractUndetermined>(abs_base);
auto arg = dyn_cast_ptr<AbstractUndetermined>(abs_base);
dic[ATTR_SHAPE] = py::none();
dic[ATTR_DTYPE] = arg->BuildType();
dic[ATTR_VALUE] = py::none();
@ -804,7 +804,7 @@ void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBa
<< "]'s attribute[output_num]:" << output_num << " not matches the infer result "
<< res_spec->ToString();
} else if (res_spec->isa<AbstractTuple>() &&
(res_spec->cast<AbstractTuplePtr>()->size() != LongToSize(output_num))) {
(res_spec->cast_ptr<AbstractTuple>()->size() != LongToSize(output_num))) {
MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString()
<< "]'s attribute[output_num]:" << output_num << " not matches the infer result "
<< res_spec->ToString();
@ -833,7 +833,7 @@ void SetShapeValue(const AbstractBasePtr &tensor, const py::object &output) {
if (!converted) {
MS_LOG(EXCEPTION) << "Convert shape max value data failed";
}
auto abs_tensor = dyn_cast<abstract::AbstractTensor>(tensor);
auto abs_tensor = dyn_cast_ptr<abstract::AbstractTensor>(tensor);
abs_tensor->set_value_range(min_value, max_value);
}
if (!output.contains(py::str(ATTR_SHAPE_VALUE))) {
@ -849,7 +849,7 @@ void SetShapeValue(const AbstractBasePtr &tensor, const py::object &output) {
if (!converted) {
MS_LOG(EXCEPTION) << "Convert shape value data failed";
}
auto abs_tensor = dyn_cast<abstract::AbstractTensor>(tensor);
auto abs_tensor = dyn_cast_ptr<abstract::AbstractTensor>(tensor);
abs_tensor->set_shape_value(shape_value);
}
@ -1023,7 +1023,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &,
MS_EXCEPTION_IF_NULL(res_spec);
if (res_spec->isa<AbstractTensor>()) {
// Replace to tensor constant node in specialize
auto res_tensor = res_spec->cast<AbstractTensorPtr>();
auto res_tensor = res_spec->cast_ptr<AbstractTensor>();
res_tensor->set_value(converted_ret);
}
return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
@ -1248,8 +1248,8 @@ EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Ab
// Check if all arguments are scalar type.
MS_EXCEPTION_IF_NULL(arg);
if (arg->isa<AbstractScalar>()) {
auto arg_scalar = dyn_cast<AbstractScalar>(arg);
auto arg_value = arg_scalar->GetValueTrack();
auto arg_scalar = dyn_cast_ptr<AbstractScalar>(arg);
const auto &arg_value = arg_scalar->GetValueTrack();
value_list.push_back(arg_value);
} else {
// Raise TypeError Expected Scalar.
@ -1335,18 +1335,18 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
AbstractBasePtr abstract = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
// Create new cnode
std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abstract);
auto func_graph_func = dyn_cast_ptr<abstract::FuncGraphAbstractClosure>(abstract);
if (func_graph_func != nullptr) {
FuncGraphPtr fg = func_graph_func->func_graph();
input.push_back(NewValueNode(fg));
} else {
auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abstract);
auto prim_func = dyn_cast_ptr<abstract::PrimitiveAbstractClosure>(abstract);
MS_EXCEPTION_IF_NULL(prim_func);
PrimitivePtr prim = prim_func->prim();
input.push_back(NewValueNode(prim));
}
AnfNodeConfigPtr conf = dyn_cast<abstract::AnfNodeConfig>(data_conf);
auto conf = dyn_cast_ptr<abstract::AnfNodeConfig>(data_conf);
MS_EXCEPTION_IF_NULL(conf);
input.push_back(conf->node());
MS_EXCEPTION_IF_NULL(old_conf);
@ -1367,15 +1367,15 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con
MS_EXCEPTION_IF_NULL(attr_value);
ValuePtr item_value = attr_value;
if (item_value->isa<StringImm>()) {
item_value = std::make_shared<parse::Symbol>(item_value->cast<StringImmPtr>()->value());
item_value = std::make_shared<parse::Symbol>(item_value->cast_ptr<StringImm>()->value());
}
if (!item_value->isa<parse::Symbol>()) {
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
}
// item_name to func addr from obj_map
parse::SymbolPtr symbol = item_value->cast<parse::SymbolPtr>();
parse::NameSpacePtr name_space = data_value->cast<parse::NameSpacePtr>();
auto symbol = item_value->cast<parse::SymbolPtr>();
auto name_space = data_value->cast<parse::NameSpacePtr>();
MS_EXCEPTION_IF_NULL(out_conf);
auto out_node = out_conf->node();
FuncGraphPtr func_graph = out_node->func_graph();
@ -1429,12 +1429,12 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &,
if (!item_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
}
std::string item_name = item_value->cast<StringImmPtr>()->value();
const auto &item_name = item_value->cast_ptr<StringImm>()->value();
// Get ms_class object.
if (!data_value->isa<parse::MsClassObject>()) {
MS_LOG(EXCEPTION) << "Expect a ms_class object, but got " << data_value->ToString();
}
auto ms_class = data_value->cast<parse::MsClassObjectPtr>();
auto ms_class = data_value->cast_ptr<parse::MsClassObject>();
MS_LOG(DEBUG) << "Resolve ms_class (" << ms_class->name() << ") with item " << item_name << ".";
// Get the attr/method of ms_class object.
@ -1461,7 +1461,7 @@ EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AnalysisEnginePtr &engi
if (python_obj == nullptr) {
return nullptr;
}
auto wrapper_obj = dyn_cast<parse::PyObjectWrapper>(python_obj);
auto wrapper_obj = dyn_cast_ptr<parse::PyObjectWrapper>(python_obj);
MS_EXCEPTION_IF_NULL(wrapper_obj);
py::object real_python_obj = wrapper_obj->obj();
MS_LOG(DEBUG) << "item_value: " << item_value->ToString() << ", func_value: " << func_value->ToString()
@ -1485,7 +1485,7 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
if (!item_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
}
std::string item_name = item_value->cast<StringImmPtr>()->value();
std::string item_name = item_value->cast_ptr<StringImm>()->value();
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
if (require.empty()) {
@ -1520,7 +1520,7 @@ ValuePtr GetMsClassObject(const AbstractBasePtr &abs) {
if (!abs->isa<abstract::PartialAbstractClosure>()) {
return nullptr;
}
auto partial_abs = abs->cast<abstract::PartialAbstractClosurePtr>();
auto partial_abs = abs->cast_ptr<abstract::PartialAbstractClosure>();
auto fn = partial_abs->fn();
if (!fn->isa<abstract::PrimitiveAbstractClosure>()) {
return nullptr;
@ -1568,7 +1568,7 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
if (class_value != nullptr) {
return GetEvaluatedValueForMsClassAttrOrMethod(engine, item_value, class_value, data_conf, out_conf);
}
auto data_func_graph = dyn_cast<FuncGraphAbstractClosure>(data_args);
auto data_func_graph = dyn_cast_ptr<FuncGraphAbstractClosure>(data_args);
if (data_func_graph != nullptr) {
auto res =
GetEvaluatedValueForCellAttrOrMethod(engine, item_value, data_func_graph->func_graph(), data_conf, out_conf);
@ -1642,7 +1642,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
if (args_conf_list.size() != 1) {
MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
}
AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
auto node_conf = dyn_cast_ptr<AnfNodeConfig>(args_conf_list[0]);
MS_EXCEPTION_IF_NULL(node_conf);
MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
AbstractBasePtr x = node_conf->ObtainEvalResult()->abstract();
@ -1662,7 +1662,7 @@ static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager,
const FuncGraphPtr &root_g = root_g_set.back();
for (auto &param_node : root_g->parameters()) {
auto param = param_node->cast<ParameterPtr>();
if (param && name == param->name()) {
if (param != nullptr && param->name() == name) {
return param;
}
}
@ -1680,7 +1680,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
return nullptr;
}
static TypePtr type = std::make_shared<SymbolicKeyType>();
auto node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
auto node_conf = dyn_cast_ptr<AnfNodeConfig>(args_conf_list[0]);
if (node_conf == nullptr) {
MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
return nullptr;
@ -1688,7 +1688,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract();
MS_EXCEPTION_IF_NULL(abs);
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
auto ref_abs = abs->cast_ptr<AbstractRefTensor>();
if (ref_abs == nullptr) {
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
return nullptr;
@ -1701,11 +1701,11 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
// Only after that funcgrpah is inlined, the RefToEmbed CNode should be evaluated to specific SymbolicKey.
bool ifEmbedIsWeight = false;
if (node_conf->node() != nullptr && node_conf->node()->isa<Parameter>()) {
auto param = node_conf->node()->cast<ParameterPtr>();
auto param = node_conf->node()->cast_ptr<Parameter>();
MS_EXCEPTION_IF_NULL(param);
ifEmbedIsWeight = param->has_default();
}
auto refkey = ref_abs->ref_key_value()->cast<StringImmPtr>();
auto refkey = ref_abs->ref_key_value()->cast_ptr<StringImm>();
if (refkey == nullptr || !ifEmbedIsWeight) {
auto ret = std::make_shared<AbstractScalar>(type);
auto ref_value = ref_abs->ref();
@ -1788,12 +1788,12 @@ class ResolveEvaluator : public TransitionPrimEvaluator {
bool IsContainUndetermined(const AbstractBasePtr &arg) {
if (arg->isa<AbstractSequence>()) {
auto seq_arg = arg->cast<AbstractSequencePtr>();
auto seq_arg = arg->cast_ptr<AbstractSequence>();
return std::any_of(seq_arg->elements().begin(), seq_arg->elements().end(), IsContainUndetermined);
}
if (arg->isa<AbstractKeywordArg>()) {
auto kw_arg = arg->cast<AbstractKeywordArgPtr>();
auto kw_arg = arg->cast_ptr<AbstractKeywordArg>();
return IsContainUndetermined(kw_arg->get_arg());
}
@ -1824,7 +1824,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
ValuePtr value_track = arg_class_type->GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
parse::PyObjectWrapperPtr type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
auto type_obj = dyn_cast_ptr<parse::PyObjectWrapper>(value_track);
if (type_obj == nullptr) {
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
}
@ -1911,7 +1911,7 @@ class CallInstanceEvaluator : public TransitionPrimEvaluator {
}
ValuePtr value_track = arg_cls->GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
parse::MsClassObjectPtr ms_class = dyn_cast<parse::MsClassObject>(value_track);
auto ms_class = dyn_cast_ptr<parse::MsClassObject>(value_track);
if (ms_class == nullptr) {
MS_LOG(EXCEPTION) << "CallInstanceEvaluator only supports MsClassObject.";
}
@ -1933,7 +1933,7 @@ class CallInstanceEvaluator : public TransitionPrimEvaluator {
// Replace net with net.__call__
AnfNodePtr old_node = out_conf->node();
MS_EXCEPTION_IF_NULL(old_node);
CNodePtr old_cnode = dyn_cast<CNode>(old_node);
auto old_cnode = dyn_cast_ptr<CNode>(old_node);
MS_EXCEPTION_IF_NULL(old_cnode);
std::vector<AnfNodePtr> inputs = {NewValueNode(call_func_graph)};
for (size_t i = 1; i < old_cnode->size(); i++) {
@ -1966,7 +1966,7 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
std::shared_ptr<parse::Script> script_obj = dyn_cast<parse::Script>(value_track);
auto script_obj = dyn_cast_ptr<parse::Script>(value_track);
if (script_obj == nullptr) {
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
}
@ -2027,12 +2027,12 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
const auto &element_name = element.first;
const auto &element_abs = element.second;
if (element_abs->isa<abstract::FuncGraphAbstractClosure>()) {
const auto &element_abs_fn = element_abs->cast<abstract::FuncGraphAbstractClosurePtr>();
const auto &fg = element_abs_fn->func_graph();
auto element_abs_fn = element_abs->cast_ptr<abstract::FuncGraphAbstractClosure>();
auto fg = element_abs_fn->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto wrapper_obj = fg->python_obj();
if (wrapper_obj != nullptr && wrapper_obj->isa<parse::PyObjectWrapper>()) {
auto fn_py_obj = wrapper_obj->cast<parse::PyObjectWrapperPtr>()->obj();
auto fn_py_obj = wrapper_obj->cast_ptr<parse::PyObjectWrapper>()->obj();
(*global_params_dict)[py::str(element_name)] = fn_py_obj;
MS_LOG(DEBUG) << "Found global python function object for " << element_name << ", add it to global dict.";
}
@ -2138,10 +2138,10 @@ class PartialEvaluator : public Evaluator {
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
// Sometimes, node[0] in out_conf becomes phi0;
if (func->isa<PrimitiveAbstractClosure>()) {
auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
auto prim_func = dyn_cast_ptr<PrimitiveAbstractClosure>(func);
MS_EXCEPTION_IF_NULL(prim_func->prim());
if (prim_func->prim()->isa<prim::DoSignaturePrimitive>()) {
prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(prim_func->prim());
auto do_signature_prim = dyn_cast_ptr<prim::DoSignaturePrimitive>(prim_func->prim());
return HandleDoSignature(engine, do_signature_prim->function(), out_conf);
}
}
@ -2179,7 +2179,7 @@ class PartialEvaluator : public Evaluator {
MS_EXCEPTION_IF_NULL(engine);
MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
auto cnode = out_conf->node()->cast<CNodePtr>();
auto cnode = out_conf->node()->cast_ptr<CNode>();
if (cnode == nullptr) {
MS_LOG(EXCEPTION) << "Cnode is nullptr";
}
@ -2234,7 +2234,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
std::string exception_string;
// Processed in units of nodes. Raise ValueError(xxxx)
size_t index_begin = 2;
auto cnode = node->cast<CNodePtr>();
auto cnode = node->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(cnode);
auto inputs = cnode->inputs();
bool need_out_symbol = inputs.size() > 3;
@ -2268,17 +2268,17 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
bool CheckNeedSymbol(const AnfNodePtr &, const AbstractBasePtr &abs) const {
bool need_symbol = false;
if (abs->isa<abstract::AbstractScalar>()) {
auto scalar = abs->cast<abstract::AbstractScalarPtr>();
auto scalar = abs->cast_ptr<abstract::AbstractScalar>();
auto scalar_value = scalar->BuildValue();
if (scalar_value->isa<StringImm>()) {
need_symbol = true;
}
} else if (abs->isa<abstract::AbstractSequence>()) {
auto abs_list = abs->cast<abstract::AbstractSequencePtr>();
auto abs_list = abs->cast_ptr<abstract::AbstractSequence>();
const auto &elements = abs_list->elements();
for (auto &element : elements) {
if (element->isa<abstract::AbstractScalar>()) {
auto scalar = element->cast<abstract::AbstractScalarPtr>();
auto scalar = element->cast_ptr<abstract::AbstractScalar>();
auto scalar_value = scalar->BuildValue();
if (scalar_value->isa<StringImm>()) {
need_symbol = true;
@ -2310,7 +2310,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
std::string GetTupleString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) {
std::string exception_str;
// Process raise ValueError("str")
auto arg_tuple = arg->cast<abstract::AbstractTuplePtr>();
auto arg_tuple = arg->cast_ptr<abstract::AbstractTuple>();
const auto &arg_tuple_elements = arg_tuple->elements();
if (arg_tuple_elements.size() == 0) {
MS_LOG(EXCEPTION) << "The arg_tuple_elements can't be empty.";
@ -2334,7 +2334,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
std::string GetListString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) {
std::string exception_str;
// Process raise ValueError("str")
auto arg_list = arg->cast<abstract::AbstractListPtr>();
auto arg_list = arg->cast_ptr<abstract::AbstractList>();
const auto &arg_list_elements = arg_list->elements();
if (arg_list_elements.size() == 0) {
MS_LOG(EXCEPTION) << "The arg_list_elements can't be empty.";
@ -2358,7 +2358,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
std::string GetExceptionType(const AbstractBasePtr &abs) const {
std::string str;
if (abs->isa<abstract::AbstractScalar>()) {
auto scalar = abs->cast<abstract::AbstractScalarPtr>();
auto scalar = abs->cast_ptr<abstract::AbstractScalar>();
auto scalar_value = scalar->BuildValue();
if (scalar_value->isa<StringImm>()) {
str = GetValue<std::string>(scalar_value);
@ -2491,8 +2491,8 @@ namespace {
bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) {
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(model);
auto x_tuple = dyn_cast<AbstractTuple>(x);
auto model_tuple = dyn_cast<Tuple>(model);
auto x_tuple = dyn_cast_ptr<AbstractTuple>(x);
auto model_tuple = dyn_cast_ptr<Tuple>(model);
if (x_tuple == nullptr || model_tuple == nullptr) {
return false;
@ -2518,8 +2518,8 @@ bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) {
bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) {
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(model);
auto x_tensor = dyn_cast<AbstractTensor>(x);
auto model_tensor = dyn_cast<TensorType>(model);
auto x_tensor = dyn_cast_ptr<AbstractTensor>(x);
auto model_tensor = dyn_cast_ptr<TensorType>(model);
if (x_tensor == nullptr || model_tensor == nullptr) {
return false;
@ -2535,8 +2535,8 @@ bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) {
bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) {
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(model);
auto x_list = dyn_cast<AbstractList>(x);
auto model_list = dyn_cast<List>(model);
auto x_list = dyn_cast_ptr<AbstractList>(x);
auto model_list = dyn_cast_ptr<List>(model);
if (x_list == nullptr || model_list == nullptr) {
return false;
@ -2563,10 +2563,10 @@ bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) {
inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) {
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(model);
if (dyn_cast<AbstractScalar>(x) == nullptr) {
if (dyn_cast_ptr<AbstractScalar>(x) == nullptr) {
return false;
}
TypePtr x_type = x->GetTypeTrack();
auto &x_type = x->GetTypeTrack();
return IsSubType(x_type, model);
}
} // namespace

View File

@ -63,7 +63,7 @@ bool CanSpecializeValueNode(const AnfNodePtr &node) {
}
if (IsValueNode<FuncGraph>(node)) {
if (node->abstract() != nullptr) {
auto abs_func = node->abstract()->cast<FuncGraphAbstractClosurePtr>();
auto abs_func = node->abstract()->cast_ptr<FuncGraphAbstractClosure>();
// If this funcgraph had specialized in ProcessCNode of FirstPass,
// then ignore it.
if (abs_func != nullptr && abs_func->specialized()) {
@ -110,12 +110,12 @@ void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) {
}
// Handle MakeTuple/MakeList CNode.
auto cnode = dyn_cast<CNode>(node);
auto cnode = dyn_cast_ptr<CNode>(node);
if (cnode != nullptr) {
if (pos + 1 >= cnode->inputs().size()) {
continue;
}
auto input_value = GetValueNode<StringImmPtr>(cnode->input(pos + 1));
auto input_value = GetValuePtr<StringImm>(cnode->input(pos + 1));
if (input_value == nullptr || input_value->value() != kDeadNodeName) {
continue;
}
@ -129,7 +129,7 @@ void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) {
// Change the abstract.
(*flags)[pos] = false; // Change the use flag as 0.
auto sequence_abs = dyn_cast<AbstractSequence>(node->abstract());
auto sequence_abs = dyn_cast_ptr<AbstractSequence>(node->abstract());
if (sequence_abs != nullptr && !sequence_abs->PurifyElements()) {
MS_LOG(ERROR) << "Purify elements failed, abstract: " << sequence_abs->ToString()
<< ", node: " << node->DebugString(recursive_level);
@ -138,13 +138,13 @@ void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) {
}
// Handle ValueTuple/ValueList.
if (IsValueNode<ValueTuple>(node) || IsValueNode<ValueList>(node)) {
auto sequence_value = GetValueNode<ValueSequencePtr>(node);
auto sequence_value = GetValuePtr<ValueSequence>(node);
MS_EXCEPTION_IF_NULL(sequence_value);
if (pos >= sequence_value->value().size()) {
continue;
}
ValuePtr element_value = sequence_value->value()[pos];
auto element_str_value = element_value->cast<StringImmPtr>();
auto element_str_value = element_value->cast_ptr<StringImm>();
if (element_str_value == nullptr || element_str_value->value() != kDeadNodeName) {
continue;
}
@ -157,7 +157,7 @@ void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) {
// Change the abstract.
(*flags)[pos] = false; // Change the use flag as 0.
auto sequence_abs = dyn_cast<AbstractSequence>(node->abstract());
auto sequence_abs = dyn_cast_ptr<AbstractSequence>(node->abstract());
if (sequence_abs != nullptr && !sequence_abs->PurifyElements()) {
constexpr int recursive_level = 2;
MS_LOG(ERROR) << "Purify elements failed, abstract: " << sequence_abs->ToString()
@ -250,7 +250,7 @@ AbstractBasePtr ProgramSpecializer::SpecializeAbstractFuncRecursively(const Abst
auto build_new_abs = [this, &func_atoms](const AbstractFuncAtomPtr &poss) {
auto resolved_atom = poss;
if (poss->isa<AsyncAbstractFuncAtom>()) {
const auto &async_abs_func = poss->cast<AsyncAbstractFuncAtomPtr>();
auto async_abs_func = poss->cast_ptr<AsyncAbstractFuncAtom>();
const auto &resolved_func = async_abs_func->GetUnique();
resolved_atom = resolved_func->cast<AbstractFuncAtomPtr>();
MS_EXCEPTION_IF_NULL(resolved_atom);
@ -306,7 +306,7 @@ void ProgramSpecializer::SpecializeCNodeInput0FuncGraph() {
if (!node->isa<CNode>()) {
continue;
}
auto &input0 = node->cast<CNodePtr>()->input(0);
auto &input0 = node->cast_ptr<CNode>()->input(0);
MS_EXCEPTION_IF_NULL(input0);
if (IsValueNode<FuncGraph>(input0)) {
continue;
@ -383,7 +383,7 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod
void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(node);
auto c_node = node->cast<CNodePtr>();
auto c_node = node->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(c_node);
auto inputs = c_node->inputs();
std::vector<AnfNodePtr> new_inputs;
@ -403,7 +403,7 @@ void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const An
return new_inp;
});
auto c_new_node = new_node->cast<CNodePtr>();
auto c_new_node = new_node->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(c_new_node);
c_new_node->set_inputs(new_inputs);
}
@ -534,7 +534,7 @@ void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node,
if (new_node == old_node) {
return;
}
AbstractSequencePtr old_sequence_abs = dyn_cast<AbstractSequence>(old_abs);
auto old_sequence_abs = dyn_cast_ptr<AbstractSequence>(old_abs);
if (old_sequence_abs == nullptr || old_sequence_abs->sequence_nodes() == nullptr ||
old_sequence_abs->sequence_nodes()->empty()) {
MS_LOG(DEBUG) << "No sequence node in old abs, " << old_node->DebugString() << " --> " << new_node->DebugString();
@ -590,7 +590,7 @@ void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node,
}
MS_LOG(ERROR) << "New abstract, " << old_node->DebugString() << " --> " << new_node->DebugString()
<< ", elements_use_flags: " << (*flags);
AbstractSequencePtr new_sequence_abs = dyn_cast<AbstractSequence>(new_abs);
auto new_sequence_abs = dyn_cast_ptr<AbstractSequence>(new_abs);
if (new_sequence_abs == nullptr) {
MS_LOG(EXCEPTION) << "New node should be sequence type as well, but got " << new_abs->ToString();
}
@ -608,7 +608,7 @@ void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node,
template <typename T>
void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecializer *const specializer) {
const auto &old_input = cnode->input(index);
auto sequence_value = GetValueNode<std::shared_ptr<T>>(old_input);
auto sequence_value = GetValuePtr<T>(old_input);
if (sequence_value == nullptr) {
return;
}
@ -625,7 +625,7 @@ void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecial
}
for (size_t i = 0; i < sequence_value_size; ++i) {
ValuePtr old_sequence_value = sequence_value->value()[i];
auto old_sequence_str_value = old_sequence_value->cast<StringImmPtr>();
auto old_sequence_str_value = old_sequence_value->cast_ptr<StringImm>();
if (!(*flags)[i]) {
auto zero = MakeValue(0);
(void)elements.emplace_back(zero);
@ -643,7 +643,7 @@ void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecial
auto new_sequence_value = std::make_shared<T>(elements);
auto new_input = NewValueNode(new_sequence_value);
auto new_input_abs = new_sequence_value->ToAbstract();
AbstractSequencePtr new_sequence_abs = dyn_cast<AbstractSequence>(new_input_abs);
auto new_sequence_abs = dyn_cast<AbstractSequence>(new_input_abs);
MS_EXCEPTION_IF_NULL(new_sequence_abs);
std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
(void)sequence_nodes->emplace_back(AnfNodeWeakPtr(new_input));
@ -704,7 +704,7 @@ void FuncGraphSpecializer::EliminateUnusedSequenceItem(const CNodePtr &cnode) co
(void)inputs.emplace_back(cnode->input(0));
for (size_t i = 0; i < (*flags).size(); ++i) {
auto old_input = cnode->input(i + 1);
auto old_input_value = GetValueNode<StringImmPtr>(old_input);
auto old_input_value = GetValuePtr<StringImm>(old_input);
if (!(*flags)[i]) {
auto zero_value = NewValueNode(MakeValue(0));
zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0)));
@ -754,7 +754,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
MS_LOG(EXCEPTION) << "Fail to get abstract value with " << conf->ToString() << ", for " << new_node->DebugString();
}
if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) {
auto partial_abstract = dyn_cast<PartialAbstractClosure>(new_node->abstract());
auto partial_abstract = dyn_cast_ptr<PartialAbstractClosure>(new_node->abstract());
if (partial_abstract->node() == node) {
partial_abstract->set_node(new_node);
}
@ -768,8 +768,8 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
}
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
auto attrs = conf->ObtainEvalResult()->attribute();
auto c_old = node->cast<CNodePtr>();
auto c_new = new_node->cast<CNodePtr>();
auto c_old = node->cast_ptr<CNode>();
auto c_new = new_node->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(c_new);
auto new_inputs = c_new->inputs();
auto old_inputs = c_old->inputs();
@ -786,7 +786,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
bool ignore_build_value = false;
AnfNodePtr replace_node = nullptr;
if (specializer_->engine()->check_isolated_side_effect()) {
auto cnode_input = dyn_cast<CNode>(node_input);
auto cnode_input = dyn_cast_ptr<CNode>(node_input);
ignore_build_value = (cnode_input != nullptr && cnode_input->has_isolated_side_effect_node());
if (ignore_build_value) {
MS_LOG(INFO) << "Don't build value node for CNode which contains isolated side-effect inputs, node: "
@ -878,7 +878,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const CNodePtr &cnode, con
const AbstractBasePtr &abs, const AbstractBasePtrList &argvals) {
MS_EXCEPTION_IF_NULL(abs);
MS_EXCEPTION_IF_NULL(func);
AbstractFunctionPtr real_a = dyn_cast<AbstractFunction>(abs);
auto real_a = dyn_cast_ptr<AbstractFunction>(abs);
MS_EXCEPTION_IF_NULL(real_a);
AbstractFunctionPtr func_abs = real_a->GetUnique();
@ -905,7 +905,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const CNodePtr &cnode, con
// Set the flag, so this MetaFuncGraph will be Re-AutoMonaded.
MS_EXCEPTION_IF_NULL(func_abs);
if (func_abs->isa<MetaFuncGraphAbstractClosure>()) {
auto specialized_fg = GetValueNode<FuncGraphPtr>(repl);
auto specialized_fg = GetValuePtr<FuncGraph>(repl);
if (specialized_fg != nullptr && (argvals.size() > 1) && argvals.back() != nullptr &&
argvals.back()->isa<AbstractUMonad>()) {
specialized_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
@ -924,7 +924,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const CNodePtr &cnode
MS_EXCEPTION_IF_NULL(errcode);
*errcode = kSpecializeSuccess;
auto real_func = dyn_cast<TypedPrimitiveAbstractClosure>(func_abs);
auto real_func = dyn_cast_ptr<TypedPrimitiveAbstractClosure>(func_abs);
if (real_func != nullptr) {
return BuildValueNode(real_func->prim(), abs);
}
@ -942,7 +942,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const CNodePtr &cnode
argvals = result.first;
AbstractBasePtr unique_output = result.second;
auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func_abs);
auto prim_func = dyn_cast_ptr<PrimitiveAbstractClosure>(func_abs);
if (prim_func != nullptr) {
auto type_func = std::make_shared<TypedPrimitiveAbstractClosure>(prim_func->prim(), argvals, unique_output);
return BuildValueNode(prim_func->prim(), type_func);
@ -993,7 +993,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &c
AbstractBasePtrList args;
auto backed_fnval = fnval;
if (fnval->isa<PartialAbstractClosure>()) {
auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
auto partial_closure = dyn_cast_ptr<PartialAbstractClosure>(fnval);
backed_fnval = partial_closure->fn();
args = partial_closure->args();
}
@ -1133,7 +1133,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
// CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...)
const size_t arg_start_index = 2;
while (IsPrimitiveCNode(func, prim::kPrimPartial)) {
std::vector<AnfNodePtr> inputs = func->cast<CNodePtr>()->inputs();
std::vector<AnfNodePtr> inputs = func->cast_ptr<CNode>()->inputs();
// First element is partial, second is func so arg is start from 2
(void)args.insert(args.cbegin(), inputs.cbegin() + SizeToInt(arg_start_index), inputs.cend());
func = inputs[1];
@ -1222,10 +1222,10 @@ bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argv
return true;
}
if (func->isa<MetaFuncGraphAbstractClosure>() && argvals.empty()) {
auto meta_func_graph_wrapper = dyn_cast<MetaFuncGraphAbstractClosure>(func);
auto meta_func_graph_wrapper = dyn_cast_ptr<MetaFuncGraphAbstractClosure>(func);
auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph();
if (meta_func_graph != nullptr && meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
auto do_signature = dyn_cast<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
auto do_signature = dyn_cast_ptr<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
if (do_signature != nullptr && do_signature->function()->isa<Primitive>()) {
MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY.";
return true;
@ -1331,7 +1331,7 @@ AnfNodePtr FuncGraphSpecializer::BuildValueNodeForAbstractFunction(const AnfNode
const AbstractFunctionPtr &abs) {
ValuePtr value = nullptr;
if (abs->isa<PrimitiveAbstractClosure>()) {
auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
auto real_fn = dyn_cast_ptr<PrimitiveAbstractClosure>(abs);
// For primitive, check if the attribute is the same with cnode inferred attribute, if not, clone a new one
if (attrs != nullptr) {
value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
@ -1339,20 +1339,20 @@ AnfNodePtr FuncGraphSpecializer::BuildValueNodeForAbstractFunction(const AnfNode
value = real_fn->prim();
}
} else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
auto real_fn = dyn_cast_ptr<MetaFuncGraphAbstractClosure>(abs);
value = real_fn->meta_func_graph();
} else if (abs->isa<FuncGraphAbstractClosure>()) {
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(abs);
auto real_fn = dyn_cast_ptr<FuncGraphAbstractClosure>(abs);
value = real_fn->func_graph();
} else {
return nullptr;
}
MS_EXCEPTION_IF_NULL(value);
if (!value->isa<FuncGraph>() || value->cast<FuncGraphPtr>()->parent() == nullptr ||
(IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast<FuncGraphPtr>()->parent()))) {
if (!value->isa<FuncGraph>() || value->cast_ptr<FuncGraph>()->parent() == nullptr ||
(IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast_ptr<FuncGraph>()->parent()))) {
return BuildValueNode(value, ival);
} else if (IsPrimitiveCNode(cnode, prim::kPrimJ) && origin_node->isa<Parameter>() &&
!value->cast<FuncGraphPtr>()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) {
!value->cast_ptr<FuncGraph>()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) {
// Only if J(Parameter=func_graph) and func_graph(aka 'value') is not K graph.
MS_LOG(DEBUG) << "Specialize the parameter used by J CNode, cnode: " << cnode->DebugString();
return BuildValueNode(value, ival);

View File

@ -43,7 +43,7 @@ AnalysisContextPtr StackFrame::GetParentContext(const BaseFuncGraphEvaluatorPtr
MS_EXCEPTION_IF_NULL(graph_func);
MS_EXCEPTION_IF_NULL(fg_evaluator);
AnalysisContextPtr parent_context = nullptr;
auto func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(graph_func);
auto func_graph_abs = dyn_cast_ptr<FuncGraphAbstractClosure>(graph_func);
if (func_graph_abs != nullptr) { // Set parent context for FuncGraphAbstractClosure.
parent_context = func_graph_abs->context();
} else if (graph_func->isa<MetaFuncGraphAbstractClosure>()) { // Or DummyContext for MetaFuncGraphAbstractClosure.
@ -164,7 +164,7 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
<< ", current_context_: " << current_context_->ToString();
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_, current_context_->func_graph());
EvalResultPtr node_eval_result = nullptr;
const auto &fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(evaluator());
auto fg_evaluator = dyn_cast_ptr<BaseFuncGraphEvaluator>(evaluator());
if (fg_evaluator == nullptr) {
MS_LOG(EXCEPTION) << "Evaluator should be a BaseGraphEvaluator, but got " << evaluator()->ToString();
}
@ -194,7 +194,7 @@ void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last
// Check if child func graph contains isolated side-effect.
if (engine->check_isolated_side_effect()) {
if (last_stack_frame->func_graph()->has_isolated_side_effect_node()) {
auto cnode = dyn_cast<CNode>(CurrentNode());
auto cnode = dyn_cast_ptr<CNode>(CurrentNode());
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_has_isolated_side_effect_node(true);
cnode->func_graph()->set_has_isolated_side_effect_node(true);
@ -205,7 +205,7 @@ void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last
auto evaluator = last_stack_frame->evaluator();
MS_EXCEPTION_IF_NULL(evaluator);
evaluator->evaluator_cache_mgr()->SetValue(last_stack_frame->args_abs_list(), result);
const auto &fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(evaluator);
auto fg_evaluator = dyn_cast_ptr<BaseFuncGraphEvaluator>(evaluator);
if (fg_evaluator == nullptr) {
MS_LOG(EXCEPTION) << "Evaluator should be a BaseGraphEvaluator, but got " << evaluator->ToString();
}

View File

@ -315,7 +315,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
return std::make_shared<EvalResult>(possible_func->Clone(), std::make_shared<AttrValueMap>());
}
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(possible_func);
auto func = dyn_cast_ptr<AbstractFunction>(possible_func);
if (func == nullptr) {
CheckInterpretedObject(possible_func);
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << possible_func->ToString() << ".";
@ -357,11 +357,11 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list);
// Check if func graph contains isolated side-effect, and sync.
if (check_isolated_side_effect()) {
FuncGraphAbstractClosurePtr func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(func);
auto func_graph_abs = mindspore::cast<FuncGraphAbstractClosure>(func);
if (func_graph_abs != nullptr) {
contains_isolated_side_effect |= func_graph_abs->func_graph()->has_isolated_side_effect_node();
}
MetaFuncGraphAbstractClosurePtr meta_func_graph_abs = dyn_cast<MetaFuncGraphAbstractClosure>(func);
auto meta_func_graph_abs = mindspore::cast<MetaFuncGraphAbstractClosure>(func);
if (meta_func_graph_abs != nullptr) {
contains_isolated_side_effect |= meta_func_graph_abs->meta_func_graph()->has_isolated_side_effect_node();
}
@ -652,7 +652,7 @@ EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, c
// again, so update the config_map with new_conf;
anfnode_config_map_[orig_conf] = new_conf;
MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->ToString() << ", to new_conf: " << new_conf->ToString();
auto old_cnode = orig_conf->node()->cast<CNodePtr>();
auto old_cnode = orig_conf->node()->cast_ptr<CNode>();
auto new_cnode = new_conf->node()->cast<CNodePtr>();
if (old_cnode != nullptr && new_cnode != nullptr) {
if (old_cnode->func_graph() == new_cnode->func_graph()) {
@ -757,20 +757,20 @@ std::string JoinBranchesFailedInfo(const AbstractBasePtr &spec, const AbstractBa
<< spec->ToString() << ", and that of the previous branch is " << last_spec->ToString() << ".\n"
<< "The node is " << node->DebugString(recursive_level);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>()->input(0);
auto cnode = node->cast_ptr<CNode>()->input(0);
if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
// {prim::kPrimSwitch, cond, true_branch, false_branch}
constexpr int true_index = 2;
constexpr int false_index = 3;
auto inputs = cnode->cast<CNodePtr>()->inputs();
const auto &inputs = cnode->cast_ptr<CNode>()->inputs();
buffer << ", true branch: " << inputs.at(true_index)->ToString()
<< ", false branch: " << inputs.at(false_index)->ToString();
} else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
// {prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, branch1, branch2, ...}}
constexpr int branch_index = 2;
auto tuple_node = cnode->cast<CNodePtr>()->input(branch_index);
const auto &tuple_node = cnode->cast_ptr<CNode>()->input(branch_index);
if (IsPrimitiveCNode(tuple_node, prim::kPrimMakeTuple)) {
auto tuple_inputs = tuple_node->cast<CNodePtr>()->inputs();
const auto &tuple_inputs = tuple_node->cast_ptr<CNode>()->inputs();
for (size_t i = 1; i < tuple_inputs.size(); i++) {
buffer << ", branch" << i << ": " << tuple_inputs.at(i);
}
@ -827,7 +827,7 @@ bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
return true;
}
if (abstract->isa<AbstractSequence>()) {
auto elements = abstract->cast<AbstractSequencePtr>()->elements();
auto elements = abstract->cast_ptr<AbstractSequence>()->elements();
if (std::any_of(elements.begin(), elements.end(),
[](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
return true;
@ -888,7 +888,7 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs,
const std::vector<AsyncAbstractPtr> &pending_async_abstract_list,
const std::vector<std::size_t> &index) {
const auto sequence_abs = dyn_cast<AbstractSequence>(orig_abs);
auto sequence_abs = dyn_cast_ptr<AbstractSequence>(orig_abs);
if (sequence_abs != nullptr) {
const auto &orig_elements = sequence_abs->elements();
AbstractBasePtrList new_elements;
@ -1153,7 +1153,7 @@ AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &cont
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
if (enable_eliminate_unused_element && value->isa<ValueSequence>()) {
auto abs = value->ToAbstract();
auto sequence_abs = dyn_cast<AbstractSequence>(abs);
auto sequence_abs = dyn_cast_ptr<AbstractSequence>(abs);
MS_EXCEPTION_IF_NULL(sequence_abs);
if (anf_node != nullptr) {
SetSequenceNodeElementsUseFlags(anf_node, std::make_shared<std::vector<bool>>(sequence_abs->elements().size()));
@ -1179,12 +1179,12 @@ EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrLi
if (evaluator == nullptr) {
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
}
auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
auto trivial_evaluator = dyn_cast_ptr<TrivialPrimEvaluator>(evaluator);
if (trivial_evaluator != nullptr) {
return trivial_evaluator->EvalPrim(nullptr, arg_specs);
}
// Support MakeTuple/MakeList ops in PyNative mode.
auto transition_evaluator = dyn_cast<TransitionPrimEvaluator>(evaluator);
auto transition_evaluator = dyn_cast_ptr<TransitionPrimEvaluator>(evaluator);
if (transition_evaluator != nullptr &&
(transition_evaluator->isa<MakeTupleEvaluator>() || transition_evaluator->isa<MakeListEvaluator>())) {
return transition_evaluator->EvalPrim(nullptr, arg_specs, nullptr, nullptr);

View File

@ -146,15 +146,15 @@ void ValidateValueNode(const AnfNodePtr &node) {
}
void CheckValueTuple(const AnfNodePtr &node) {
const auto &value_node = node->cast<ValueNodePtr>();
auto value_node = node->cast_ptr<ValueNode>();
MS_EXCEPTION_IF_NULL(value_node);
const auto value = value_node->value();
const auto &value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
const auto value_tuple = value->cast<ValueTuplePtr>();
auto value_tuple = value->cast_ptr<ValueTuple>();
MS_EXCEPTION_IF_NULL(value_tuple);
const auto tuple_values = value_tuple->value();
const auto &tuple_values = value_tuple->value();
for (size_t i = 0; i < tuple_values.size(); ++i) {
const auto input_node = NewValueNode(tuple_values[i]);
auto input_node = NewValueNode(tuple_values[i]);
ValidateOperation(input_node);
ValidateValueNode(input_node);
}
@ -167,7 +167,7 @@ void Validate(const FuncGraphPtr &func_graph) {
for (auto node : all_nodes) {
TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
while (IsPrimitiveCNode(node, prim::kPrimReturn) || IsPrimitiveCNode(node, prim::kPrimDepend)) {
node = node->cast<CNodePtr>()->input(1);
node = node->cast_ptr<CNode>()->input(1);
}
if (IsValueNode<ValueTuple>(node)) {
CheckValueTuple(node);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -41,7 +41,7 @@ AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) {
}
return std::make_shared<AbstractFuncUnion>(this_func, other);
}
auto other_union = dyn_cast<AbstractFuncUnion>(other);
auto other_union = dyn_cast_ptr<AbstractFuncUnion>(other);
MS_EXCEPTION_IF_NULL(other_union);
if (other_union->IsSuperSet(this_func)) {
return other;
@ -122,7 +122,7 @@ AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) {
}
return std::make_shared<AbstractFuncUnion>(this_func, other);
}
auto other_union = dyn_cast<AbstractFuncUnion>(other);
auto other_union = dyn_cast_ptr<AbstractFuncUnion>(other);
MS_EXCEPTION_IF_NULL(other_union);
if (other_union->IsSuperSet(this_func)) {
return other;

View File

@ -173,14 +173,14 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
if (*this == *other) {
return shared_from_base<AbstractBase>();
}
auto type_self = GetTypeTrack();
auto type_other = other->GetTypeTrack();
const auto &type_self = GetTypeTrack();
const auto &type_other = other->GetTypeTrack();
TypePtr res_type = TypeJoin(type_self, type_other);
if (res_type == kAnyType) {
TypeJoinLogging(type_self, type_other, shared_from_base<AbstractBase>(), other);
}
auto value_self = GetValueTrack();
auto value_other = other->GetValueTrack();
const auto &value_self = GetValueTrack();
const auto &value_other = other->GetValueTrack();
ValuePtr res_value = ValueJoin(value_self, value_other);
if (res_value == value_self) {
return shared_from_base<AbstractBase>();
@ -193,7 +193,7 @@ AbstractBasePtr AbstractType::Clone() const {
if (value_self == nullptr || !value_self->isa<Type>()) {
return nullptr;
}
TypePtr type_self = value_self->cast<TypePtr>();
auto type_self = value_self->cast_ptr<Type>();
return std::make_shared<AbstractType>(type_self->Clone());
}
@ -201,7 +201,8 @@ bool AbstractType::operator==(const AbstractBase &other) const {
if (this == &other) {
return true;
}
return tid() == other.tid() && IsEqual(dyn_cast<Type>(GetValueTrack()), dyn_cast<Type>(other.GetValueTrack()));
return tid() == other.tid() &&
IsEqual(dyn_cast_ptr<Type>(GetValueTrack()), dyn_cast_ptr<Type>(other.GetValueTrack()));
}
std::string AbstractType::ToString() const {
@ -215,7 +216,7 @@ std::string AbstractType::ToString() const {
buffer << type_name() << "(Value: nullptr)";
return buffer.str();
}
TypePtr type_self = value_self->cast<TypePtr>();
auto type_self = value_self->cast_ptr<Type>();
buffer << type_name() << "("
<< "Value: " << type_self->ToString() << ")";
return buffer.str();
@ -496,7 +497,7 @@ AnfNodeWeakPtrList AbstractSequence::SequenceNodesJoin(const AbstractBasePtr &ot
if (!enable_eliminate_unused_element || this->sequence_nodes() == nullptr) {
return sequence_nodes;
}
auto other_sequence = dyn_cast<AbstractSequence>(other);
auto other_sequence = dyn_cast_ptr<AbstractSequence>(other);
if (other_sequence == nullptr) {
return sequence_nodes;
}
@ -714,7 +715,7 @@ template MS_CORE_API ValuePtr AbstractSequence::ElementsBuildValue<ValueList>()
template <typename T>
AbstractBasePtr AbstractSequence::ElementsJoin(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
auto other_sequeue = dyn_cast<T>(other);
auto other_sequeue = dyn_cast_ptr<T>(other);
if (other_sequeue == nullptr) {
AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
}
@ -761,7 +762,7 @@ bool AbstractSequence::operator==(const AbstractSequence &other) const {
}
void AbstractTuple::set_shape(const BaseShapePtr &shape) {
auto tuple_shape = dyn_cast<TupleShape>(shape);
auto tuple_shape = dyn_cast_ptr<TupleShape>(shape);
MS_EXCEPTION_IF_NULL(tuple_shape);
if (tuple_shape->shape().size() != elements_.size()) {
MS_LOG(EXCEPTION) << "Size mismatch: " << tuple_shape->shape().size() << " vs " << elements_.size();
@ -789,7 +790,7 @@ bool AbstractTuple::ContainsAllBroadenTensors() const {
for (size_t i = 0; i < elements_.size(); ++i) {
if (!(elements_[i]->isa<abstract::AbstractUndetermined>() && elements_[i]->IsBroaden()) &&
!(elements_[i]->isa<abstract::AbstractTuple>() &&
elements_[i]->cast<abstract::AbstractTuplePtr>()->ContainsAllBroadenTensors())) {
elements_[i]->cast_ptr<abstract::AbstractTuple>()->ContainsAllBroadenTensors())) {
return false;
}
}
@ -927,7 +928,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
// AbstractTensor join with AbstractUndetermined
if (other_type->type_id() == kObjectTypeUndeterminedType) {
auto other_undetermined_tensor = dyn_cast<AbstractUndetermined>(other);
auto other_undetermined_tensor = dyn_cast_ptr<AbstractUndetermined>(other);
MS_EXCEPTION_IF_NULL(other_undetermined_tensor);
// Check shape
auto res_shape = ShapeJoin(shape(), other_undetermined_tensor->shape());
@ -941,7 +942,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
}
// AbstractTensor join with AbstractTensor
auto other_tensor = dyn_cast<AbstractTensor>(other);
auto other_tensor = dyn_cast_ptr<AbstractTensor>(other);
if (other_tensor == nullptr) {
AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
}
@ -964,8 +965,8 @@ bool AbstractTensor::equal_to(const AbstractTensor &other) const {
return true;
}
// Check value. for AbstractTensor, both value should be AnyValue.
auto v1 = GetValueTrack();
auto v2 = other.GetValueTrack();
const auto &v1 = GetValueTrack();
const auto &v2 = other.GetValueTrack();
if (v1 != v2 && (v1 == nullptr || !v1->isa<AnyValue>() || v2 == nullptr || !v2->isa<AnyValue>())) {
return false;
}
@ -1142,7 +1143,7 @@ TypePtr AbstractJTagged::BuildType() const {
AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
auto other_jtagged = dyn_cast<AbstractJTagged>(other);
auto other_jtagged = dyn_cast_ptr<AbstractJTagged>(other);
if (other_jtagged == nullptr) {
AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
}
@ -1182,8 +1183,8 @@ AbstractRefTensor::AbstractRefTensor(const AbstractTensorPtr &ref_value, const V
TypePtr AbstractRefTensor::BuildType() const {
auto type = AbstractTensor::BuildType();
MS_EXCEPTION_IF_NULL(type);
auto subtype = type->cast<TensorTypePtr>();
auto subtype = dyn_cast_ptr<TensorType>(type);
MS_EXCEPTION_IF_NULL(subtype);
return std::make_shared<RefType>(subtype);
}
@ -1430,7 +1431,7 @@ BaseShapePtrList AbstractSparseTensor::ElementsShapeTupleRecursive() const {
BaseShapePtrList element_shape_list;
for (const auto &element : elements()) {
MS_EXCEPTION_IF_NULL(element);
auto abs_tuple = element->cast<AbstractTuplePtr>();
auto abs_tuple = element->cast_ptr<AbstractTuple>();
if (abs_tuple == nullptr) {
element_shape_list.push_back(element->BuildShape());
} else {

View File

@ -120,17 +120,17 @@ class MS_CORE_API AbstractBase : public Base {
/// \brief Get the abstract value, which is tracked.
///
/// \return A pointer to the Value.
ValuePtr GetValueTrack() const { return value_; }
const ValuePtr &GetValueTrack() const { return value_; }
/// \brief Get the abstract type, which is tracked.
///
/// \return A pointer to the Type.
TypePtr GetTypeTrack() const { return type_; }
const TypePtr &GetTypeTrack() const { return type_; }
/// \brief Get the abstract shape, which is tracked.
///
/// \return A pointer to the BaseShape.
BaseShapePtr GetShapeTrack() const { return shape_; }
const BaseShapePtr &GetShapeTrack() const { return shape_; }
/// \brief Try to build a real value from an abstract value.
///

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -141,8 +141,8 @@ bool AnalysisContext::operator==(const AnalysisContext &other) const {
if (func_graph_->has_flag(GRAPH_FLAG_IS_WHILE_HEADER) &&
args_spec_list_[i]->isa<abstract::FuncGraphAbstractClosure>() &&
other.args_spec_list_[i]->isa<abstract::FuncGraphAbstractClosure>()) {
auto temp_this = args_spec_list_[i]->cast<abstract::FuncGraphAbstractClosurePtr>()->Copy();
auto temp_other = other.args_spec_list_[i]->cast<abstract::FuncGraphAbstractClosurePtr>()->Copy();
auto temp_this = args_spec_list_[i]->cast_ptr<abstract::FuncGraphAbstractClosure>()->Copy();
auto temp_other = other.args_spec_list_[i]->cast_ptr<abstract::FuncGraphAbstractClosure>()->Copy();
temp_this->set_tracking_id(nullptr);
temp_other->set_tracking_id(nullptr);
if (!(*temp_this == *temp_other)) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -41,7 +41,7 @@ ABSTRACT_REPORT_NAME_DEC(KeywordArg)
TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &error_message_prefix) {
auto ori_type = type;
if (type->isa<TensorType>()) {
auto tensor = type->cast<TensorTypePtr>();
auto tensor = type->cast_ptr<TensorType>();
type = tensor->element();
MS_EXCEPTION_IF_NULL(type);
}

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -212,7 +212,7 @@ AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const Abstrac
}
AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
AbstractFunctionPtr f_spec = dyn_cast<AbstractFunction>(spec);
auto f_spec = dyn_cast_ptr<AbstractFunction>(spec);
if (f_spec != nullptr) {
return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
}
@ -285,7 +285,7 @@ AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) {
auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
if (type->isa<TensorType>()) {
auto tensor_type = type->cast<TensorTypePtr>();
auto tensor_type = type->cast_ptr<TensorType>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
@ -319,8 +319,8 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type
}
return MakeAbstractTensor(shape, type);
} else if (base_shape->isa<TupleShape>() && type->isa<Tuple>()) {
auto shape_tuple = base_shape->cast<TupleShapePtr>();
auto type_tuple = type->cast<TuplePtr>();
auto shape_tuple = base_shape->cast_ptr<TupleShape>();
auto type_tuple = type->cast_ptr<Tuple>();
AbstractBasePtrList ptr_list;
for (size_t it = 0; it < shape_tuple->size(); ++it) {
auto tensor_it = MakeAbstract((*shape_tuple)[it], (*type_tuple)[it]);
@ -329,8 +329,8 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type
auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
return tuple;
} else if (base_shape->isa<ListShape>() && type->isa<List>()) {
auto shape_list = base_shape->cast<ListShapePtr>();
auto type_list = type->cast<ListPtr>();
auto shape_list = base_shape->cast_ptr<ListShape>();
auto type_list = type->cast_ptr<List>();
AbstractBasePtrList ptr_list;
for (size_t it = 0; it < shape_list->size(); ++it) {
auto tensor_it = MakeAbstract((*shape_list)[it], (*type_list)[it]);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -42,12 +42,12 @@ bool BaseRef::operator==(const BaseRef &other) const {
return false;
}
if (m_ptr->isa<Value>()) {
return *(m_ptr->cast<ValuePtr>()) == *(other.m_ptr->cast<ValuePtr>());
return *(m_ptr->cast_ptr<Value>()) == *(other.m_ptr->cast_ptr<Value>());
}
// for noderef equal
if (m_ptr->isa<BaseRef>()) {
return *std::static_pointer_cast<BaseRef>(m_ptr) == *std::static_pointer_cast<BaseRef>(other.m_ptr);
return *(m_ptr->cast_ptr<BaseRef>()) == *(other.m_ptr->cast_ptr<BaseRef>());
}
// for node equal

View File

@ -1173,7 +1173,11 @@ inline S GetValueNode(const AnfNodePtr &node) {
template <typename S, typename std::enable_if<std::is_base_of<Value, S>::value, S>::type * = nullptr>
inline S *GetValuePtr(const AnfNodePtr &node) {
auto value = GetValuePtr(node);
auto value_node = dyn_cast_ptr<ValueNode>(node);
if (value_node == nullptr) {
return nullptr;
}
const auto &value = value_node->value();
return (value == nullptr) ? nullptr : value->cast_ptr<S>();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -55,6 +55,11 @@ class MS_CORE_API RefType final : public TensorType {
/// \param[in] subtype Define the TensorType for RefType object to refer to.
explicit RefType(const TensorTypePtr &subtype) : TensorType(subtype->element()) {}
/// \brief Constructor for RefType.
///
/// \param[in] subtype Define the TensorType for RefType object to refer to.
explicit RefType(const TensorType *subtype) : TensorType(subtype->element()) {}
/// \brief Destructor of RefType.
~RefType() override {}
MS_DECLARE_PARENT(RefType, TensorType)

View File

@ -592,7 +592,7 @@ std::string FuncGraph::GetVariableArgName() {
const auto &param_node = GetVariableArgParameter();
MS_EXCEPTION_IF_NULL(param_node);
const auto &parameter = param_node->cast<ParameterPtr>();
auto parameter = param_node->cast_ptr<Parameter>();
MS_EXCEPTION_IF_NULL(parameter);
return parameter->name();
}
@ -614,7 +614,7 @@ std::string FuncGraph::GetVariableKwargName() {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is " << fv_param_count_
<< ", parameters is less than 1 + fv_param_count";
}
const auto &parameter = parameters_[(parameters_.size() - fv_param_count_) - 1]->cast<ParameterPtr>();
auto parameter = parameters_[(parameters_.size() - fv_param_count_) - 1]->cast_ptr<Parameter>();
MS_EXCEPTION_IF_NULL(parameter);
return parameter->name();
}
@ -666,7 +666,7 @@ int FuncGraph::GetPositionalArgsCount() const {
AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
for (size_t i = 0; i < parameters_.size(); ++i) {
MS_EXCEPTION_IF_NULL(parameters_[i]);
auto param_cast = parameters_[i]->cast<ParameterPtr>();
auto param_cast = parameters_[i]->cast_ptr<Parameter>();
MS_EXCEPTION_IF_NULL(param_cast);
if (param_cast->name() == name) {
return parameters_[i];

View File

@ -87,7 +87,7 @@ void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(target);
auto old_param = node->cast<ParameterPtr>();
auto old_param = node->cast_ptr<Parameter>();
MS_EXCEPTION_IF_NULL(old_param);
auto debug_info = CloneNodeDebugInfo(node->debug_info(), relation_);
auto new_param = (is_add ? target->add_parameter(std::move(debug_info))
@ -136,7 +136,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node) {
ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
new_const->set_scope(scope);
new_const->set_abstract(node->abstract());
new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
new_const->set_has_new_value(node->cast_ptr<ValueNode>()->has_new_value());
repl_node_[node] = std::move(new_const);
}
@ -148,7 +148,7 @@ void Cloner::CloneFuncGraphValueNode(const AnfNodePtr &node, const FuncGraphPtr
ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
new_const->set_scope(scope);
new_const->set_abstract(node->abstract());
new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
new_const->set_has_new_value(node->cast_ptr<ValueNode>()->has_new_value());
repl_node_[node] = std::move(new_const);
}
@ -229,7 +229,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
auto &cnodes = func_graph->func_graph_cnodes_index();
for (auto &cnode : cnodes) {
auto parent = cnode.first->first->cast<CNodePtr>();
auto parent = cnode.first->first->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(parent);
const auto &valuenode = parent->input(IntToSize(cnode.first->second));
CloneFuncGraphValueNode(valuenode, target_func_graph);
@ -291,7 +291,7 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
auto free_var_node = utils::cast<AnfNodePtr>(free_var);
// Don't lift weight parameter to top func_graph.
if (IsLiftTopFuncGraph(func_graph) && free_var_node->isa<Parameter>()) {
auto free_var_param = free_var_node->cast<ParameterPtr>();
auto free_var_param = free_var_node->cast_ptr<Parameter>();
if (free_var_param->has_default()) {
MS_LOG(DEBUG) << "Bypass weight param: " << free_var_param->DebugString()
<< " for top_func_graph: " << lift_top_func_graph->ToString();
@ -306,7 +306,7 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
}
MS_LOG(DEBUG) << "Gen param: " << free_var_node->ToString() << " for func_graph: " << func_graph->ToString();
auto fv_parameter = AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var));
auto fv_parameter = AddParameter(func_graph, free_var_node);
fv_parameter->set_user_data<bool>("lifted_from_fv", std::make_shared<bool>(true));
auto &fg_params = repl_func_graph_params_[func_graph];
(void)fg_params.emplace_back(fv_parameter);
@ -316,7 +316,7 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) const {
param->set_abstract(node->abstract());
if (node->isa<Parameter>()) {
ParameterPtr old_param = node->cast<ParameterPtr>();
auto old_param = node->cast_ptr<Parameter>();
if (old_param->has_default()) {
// Default parameter can be shared since it is readonly.
param->set_default_param(old_param->default_param());
@ -728,11 +728,11 @@ void Cloner::CloneNodes() {
void Cloner::LinkEdges() {
for (auto &repl : repl_node_) {
CNodePtr old_node = dyn_cast<CNode>(repl.first);
auto old_node = dyn_cast_ptr<CNode>(repl.first);
if (old_node == nullptr) {
continue;
}
CNodePtr new_node = repl.second->cast<CNodePtr>();
auto new_node = repl.second->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(new_node);
for (auto &input : old_node->inputs()) {
auto iter = repl_node_.find(input);

View File

@ -131,7 +131,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
[param_name](const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto param = node->cast<ParameterPtr>();
auto param = node->cast_ptr<Parameter>();
return param != nullptr && param->name() == param_name;
});
if (find_kw_arg_in_list) {

View File

@ -186,8 +186,8 @@ std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
return vecs;
}
if (IsValueNode<FuncGraph>(node)) {
auto graph = GetValueNode<FuncGraphPtr>(node);
auto graph = GetValuePtr<FuncGraph>(node);
if (graph != nullptr) {
auto &ret = graph->return_node();
if (ret != nullptr) {
vecs.push_back(ret);
@ -209,19 +209,16 @@ std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node) {
return vecs;
}
if (IsValueNode<FuncGraph>(node)) {
auto graph = GetValueNode<FuncGraphPtr>(node);
auto graph = GetValuePtr<FuncGraph>(node);
if (graph != nullptr) {
auto &ret = graph->return_node();
if (ret != nullptr) {
vecs.push_back(ret);
}
return vecs;
} else {
if (node->isa<CNode>()) {
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
}
return vecs;
} else if (node->isa<CNode>()) {
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
}
return vecs;
}
std::vector<AnfNodePtr> SuccIncoming(const AnfNodePtr &node) {
@ -234,26 +231,24 @@ std::vector<AnfNodePtr> SuccIncoming(const AnfNodePtr &node) {
}
std::vector<AnfNodePtr> SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node) {
std::vector<AnfNodePtr> vecs;
if (node == nullptr) {
return vecs;
}
auto cnode = dyn_cast<CNode>(node);
if (cnode != nullptr) {
auto &inputs = cnode->inputs();
// Check if free variables used.
for (const auto &input : inputs) {
auto input_fg = GetValueNode<FuncGraphPtr>(input);
if (input_fg) {
for (auto &fv : input_fg->free_variables_nodes()) {
if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
vecs.push_back(fv);
}
if (cnode == nullptr) {
return {};
}
std::vector<AnfNodePtr> vecs;
const auto &inputs = cnode->inputs();
// Check if free variables used.
for (const auto &input : inputs) {
auto input_fg = GetValuePtr<FuncGraph>(input);
if (input_fg != nullptr) {
for (auto &fv : input_fg->free_variables_nodes()) {
if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
vecs.push_back(fv);
}
}
}
FetchCNodeSuccessors(cnode, &vecs);
}
FetchCNodeSuccessors(cnode, &vecs);
return vecs;
}
@ -263,28 +258,24 @@ std::vector<AnfNodePtr> SuccWithFilter(const GraphFilterFunc &graph_filter, cons
return vecs;
}
if (IsValueNode<FuncGraph>(node)) {
auto graph = GetValueNode<FuncGraphPtr>(node);
auto graph = GetValueNode<FuncGraphPtr>(node);
if (graph != nullptr) {
if (graph_filter != nullptr && graph_filter(graph)) {
return vecs;
}
auto &ret = graph->return_node();
if (ret != nullptr) {
vecs.push_back(ret);
}
return vecs;
} else {
if (node->isa<CNode>()) {
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
}
return vecs;
} else if (node->isa<CNode>()) {
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
}
return vecs;
}
const std::vector<AnfNodePtr> &GetInputs(const AnfNodePtr &node) {
static std::vector<AnfNodePtr> empty_inputs;
auto cnode = dyn_cast<CNode>(node);
auto cnode = dyn_cast_ptr<CNode>(node);
if (cnode != nullptr) {
return cnode->inputs();
}

View File

@ -92,8 +92,8 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher {
if (!IsValueNode<FuncGraph>(vnode)) {
return;
}
auto fg = GetValueNode<FuncGraphPtr>(vnode);
AnfNodePtr ret = fg->return_node();
auto fg = GetValuePtr<FuncGraph>(vnode);
const auto &ret = fg->return_node();
DeepFirstSearcher::Visit(ret);
}

View File

@ -661,7 +661,7 @@ void FuncGraphManager::MoveAllCNodeDropGraph(const FuncGraphPtr &source, const F
const ScopePtr &scope) {
AnfNodePtr source_return = source->get_return();
AnfNodePtr source_output = source->output();
AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0);
AnfNodePtr source_prim = source_return->cast_ptr<CNode>()->input(0);
int index = 0;
(void)node_users_[source_prim].erase(make_pair(source_return, index));
@ -1105,7 +1105,7 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se
std::find_if(meta_fg_prim_values.begin(), meta_fg_prim_values.end(), [seen_num](const auto &iter) {
// Check g1->MetaFgPrim(fg)->g2->g cycle.
if (IsValueNode<FuncGraph>(iter.first)) {
auto func_graph = GetValueNode<FuncGraphPtr>(iter.first);
auto func_graph = GetValuePtr<FuncGraph>(iter.first);
return func_graph->seen_ != seen_num;
}
if (IsValueNode<Primitive>(iter.first)) {

View File

@ -435,7 +435,7 @@ IMM_TRAITS(StringImmPtr, const char *)
/// \brief RefKey defines a class whose real type is String.
/// \brief Notice: RefKey is keep for compatible only, we use RefKey just as StringImm.
class MS_CORE_API RefKey : public StringImm {
class MS_CORE_API RefKey final : public StringImm {
public:
/// \brief Constructor of RefKey.
///

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,8 +27,8 @@ void AnfIrVisitor::Visit(const CNodePtr &cnode) {
}
void AnfIrVisitor::Visit(const ValueNodePtr &vnode) {
if (IsValueNode<FuncGraph>(vnode)) {
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
auto func_graph = GetValuePtr<FuncGraph>(vnode);
if (func_graph != nullptr) {
Visit(func_graph->output());
}
}
@ -36,36 +36,34 @@ void AnfIrVisitor::Visit(const ValueNodePtr &vnode) {
void AnfIrVisitor::Visit(const ParameterPtr &) {}
VisitFuncType AnfIrVisitor::Match(const PrimitivePtr &prim, const std::vector<PredicateFuncType> &funcs) {
auto fn = [prim, funcs, this](const AnfNodePtr &node) {
return [prim, funcs, this](const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim)) {
return;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
auto &inputs = node->cast_ptr<CNode>()->inputs();
auto funcs_size = funcs.size();
auto inputs_size = inputs.size();
// check the inputs are matched with the predicate functions
// Check the inputs are matched with the predicate functions.
if (funcs_size > 0) {
// use the predicate function list to check the number of inputs
// Use the predicate function list to check the number of inputs.
if (funcs_size != (inputs_size - 1)) {
return;
}
// check per input
for (size_t i = 0; i < funcs_size; i++) {
// Check inputs.
for (size_t i = 0; i < funcs_size; ++i) {
if (!funcs[i](inputs[i + 1])) {
return;
}
}
}
// visit the inputs
for (size_t i = 1; i < inputs_size; i++) {
// Visit argument inputs.
for (size_t i = 1; i < inputs_size; ++i) {
this->Visit(inputs[i]);
}
};
return fn;
}
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -103,7 +103,7 @@ static inline bool UseMPI() {
}
template <typename T>
inline bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
bool IsEqual(const T *a, const T *b) {
if (a == b) {
return true;
}
@ -113,6 +113,11 @@ inline bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
return *a == *b;
}
template <typename T>
bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
return IsEqual(a.get(), b.get());
}
template <typename T>
bool IsAttrsEqual(const T &a, const T &b) {
if (&a == &b) {