!39471 Optimize pointer casting for compile framework
Merge pull request !39471 from hewei/opt_perf1
This commit is contained in:
commit
6685a89540
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -143,7 +143,7 @@ void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNode
|
|||
MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
|
||||
}
|
||||
mindspore::HashMap<AnfNodePtr, FuncGraphPtr> node_to_fg;
|
||||
auto tuple_graphs = input->cast<CNodePtr>();
|
||||
auto tuple_graphs = input->cast_ptr<CNode>();
|
||||
for (size_t i = 1; i < tuple_graphs->size(); ++i) {
|
||||
auto graph = tuple_graphs->input(i);
|
||||
if (!IsValueNode<FuncGraph>(graph)) {
|
||||
|
@ -191,7 +191,7 @@ static AnfNodePtr SkipHookNodeInBackProp(const AnfNodePtr &node) {
|
|||
if (IsPrimitiveCNode(node, prim::kPrimHookBackward) || IsPrimitiveCNode(node, prim::kPrimCellBackwardHook)) {
|
||||
MS_LOG(WARNING)
|
||||
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
|
||||
auto output_cnode = node->cast<CNodePtr>();
|
||||
auto output_cnode = node->cast_ptr<CNode>();
|
||||
if (output_cnode->size() - 1 == 1) {
|
||||
return output_cnode->input(1);
|
||||
}
|
||||
|
@ -217,16 +217,16 @@ static AnfNodePtr SkipHookNodeInBackProp(const AnfNodePtr &node) {
|
|||
return make_tuple;
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
auto tuple_get_item = node->cast<CNodePtr>();
|
||||
auto tuple_get_item = node->cast_ptr<CNode>();
|
||||
auto inp = tuple_get_item->input(1);
|
||||
if (IsPrimitiveCNode(inp, prim::kPrimHookBackward) || IsPrimitiveCNode(inp, prim::kPrimCellBackwardHook)) {
|
||||
MS_LOG(WARNING)
|
||||
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
|
||||
constexpr size_t idx = 2;
|
||||
auto v_node = tuple_get_item->input(idx)->cast<ValueNodePtr>();
|
||||
auto v_node = tuple_get_item->input(idx)->cast_ptr<ValueNode>();
|
||||
MS_EXCEPTION_IF_NULL(v_node);
|
||||
auto out_idx = GetValue<int64_t>(v_node->value());
|
||||
return inp->cast<CNodePtr>()->input(LongToSize(out_idx) + 1);
|
||||
return inp->cast_ptr<CNode>()->input(LongToSize(out_idx) + 1);
|
||||
}
|
||||
}
|
||||
return node;
|
||||
|
@ -238,7 +238,7 @@ AnfNodePtr HandleRealToComplex(const AnfNodePtr &input, const CNodePtr &din, con
|
|||
if (input_type == nullptr || !input_type->isa<TensorType>()) {
|
||||
return din;
|
||||
}
|
||||
input_type = input_type->cast<TensorTypePtr>()->element();
|
||||
input_type = input_type->cast_ptr<TensorType>()->element();
|
||||
MS_EXCEPTION_IF_NULL(input_type);
|
||||
if (input_type->type_id() == kNumberTypeComplex64 || input_type->type_id() == kNumberTypeComplex128) {
|
||||
return din;
|
||||
|
@ -256,7 +256,7 @@ AnfNodePtr HandleRealToComplex(const AnfNodePtr &input, const CNodePtr &din, con
|
|||
if (din_type == nullptr || !din_type->isa<TensorType>()) {
|
||||
return din;
|
||||
}
|
||||
din_type = din_type->cast<TensorTypePtr>()->element();
|
||||
din_type = din_type->cast_ptr<TensorType>()->element();
|
||||
MS_EXCEPTION_IF_NULL(din_type);
|
||||
if (din_type->type_id() != kNumberTypeComplex64 && din_type->type_id() != kNumberTypeComplex128) {
|
||||
return din;
|
||||
|
@ -689,8 +689,8 @@ void DFunctor::MapValueObject() {
|
|||
|
||||
AdjointPtr adjoint = nullptr;
|
||||
if (IsValueNode<Primitive>(node)) { // Primitive.
|
||||
auto prim = GetValueNode<PrimitivePtr>(node);
|
||||
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn ||
|
||||
auto prim = GetValuePtr<Primitive>(node);
|
||||
if ((prim->Hash() == prim::kPrimReturn->hash() && prim->name() == prim::kPrimReturn->name()) ||
|
||||
(prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) ||
|
||||
(prim->Hash() == prim::kPrimCellBackwardHook->Hash() &&
|
||||
prim->name() == prim::kPrimCellBackwardHook->name())) {
|
||||
|
@ -784,17 +784,15 @@ void DFunctor::BroadCastStopFlag() {
|
|||
while (need_cut_) {
|
||||
need_cut_ = false;
|
||||
for (auto &node : primal_graph_->nodes()) {
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!cnode->stop_gradient()) {
|
||||
// Cut off the cnode only when it's not referred any more
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) ||
|
||||
AllReferencesStopped(cnode)) {
|
||||
MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
|
||||
cnode->set_stop_gradient(true);
|
||||
// The stop set changed, more cut required
|
||||
need_cut_ = true;
|
||||
}
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode != nullptr && !cnode->stop_gradient()) {
|
||||
// Cut off the cnode only when it's not referred any more
|
||||
if (cnode->IsApply(prim::kPrimStopGradient) || cnode->IsApply(prim::kPrimUpdateState) ||
|
||||
AllReferencesStopped(cnode)) {
|
||||
MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
|
||||
cnode->set_stop_gradient(true);
|
||||
// The stop set changed, more cut required
|
||||
need_cut_ = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -809,7 +807,7 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
|
|||
}
|
||||
for (auto &kv : users) {
|
||||
auto &user = kv.first;
|
||||
if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
|
||||
if (!user->isa<CNode>() || !user->cast_ptr<CNode>()->stop_gradient()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -225,7 +225,7 @@ AnfNodePtr GetPythonOps(const FuncGraphPtr &fg, const AnfNodePtr &origin_node, c
|
|||
} else {
|
||||
python_ops_value = prim::GetPythonOps(iter->second.first);
|
||||
}
|
||||
auto origin_cnode = origin_node->cast<CNodePtr>();
|
||||
auto origin_cnode = origin_node->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(origin_cnode);
|
||||
auto &origin_inputs = origin_cnode->inputs();
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(python_ops_value)};
|
||||
|
@ -243,7 +243,7 @@ void ReplacePythonOps(const FuncGraphPtr &fg) {
|
|||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto cnode = node->cast_ptr<CNode>();
|
||||
for (size_t i = 0; i < cnode->size(); ++i) {
|
||||
auto prim = GetCNodePrimitive(cnode->input(i));
|
||||
if (prim == nullptr) {
|
||||
|
@ -388,7 +388,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceB
|
|||
if (prim->is_base()) {
|
||||
fn = GetBpropFunction(prim->name());
|
||||
} else {
|
||||
fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
|
||||
fn = prim->cast_ptr<PrimitivePy>()->GetBpropFunction();
|
||||
if (py::isinstance<py::none>(fn)) {
|
||||
fn = GetBpropFunction(prim->name());
|
||||
}
|
||||
|
@ -491,8 +491,8 @@ static void AppendMonadOutput(const FuncGraphPtr &bprop_fg, const AnfNodePtr &mo
|
|||
auto output_cnode = output->cast<CNodePtr>();
|
||||
if (output_cnode != nullptr) {
|
||||
// If output_cnode has the form like (make_tuple, x, y).
|
||||
while (IsPrimitiveCNode(output_cnode, prim::kPrimDepend)) {
|
||||
auto real_input = output_cnode->input(kRealInputIndexInDepend);
|
||||
while (output_cnode->IsApply(prim::kPrimDepend)) {
|
||||
const auto &real_input = output_cnode->input(kRealInputIndexInDepend);
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
output_cnode = real_input->cast<CNodePtr>();
|
||||
}
|
||||
|
@ -555,7 +555,7 @@ void SetDumpFlag(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) {
|
|||
return;
|
||||
}
|
||||
auto attr = prim->GetAttr(kAttrDump);
|
||||
if (attr != nullptr && attr->isa<StringImm>() && attr->cast<StringImmPtr>()->value() == kValueTrue) {
|
||||
if (attr != nullptr && attr->isa<StringImm>() && attr->cast_ptr<StringImm>()->value() == kValueTrue) {
|
||||
bprop_fg->set_flag(FUNC_GRAPH_FLAG_DUMP, true);
|
||||
}
|
||||
}
|
||||
|
@ -660,7 +660,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
|
|||
// bprop_fg has been checked in caller
|
||||
if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) {
|
||||
// Set bprop output as (env, dx, dy, dz, ...)
|
||||
auto cbprop = bprop_fg->output()->cast<CNodePtr>();
|
||||
auto cbprop = bprop_fg->output()->cast_ptr<CNode>();
|
||||
auto &inputs = cbprop->inputs();
|
||||
|
||||
std::vector<AnfNodePtr> args;
|
||||
|
@ -742,7 +742,7 @@ void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const Func
|
|||
const auto ¤t_primal_fg_params = current_primal_fg->parameters();
|
||||
// The lifted parameters are put in front: {lifted parameters, origin parameters, u/io monad}.
|
||||
for (size_t i = 0; i < current_primal_fg_params.size(); ++i) {
|
||||
auto primal_parameter = dyn_cast<Parameter>(current_primal_fg_params[i]);
|
||||
auto primal_parameter = dyn_cast_ptr<Parameter>(current_primal_fg_params[i]);
|
||||
MS_EXCEPTION_IF_NULL(primal_parameter);
|
||||
auto lifted = primal_parameter->template user_data<bool>(kLiftedUserDataKey);
|
||||
if (lifted == nullptr || !*lifted) {
|
||||
|
@ -845,7 +845,7 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res
|
|||
if (cnode == users.end()) {
|
||||
MS_LOG(EXCEPTION) << "Fail to find cnode.";
|
||||
}
|
||||
auto inputs_num = cnode->first->cast<CNodePtr>()->size() - 1;
|
||||
auto inputs_num = cnode->first->cast_ptr<CNode>()->size() - 1;
|
||||
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
|
@ -886,7 +886,7 @@ FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::Re
|
|||
if (cnode == users.end()) {
|
||||
MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
|
||||
}
|
||||
auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;
|
||||
auto inputs_num = cnode->first->cast_ptr<CNode>()->inputs().size() - 1;
|
||||
auto effect_info = GetPrimEffectInfo(prim);
|
||||
// Don't add U or IO monad parameters as it will be added later.
|
||||
size_t monad_params_size = 0;
|
||||
|
|
|
@ -33,22 +33,25 @@ namespace mindspore {
|
|||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
|
||||
static const std::map<std::string, std::vector<std::string>> op2attrs = {
|
||||
{prim::kPrimBroadcastTo->name(), {kAttrShape}},
|
||||
{prim::kPrimReduceMax->name(), {kAttrKeepDims}},
|
||||
{prim::kPrimReduceMin->name(), {kAttrKeepDims}},
|
||||
{prim::kPrimReduceSum->name(), {kAttrKeepDims}}};
|
||||
|
||||
auto todos = TopoSort(graph->get_return());
|
||||
for (const auto &node : todos) {
|
||||
if (!node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
if (!primitive || dyn_cast<PrimitivePy>(primitive)) {
|
||||
if (primitive == nullptr || primitive->isa<PrimitivePy>()) {
|
||||
continue;
|
||||
}
|
||||
parallel::OperatorAttrs attrs;
|
||||
std::map<std::string, std::vector<std::string>> op2attrs = {{prim::kPrimBroadcastTo->name(), {kAttrShape}},
|
||||
{prim::kPrimReduceMax->name(), {kAttrKeepDims}},
|
||||
{prim::kPrimReduceMin->name(), {kAttrKeepDims}},
|
||||
{prim::kPrimReduceSum->name(), {kAttrKeepDims}}};
|
||||
if (op2attrs.count(primitive->name()) != 0) {
|
||||
for (auto &attr : op2attrs[primitive->name()]) {
|
||||
auto iter = op2attrs.find(primitive->name());
|
||||
if (iter != op2attrs.end()) {
|
||||
for (auto &attr : iter->second) {
|
||||
if (primitive->HasAttr(attr)) {
|
||||
(void)attrs.emplace_back(std::pair{attr, primitive->GetAttr(attr)});
|
||||
} else {
|
||||
|
@ -57,10 +60,10 @@ bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
|
|||
}
|
||||
}
|
||||
}
|
||||
auto new_prim = parallel::CreateOpInstance(attrs, primitive->name(), "")->cast<PrimitivePtr>();
|
||||
(void)new_prim->SetAttrs(primitive->attrs());
|
||||
auto new_prim = parallel::CreateOpInstance(attrs, primitive->name(), "");
|
||||
(void)new_prim->cast_ptr<Primitive>()->SetAttrs(primitive->attrs());
|
||||
AnfNodePtrList inputs = {NewValueNode(new_prim)};
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
auto cnode = dyn_cast_ptr<CNode>(node);
|
||||
(void)inputs.insert(inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
|
||||
cnode->set_inputs(inputs);
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ bool ContainSparseTensor(const abstract::AbstractBasePtr &abs) {
|
|||
return true;
|
||||
}
|
||||
if (abs->isa<abstract::AbstractTuple>()) {
|
||||
auto vec = abs->cast<abstract::AbstractTuplePtr>()->elements();
|
||||
auto vec = abs->cast_ptr<abstract::AbstractTuple>()->elements();
|
||||
return std::any_of(vec.begin(), vec.end(), ContainSparseTensor);
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -40,27 +40,20 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std:
|
|||
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action,
|
||||
bool has_priority_pattern) {
|
||||
auto fn = [prims](const AnfNodePtr &node) -> bool {
|
||||
if (!node->isa<CNode>()) {
|
||||
auto cnode = dyn_cast_ptr<CNode>(node);
|
||||
if (cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inp0 = cnode->input(0);
|
||||
auto prim0 = GetValueNode<PrimitivePtr>(inp0);
|
||||
if (prim0 == nullptr) {
|
||||
auto cnode_prim = GetValuePtr<Primitive>(cnode->input(0));
|
||||
if (cnode_prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto hash = prim0->Hash();
|
||||
auto const &name = prim0->name();
|
||||
for (auto &prim : prims) {
|
||||
if (hash == prim->Hash() && name == prim->name()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
auto hash = cnode_prim->Hash();
|
||||
const auto &name = cnode_prim->name();
|
||||
return std::any_of(prims.begin(), prims.end(), [&hash, &name](const PrimitivePtr &prim) {
|
||||
return (prim->Hash() == hash) && (prim->name() == name);
|
||||
});
|
||||
};
|
||||
|
||||
return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
|
||||
}
|
||||
|
||||
|
@ -93,17 +86,15 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode
|
|||
return result;
|
||||
}
|
||||
|
||||
static bool isTraversable(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
static inline bool isTraversable(const AnfNodePtr &node) {
|
||||
if (node->isa<CNode>() || node->isa<Parameter>()) {
|
||||
return true;
|
||||
}
|
||||
if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
// FuncGraph or RefKey value node is traversable.
|
||||
auto value_node = dyn_cast_ptr<ValueNode>(node);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
const auto &value = value_node->value();
|
||||
return (value != nullptr) && (value->isa<FuncGraph>() || value->isa<RefKey>());
|
||||
}
|
||||
|
||||
static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node,
|
||||
|
@ -131,34 +122,38 @@ static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &n
|
|||
}
|
||||
|
||||
static void UpdateTransformingListForSubstitutions(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change) {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
|
||||
auto fg = GetValuePtr<FuncGraph>(node);
|
||||
if (fg != nullptr) {
|
||||
(void)todo->emplace_back(fg->output());
|
||||
}
|
||||
|
||||
if (change) {
|
||||
(*todo).emplace_back(node);
|
||||
(void)todo->emplace_back(node);
|
||||
} else {
|
||||
if (node->isa<CNode>()) {
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
|
||||
auto cnode = dyn_cast_ptr<CNode>(node);
|
||||
if (cnode != nullptr) {
|
||||
const auto &inputs = cnode->inputs();
|
||||
(void)todo->insert(todo->end(), inputs.cbegin(), inputs.cend());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void UpdateTransformingListForIR(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change,
|
||||
const SubstitutionPtr &substitution) {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
|
||||
auto fg = GetValuePtr<FuncGraph>(node);
|
||||
if (fg != nullptr) {
|
||||
(void)todo->emplace_back(fg->output());
|
||||
}
|
||||
|
||||
// If there is a priority pattern in substitution, don't transform the new node,
|
||||
// otherwise some nodes may match the wrong patterns.
|
||||
if (change && substitution != nullptr && !substitution->has_priority_pattern_) {
|
||||
(*todo).emplace_back(node);
|
||||
(void)todo->emplace_back(node);
|
||||
} else {
|
||||
if (node->isa<CNode>()) {
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
|
||||
auto cnode = dyn_cast_ptr<CNode>(node);
|
||||
if (cnode != nullptr) {
|
||||
const auto &inputs = cnode->inputs();
|
||||
(void)todo->insert(todo->end(), inputs.cbegin(), inputs.cend());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -194,12 +189,11 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
|
|||
FuncGraphManagerPtr manager = optimizer->manager();
|
||||
auto seen = NewSeenGeneration();
|
||||
std::deque<AnfNodePtr> todo;
|
||||
todo.emplace_back(func_graph->output());
|
||||
(void)todo.emplace_back(func_graph->output());
|
||||
bool changes = false;
|
||||
|
||||
auto &all_nodes = manager->all_nodes();
|
||||
while (!todo.empty()) {
|
||||
AnfNodePtr node = todo.front();
|
||||
AnfNodePtr node = std::move(todo.front());
|
||||
todo.pop_front();
|
||||
|
||||
if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
|
||||
|
@ -373,13 +367,13 @@ bool SimpleRewriter::Run() {
|
|||
continue;
|
||||
}
|
||||
node->seen_ = seen;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto cnode = node->cast_ptr<CNode>();
|
||||
if (cnode != nullptr) {
|
||||
for (auto &input : cnode->inputs()) {
|
||||
add_todo(input);
|
||||
}
|
||||
} else {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(node);
|
||||
auto fg = GetValuePtr<FuncGraph>(node);
|
||||
if (fg != nullptr) {
|
||||
add_todo(fg->output());
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -43,7 +43,7 @@ MatchResultPtr Call::match(const AnfNodePtr &node) {
|
|||
}
|
||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||
// IsPrimitiveCNode
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto cnode = node->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Check Primitive ValueNode
|
||||
if (prim_pattern_ != nullptr) {
|
||||
|
@ -120,20 +120,16 @@ MatchResultPtr Any::match(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
MatchResultPtr Imm::match(const AnfNodePtr &node) {
|
||||
if (!IsValueNode<Int32Imm>(node)) {
|
||||
auto value_ptr = GetValuePtr<Int32Imm>(node);
|
||||
if (value_ptr == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
// Check value
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value_ptr = value_node->value()->cast<Int32ImmPtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
if (value_ptr->value() == value_) {
|
||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||
res->add_entry(shared_from_base<Imm>(), node);
|
||||
return res;
|
||||
if (value_ptr->value() != value_) {
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||
res->add_entry(shared_from_base<Imm>(), node);
|
||||
return res;
|
||||
}
|
||||
|
||||
AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -57,7 +57,7 @@ bool IsTraversable(const AnfNodePtr &node) {
|
|||
|
||||
AnfNodePtr BuildPrimitive(const PatternPtr &pattern) {
|
||||
// Build up AnfNode from primitive
|
||||
auto prim_pattern = pattern->cast<PrimPtr>();
|
||||
auto prim_pattern = pattern->cast_ptr<Prim>();
|
||||
MS_EXCEPTION_IF_NULL(prim_pattern);
|
||||
PrimitivePyPtr prim = prim_pattern->matched_primitive();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
@ -67,7 +67,7 @@ AnfNodePtr BuildPrimitive(const PatternPtr &pattern) {
|
|||
|
||||
AnfNodePtr BuildNewTensor(const PatternPtr &pattern) {
|
||||
// Build a ValueNode from TensorPtr
|
||||
auto new_tensor_pattern = pattern->cast<NewTensorPtr>();
|
||||
auto new_tensor_pattern = pattern->cast_ptr<NewTensor>();
|
||||
MS_EXCEPTION_IF_NULL(new_tensor_pattern);
|
||||
auto input_tensor = new_tensor_pattern->input_tensor();
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
|
@ -76,7 +76,7 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern) {
|
|||
|
||||
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg,
|
||||
const FuncGraphPtr &top_graph) {
|
||||
auto call_pattern = pattern->cast<CallPtr>();
|
||||
auto call_pattern = pattern->cast_ptr<Call>();
|
||||
MS_EXCEPTION_IF_NULL(call_pattern);
|
||||
auto prim = call_pattern->prim_value();
|
||||
if (prim != nullptr) {
|
||||
|
@ -88,7 +88,7 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP
|
|||
}
|
||||
|
||||
AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &top_graph) {
|
||||
auto new_para_pattern = pattern->cast<NewParameterPtr>();
|
||||
auto new_para_pattern = pattern->cast_ptr<NewParameter>();
|
||||
MS_EXCEPTION_IF_NULL(new_para_pattern);
|
||||
if (!new_para_pattern->built()) {
|
||||
static int64_t parameter_id = 0;
|
||||
|
@ -122,7 +122,7 @@ AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &re
|
|||
}
|
||||
|
||||
AnfNodePtr BuildImmNode(const PatternPtr &pattern) {
|
||||
auto imm_pattern = pattern->cast<ImmPtr>();
|
||||
auto imm_pattern = pattern->cast_ptr<Imm>();
|
||||
MS_EXCEPTION_IF_NULL(imm_pattern);
|
||||
auto value = imm_pattern->value();
|
||||
auto scalar_value_ptr = std::make_shared<Int64Imm>(value);
|
||||
|
@ -134,7 +134,7 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
|
|||
auto target_node = res->get_node(pattern);
|
||||
if (target_node != nullptr) {
|
||||
// If pattern is NewParameter, check whether it shouldn't last and is not built
|
||||
auto new_para = pattern->cast<NewParameterPtr>();
|
||||
auto new_para = pattern->cast_ptr<NewParameter>();
|
||||
if (new_para == nullptr || new_para->should_last() || new_para->built()) {
|
||||
return target_node;
|
||||
}
|
||||
|
@ -218,20 +218,20 @@ void ReflectParamBackToPython(const AnfNodePtr ¶m, const string ¶m_name,
|
|||
MS_LOG(EXCEPTION) << "Failed to convert new parameter to ValuePtr.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto param_node = param->cast<ParameterPtr>();
|
||||
auto param_node = param->cast_ptr<Parameter>();
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
param_node->set_default_param(param_value);
|
||||
}
|
||||
|
||||
void Reset(const PatternPtr &pattern) {
|
||||
if (pattern->isa<Prim>()) {
|
||||
auto prim_pattern = pattern->cast<PrimPtr>();
|
||||
auto prim_pattern = pattern->cast_ptr<Prim>();
|
||||
prim_pattern->reset();
|
||||
} else if (pattern->isa<NewParameter>()) {
|
||||
auto new_param_pattern = pattern->cast<NewParameterPtr>();
|
||||
auto new_param_pattern = pattern->cast_ptr<NewParameter>();
|
||||
new_param_pattern->reset();
|
||||
} else if (pattern->isa<Call>()) {
|
||||
auto call_with_pattern = pattern->cast<CallPtr>();
|
||||
auto call_with_pattern = pattern->cast_ptr<Call>();
|
||||
for (const auto &sub_pattern : call_with_pattern->inputs()) {
|
||||
Reset(sub_pattern);
|
||||
}
|
||||
|
@ -257,7 +257,7 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
|
|||
MS_EXCEPTION_IF_NULL(dst_pattern_);
|
||||
if (src_pattern_ == nullptr) {
|
||||
// Add NewParameter
|
||||
auto new_para_pattern = dst_pattern_->cast<NewParameterPtr>();
|
||||
auto new_para_pattern = dst_pattern_->cast_ptr<NewParameter>();
|
||||
if (new_para_pattern == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null.";
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -81,7 +81,7 @@ void PyPassManager::GenNewParameter(const PatternPtr ¶meter) {
|
|||
auto cur_pg = GetPassGroup(Phase::OPT);
|
||||
MS_EXCEPTION_IF_NULL(cur_pg);
|
||||
cur_pg->SetRunOnlyOnce(true);
|
||||
auto new_para_pattern = parameter->cast<NewParameterPtr>();
|
||||
auto new_para_pattern = parameter->cast_ptr<NewParameter>();
|
||||
MS_EXCEPTION_IF_NULL(new_para_pattern);
|
||||
auto pass_name = new_para_pattern->para_name();
|
||||
new_para_pattern->set_last(true);
|
||||
|
|
|
@ -59,7 +59,7 @@ bool WithRecomputedScope(const AnfNodePtr &node) {
|
|||
|
||||
ValuePtr GetRecomputeCNodeAttr(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto cnode = node->cast_ptr<CNode>();
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -222,7 +222,7 @@ bool HasGradInputs(const AnfNodePtr &node, mindspore::HashMap<AnfNodePtr, bool>
|
|||
if (has_grad_inputs_map->find(node) != has_grad_inputs_map->end()) {
|
||||
return has_grad_inputs_map->find(node)->second;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto cnode = node->cast_ptr<CNode>();
|
||||
if (cnode == nullptr) {
|
||||
(void)has_grad_inputs_map->emplace(node, false);
|
||||
return false;
|
||||
|
@ -230,7 +230,7 @@ bool HasGradInputs(const AnfNodePtr &node, mindspore::HashMap<AnfNodePtr, bool>
|
|||
const auto &inputs = cnode->inputs();
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// For the pipeline split case, the forward pass may depend on the backward pass.
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && i == kDependAttachNodeIndex) {
|
||||
if (cnode->IsApply(prim::kPrimDepend) && i == kDependAttachNodeIndex) {
|
||||
continue;
|
||||
}
|
||||
if (IsBpropNode(inputs[i]) || HasGradInputs(inputs[i], has_grad_inputs_map)) {
|
||||
|
@ -320,7 +320,7 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
|
|||
std::vector<AnfNodePtr> tuple_getitem_output_nodes;
|
||||
GetTupleGetItemOutputNodes(mng, node, &tuple_getitem_output_nodes);
|
||||
for (const auto &output_node : tuple_getitem_output_nodes) {
|
||||
auto output_cnode = output_node->cast<CNodePtr>();
|
||||
auto output_cnode = output_node->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
output_cnode->AddAttr(kAttrRecompute, MakeValue(true));
|
||||
}
|
||||
|
@ -362,7 +362,7 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod
|
|||
for (size_t i = 0; i < origin_node->size(); ++i) {
|
||||
auto input = origin_node->input(i);
|
||||
if (i == 0 && IsPrimitive(input, prim::kPrimAllGather)) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(input);
|
||||
auto prim = GetValuePtr<Primitive>(input);
|
||||
auto instance_name = prim->instance_name();
|
||||
bool is_from_parallel_optimizer = instance_name.find("parallel_optimizer") != std::string::npos;
|
||||
int64_t fusion_id = prim->HasAttr(kAttrFusion) ? GetValue<int64_t>(prim->GetAttr(kAttrFusion)) : 0;
|
||||
|
@ -426,7 +426,7 @@ void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const mindspore::HashSe
|
|||
for (const auto &target_node : target_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(target_node);
|
||||
MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input";
|
||||
auto target_cnode = target_node->cast<CNodePtr>();
|
||||
auto target_cnode = target_node->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(target_cnode);
|
||||
std::vector<AnfNodePtr> new_target_inputs;
|
||||
for (const auto &input : target_cnode->inputs()) {
|
||||
|
|
|
@ -278,7 +278,7 @@ FuncGraphPtr CompileCacheManager::GetCachedFuncGraph(const FuncGraphManagerPtr &
|
|||
// The value of attr "shared_name" will changed every time.
|
||||
auto cnodes = fg->GetOrderedCnodes();
|
||||
for (const auto &cnode : cnodes) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
auto prim = GetValuePtr<Primitive>(cnode->input(0));
|
||||
if (prim != nullptr && prim->HasAttr("shared_name")) {
|
||||
prim->set_attr("shared_name", MakeValue(queue_name));
|
||||
break;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/value.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
|
@ -31,6 +32,7 @@
|
|||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/irpass/symbol_resolver.h"
|
||||
#include "include/common/debug/anf_dump_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
|
@ -452,7 +454,7 @@ AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::obj
|
|||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::object namespace_obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
|
||||
auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj);
|
||||
std::string attr_as_string = GetValueNode<StringImmPtr>(attr)->value();
|
||||
const std::string &attr_as_string = GetValuePtr<StringImm>(attr)->value();
|
||||
auto new_symbol = std::make_shared<Symbol>(attr_as_string);
|
||||
MS_LOG(DEBUG) << "name_space: " << new_namespace->ToString() << ", symbol: " << new_symbol->ToString();
|
||||
|
||||
|
@ -489,7 +491,9 @@ AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py:
|
|||
} else if (count_msclass == sequence_size) {
|
||||
// Resolve MsClass instances.
|
||||
for (size_t i = 0; i < sequence_size; ++i) {
|
||||
auto attr_str = GetValue<std::string>(GetValueNode(attr));
|
||||
auto attr_str_ptr = GetValuePtr<StringImm>(attr);
|
||||
MS_EXCEPTION_IF_NULL(attr_str_ptr);
|
||||
const auto &attr_str = attr_str_ptr->value();
|
||||
auto res = ResolveMsClassWithAttr(manager, sequence[i], attr_str, get_attr_node);
|
||||
(void)inputs.emplace_back(res);
|
||||
}
|
||||
|
|
|
@ -214,7 +214,7 @@ void SetValueMutable(const abstract::AbstractBasePtr &abs) {
|
|||
return;
|
||||
}
|
||||
|
||||
auto abs_sequence = abs->cast<abstract::AbstractSequencePtr>();
|
||||
auto abs_sequence = abs->cast_ptr<abstract::AbstractSequence>();
|
||||
if (abs_sequence != nullptr) {
|
||||
const auto &elements = abs_sequence->elements();
|
||||
for (auto &ele : elements) {
|
||||
|
@ -223,7 +223,7 @@ void SetValueMutable(const abstract::AbstractBasePtr &abs) {
|
|||
return;
|
||||
}
|
||||
|
||||
auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
|
||||
auto abs_dict = abs->cast_ptr<abstract::AbstractDictionary>();
|
||||
if (abs_dict != nullptr) {
|
||||
const auto &elements = abs_dict->elements();
|
||||
for (auto &ele : elements) {
|
||||
|
@ -590,16 +590,15 @@ void GraphExecutorPy::GetWeightInfo(
|
|||
auto x = root_node->input(1);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
|
||||
weight_name = weight_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
|
||||
weight_name = weight_node->cast_ptr<CNode>()->input(1)->cast_ptr<Parameter>()->name();
|
||||
} else {
|
||||
auto para = weight_node->cast<ParameterPtr>();
|
||||
auto para = weight_node->cast_ptr<Parameter>();
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
weight_name = para->name();
|
||||
}
|
||||
// find the fakequant from input
|
||||
int64_t count = 0;
|
||||
const int64_t max_depth = 5;
|
||||
CNodePtr cnode = nullptr;
|
||||
auto is_quant_cnode = [](const AnfNodePtr &node) {
|
||||
return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel) ||
|
||||
|
@ -610,7 +609,7 @@ void GraphExecutorPy::GetWeightInfo(
|
|||
if (count >= max_depth) {
|
||||
break;
|
||||
}
|
||||
cnode = x->cast<CNodePtr>();
|
||||
auto cnode = x->cast_ptr<CNode>();
|
||||
if (cnode == nullptr || cnode->size() <= 1) {
|
||||
break;
|
||||
}
|
||||
|
@ -624,9 +623,9 @@ void GraphExecutorPy::GetWeightInfo(
|
|||
if (!is_quant_cnode(x)) {
|
||||
return;
|
||||
}
|
||||
cnode = x->cast<CNodePtr>();
|
||||
auto cnode = x->cast_ptr<CNode>();
|
||||
constexpr size_t expect_input_size = 4;
|
||||
if (cnode == nullptr || IsPrimitiveCNode(cnode, prim::kPrimLoad) || cnode->size() != expect_input_size) {
|
||||
if (cnode == nullptr || cnode->IsApply(prim::kPrimLoad) || cnode->size() != expect_input_size) {
|
||||
return;
|
||||
}
|
||||
const size_t fakequant_index = 2;
|
||||
|
@ -636,18 +635,18 @@ void GraphExecutorPy::GetWeightInfo(
|
|||
}
|
||||
std::string fakequant_min_node_name;
|
||||
if (IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
|
||||
fakequant_min_node_name = fakequant_min_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
|
||||
fakequant_min_node_name = fakequant_min_node->cast_ptr<CNode>()->input(1)->cast_ptr<Parameter>()->name();
|
||||
} else {
|
||||
auto param = fakequant_min_node->cast<ParameterPtr>();
|
||||
auto param = fakequant_min_node->cast_ptr<Parameter>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
fakequant_min_node_name = param->name();
|
||||
}
|
||||
auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
|
||||
const auto &quant_op_value = cnode->input(0)->cast_ptr<ValueNode>()->value();
|
||||
MS_EXCEPTION_IF_NULL(quant_op_value);
|
||||
if (!quant_op_value->isa<PrimitivePy>()) {
|
||||
return;
|
||||
}
|
||||
auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
|
||||
auto quant_op = quant_op_value->cast_ptr<PrimitivePy>();
|
||||
(*fake_quant_table)[weight_name] = std::make_pair(quant_op->adapter(), fakequant_min_node_name);
|
||||
}
|
||||
|
||||
|
@ -677,7 +676,7 @@ std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> GraphExecut
|
|||
}
|
||||
auto weight = root_node->input(weight_index);
|
||||
if (!is_quant_cnode(weight)) {
|
||||
auto tuple_node = weight->cast<CNodePtr>();
|
||||
auto tuple_node = weight->cast_ptr<CNode>();
|
||||
if (tuple_node != nullptr) {
|
||||
auto fake_node = tuple_node->input(1);
|
||||
if (!is_quant_cnode(fake_node)) {
|
||||
|
@ -688,7 +687,7 @@ std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> GraphExecut
|
|||
}
|
||||
}
|
||||
// get parameter weight's name
|
||||
auto cnode = weight->cast<CNodePtr>();
|
||||
auto cnode = weight->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto weight_node = cnode->input(weight_index);
|
||||
if (!weight_node->isa<Parameter>() && !IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
|
||||
|
@ -1207,7 +1206,7 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef
|
|||
// Maybe some default parameter
|
||||
for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) {
|
||||
MS_EXCEPTION_IF_NULL(graph_params[i]);
|
||||
auto param_ptr = (graph_params[i])->cast<ParameterPtr>();
|
||||
auto param_ptr = (graph_params[i])->cast_ptr<Parameter>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (!param_ptr->has_default()) {
|
||||
MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
|
||||
|
@ -1346,7 +1345,7 @@ void GraphExecutorPy::UpdataParamNodeDefaultInput(
|
|||
auto ¶ms = func_graph->parameters();
|
||||
for (const auto ¶m : params) {
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto param_cast = param->cast<ParameterPtr>();
|
||||
auto param_cast = param->cast_ptr<Parameter>();
|
||||
MS_EXCEPTION_IF_NULL(param_cast);
|
||||
auto iter = params_value.find(param_cast->name());
|
||||
if (iter != params_value.end()) {
|
||||
|
|
|
@ -77,8 +77,7 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num
|
|||
|
||||
static bool HasVirtualDataset(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (auto &node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -96,7 +95,7 @@ static CNodePtr CreateTupleGetItem(const AnfNodePtr &node, size_t index, const F
|
|||
CNodePtr tuple_get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
tuple_get_item->set_scope(node->scope());
|
||||
auto input_abstract_tuple = node->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||
auto input_abstract_tuple = node->abstract()->cast_ptr<abstract::AbstractTuple>();
|
||||
MS_EXCEPTION_IF_NULL(input_abstract_tuple);
|
||||
auto tuple_get_item_abstract = input_abstract_tuple->elements()[index];
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item_abstract);
|
||||
|
@ -134,7 +133,7 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
|
|||
}
|
||||
std::set<AnfNodePtr> input_parameters;
|
||||
for (auto &anf_param : root->parameters()) {
|
||||
auto param = anf_param->cast<ParameterPtr>();
|
||||
auto param = anf_param->cast_ptr<Parameter>();
|
||||
if (!param->has_default()) {
|
||||
(void)input_parameters.insert(anf_param);
|
||||
}
|
||||
|
@ -143,7 +142,7 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
|
|||
auto node_users_map = root->manager()->node_users();
|
||||
auto node_users = node_users_map[input_parameter];
|
||||
for (auto node_user : node_users) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
auto cnode = node_user.first->cast_ptr<CNode>();
|
||||
if (IsValueNode<Primitive>(cnode->inputs()[0]) ||
|
||||
(IsValueNode<FuncGraph>(cnode->inputs()[0]) && !root->has_flag(parallel::kTraining))) {
|
||||
(void)graph_sets.insert(cnode->func_graph());
|
||||
|
@ -155,7 +154,7 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
|
|||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto cnode = node->cast_ptr<CNode>();
|
||||
if ((cnode->size() < NODE_INPUT_NUM) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,6 +27,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
static inline bool IsSameValue(Value *v1, Value *v2) {
|
||||
if (v1->isa<tensor::Tensor>() && v2->isa<tensor::Tensor>()) {
|
||||
return static_cast<tensor::Tensor *>(v1)->ValueEqual(*(static_cast<tensor::Tensor *>(v2)));
|
||||
}
|
||||
return *v1 == *v2;
|
||||
}
|
||||
|
||||
void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache,
|
||||
HashValue *const hash_value) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -35,7 +42,7 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has
|
|||
if (IsValueNode<FuncGraph>(node)) {
|
||||
return;
|
||||
}
|
||||
const auto &to_check_value = GetValueNode(node);
|
||||
auto to_check_value = GetValuePtr(node);
|
||||
MS_EXCEPTION_IF_NULL(to_check_value);
|
||||
|
||||
// Calculate hash value.
|
||||
|
@ -62,20 +69,13 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has
|
|||
if (v == node) {
|
||||
return;
|
||||
}
|
||||
const auto &existed_value = GetValueNode(v);
|
||||
auto existed_value = GetValuePtr(v);
|
||||
MS_EXCEPTION_IF_NULL(existed_value);
|
||||
auto equal = [&]() -> bool {
|
||||
if (existed_value->isa<tensor::Tensor>() && to_check_value->isa<tensor::Tensor>()) {
|
||||
return existed_value->cast<tensor::TensorPtr>()->ValueEqual(*(to_check_value->cast<tensor::TensorPtr>()));
|
||||
}
|
||||
return *existed_value == *to_check_value;
|
||||
};
|
||||
if (equal()) {
|
||||
if (IsSameValue(existed_value, to_check_value)) {
|
||||
(void)manager->Replace(node, v);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Meet for the first time, append node to bucket.
|
||||
bucket.emplace_back(node);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -179,8 +179,9 @@ AbstractFunctionPtr GetAbstractFuncRecursively(const AbstractBasePtr &abs, const
|
|||
const std::size_t offset) {
|
||||
if (abs->isa<AbstractFuncAtom>()) {
|
||||
return abs->cast<AbstractFuncAtomPtr>();
|
||||
} else if (abs->isa<AbstractSequence>()) {
|
||||
const auto &abs_seq = abs->cast<AbstractSequencePtr>();
|
||||
}
|
||||
if (abs->isa<AbstractSequence>()) {
|
||||
auto abs_seq = abs->cast_ptr<AbstractSequence>();
|
||||
MS_EXCEPTION_IF_NULL(abs_seq);
|
||||
const auto &elements = abs_seq->elements();
|
||||
if (offset >= index.size()) {
|
||||
|
@ -190,12 +191,12 @@ AbstractFunctionPtr GetAbstractFuncRecursively(const AbstractBasePtr &abs, const
|
|||
MS_LOG(EXCEPTION) << "At offset" << offset << ", elements size of AsyncAbstract result: " << abs->ToString()
|
||||
<< " is less than or equal to index: " << index[offset];
|
||||
}
|
||||
const auto &resolved = GetAbstractFuncRecursively(elements[index[offset]], index, offset + 1);
|
||||
auto resolved = GetAbstractFuncRecursively(elements[index[offset]], index, offset + 1);
|
||||
if (!resolved->isa<AbstractFuncAtom>()) {
|
||||
MS_LOG(EXCEPTION) << "AsyncAbstract result cannot be resolved to AbstractFuncAtom, but: " << resolved->ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Return abstract: " << resolved->ToString();
|
||||
return resolved->cast<AbstractFuncAtomPtr>();
|
||||
return resolved;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "AsyncAbstract cannot resolved to AbstractFuncAtom or AbstractSeqeunce, but: "
|
||||
<< abs->ToString();
|
||||
|
|
|
@ -61,7 +61,7 @@ void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &,
|
|||
} // namespace
|
||||
|
||||
bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
|
||||
auto new_sequence = dyn_cast<AbstractSequence>(arg);
|
||||
auto new_sequence = dyn_cast_ptr<AbstractSequence>(arg);
|
||||
if (new_sequence != nullptr && new_sequence->sequence_nodes() != nullptr && new_sequence->size() != 0) {
|
||||
static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
|
||||
auto prev_result = cache_mgr.GetValue(conf);
|
||||
|
@ -69,7 +69,7 @@ bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg)
|
|||
return false;
|
||||
}
|
||||
auto prev_abs = prev_result->abstract();
|
||||
auto old_sequence = dyn_cast<AbstractSequence>(prev_abs);
|
||||
auto old_sequence = dyn_cast_ptr<AbstractSequence>(prev_abs);
|
||||
if (old_sequence != nullptr &&
|
||||
(old_sequence->sequence_nodes() == nullptr || old_sequence->sequence_nodes()->empty()) && *arg == *prev_abs) {
|
||||
MS_LOG(DEBUG) << "Always eval";
|
||||
|
@ -255,7 +255,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
<< fg->ToString() << "();";
|
||||
}
|
||||
|
||||
auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>());
|
||||
auto func_graph_evaluator = mindspore::cast<FuncGraphEvaluator>(this);
|
||||
if (func_graph_evaluator != nullptr) {
|
||||
if (engine->root_func_graph() == func_graph_evaluator->func_graph()) {
|
||||
engine->set_root_context(context);
|
||||
|
@ -551,7 +551,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
|
|||
AbstractBasePtrList bparams;
|
||||
bparams.push_back(SensitivityTransform(primal_func_));
|
||||
// Check if primal func graph has the primitive returned sparse result in its bprop().
|
||||
auto real_primal_func = dyn_cast<FuncGraphAbstractClosure>(primal_func_);
|
||||
auto real_primal_func = dyn_cast_ptr<FuncGraphAbstractClosure>(primal_func_);
|
||||
MS_EXCEPTION_IF_NULL(real_primal_func);
|
||||
FuncGraphPtr primal_func_graph = real_primal_func->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(primal_func_graph);
|
||||
|
@ -625,7 +625,7 @@ AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_
|
|||
MS_LOG(EXCEPTION) << "The orig_abs should be AbstractTensor when axis is " << *axis << ", but got a "
|
||||
<< orig_abs->ToString() << ".";
|
||||
}
|
||||
ShapeVector orig_shape = dyn_cast<abstract::Shape>(orig_abs->BuildShape())->shape();
|
||||
ShapeVector orig_shape = dyn_cast_ptr<abstract::Shape>(orig_abs->BuildShape())->shape();
|
||||
int shape_len = SizeToInt(orig_shape.size());
|
||||
if (*axis < -shape_len || *axis >= shape_len) {
|
||||
MS_LOG(EXCEPTION) << "The axis: " << *axis << " in 'in_axes' is out of bounds for array of dimension ["
|
||||
|
@ -649,22 +649,21 @@ AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_
|
|||
AbstractBasePtr GetLogicalViewAbs(const AbstractBasePtr &physical_view_abs, const ValuePtr &in_axes, int *axis_size) {
|
||||
MS_EXCEPTION_IF_NULL(physical_view_abs);
|
||||
MS_EXCEPTION_IF_NULL(in_axes);
|
||||
auto physical_view_abs_sequence = dyn_cast<abstract::AbstractSequence>(physical_view_abs);
|
||||
auto physical_view_abs_sequence = dyn_cast_ptr<abstract::AbstractSequence>(physical_view_abs);
|
||||
if (physical_view_abs_sequence != nullptr) {
|
||||
AbstractBasePtrList abs_list = physical_view_abs_sequence->elements();
|
||||
AbstractBasePtrList logical_view_abs_list;
|
||||
auto in_axes_seq = dyn_cast<ValueSequeue>(in_axes);
|
||||
auto in_axes_seq = dyn_cast_ptr<ValueSequeue>(in_axes);
|
||||
int index = 0;
|
||||
(void)std::transform(
|
||||
abs_list.begin(), abs_list.end(), std::back_inserter(logical_view_abs_list),
|
||||
[&axis_size, &index, &in_axes_seq, in_axes](const AbstractBasePtr &sub_abs) -> AbstractBasePtr {
|
||||
ValuePtr sub_in_axes = in_axes;
|
||||
if (in_axes->isa<ValueSequeue>()) {
|
||||
sub_in_axes = (*in_axes_seq)[index];
|
||||
index++;
|
||||
}
|
||||
return GetLogicalViewAbs(sub_abs, sub_in_axes, axis_size);
|
||||
});
|
||||
(void)std::transform(abs_list.begin(), abs_list.end(), std::back_inserter(logical_view_abs_list),
|
||||
[&axis_size, &index, in_axes_seq, in_axes](const AbstractBasePtr &sub_abs) -> AbstractBasePtr {
|
||||
ValuePtr sub_in_axes = in_axes;
|
||||
if (in_axes->isa<ValueSequeue>()) {
|
||||
sub_in_axes = (*in_axes_seq)[index];
|
||||
index++;
|
||||
}
|
||||
return GetLogicalViewAbs(sub_abs, sub_in_axes, axis_size);
|
||||
});
|
||||
if (physical_view_abs->isa<AbstractList>()) {
|
||||
return std::make_shared<AbstractList>(logical_view_abs_list, physical_view_abs_sequence->sequence_nodes());
|
||||
}
|
||||
|
@ -672,7 +671,7 @@ AbstractBasePtr GetLogicalViewAbs(const AbstractBasePtr &physical_view_abs, cons
|
|||
}
|
||||
ValuePtr in_axis = in_axes;
|
||||
if (in_axis->isa<Int64Imm>()) {
|
||||
int axis = dyn_cast<Int64Imm>(in_axis)->value();
|
||||
int axis = dyn_cast_ptr<Int64Imm>(in_axis)->value();
|
||||
auto logical_view_abs = ReduceDim(&axis, physical_view_abs, axis_size);
|
||||
return logical_view_abs;
|
||||
}
|
||||
|
@ -689,7 +688,7 @@ AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_s
|
|||
AbstractBasePtr out_abs = nullptr;
|
||||
ShapeVector orig_shape;
|
||||
if (orig_abs->isa<abstract::AbstractTensor>()) {
|
||||
orig_shape = dyn_cast<abstract::Shape>(orig_abs->BuildShape())->shape();
|
||||
orig_shape = dyn_cast_ptr<abstract::Shape>(orig_abs->BuildShape())->shape();
|
||||
}
|
||||
int shape_len = SizeToInt(orig_shape.size() + 1);
|
||||
if (*axis < -shape_len || *axis >= shape_len) {
|
||||
|
@ -713,11 +712,11 @@ AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_s
|
|||
|
||||
AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, const ValuePtr &out_axes, int axis_size) {
|
||||
MS_EXCEPTION_IF_NULL(logical_view_abs);
|
||||
auto logical_view_abs_sequence = dyn_cast<abstract::AbstractSequence>(logical_view_abs);
|
||||
auto logical_view_abs_sequence = dyn_cast_ptr<abstract::AbstractSequence>(logical_view_abs);
|
||||
if (logical_view_abs_sequence != nullptr) {
|
||||
AbstractBasePtrList logical_view_abs_list = logical_view_abs_sequence->elements();
|
||||
AbstractBasePtrList physical_view_abs_list;
|
||||
auto out_axes_seq = dyn_cast<ValueSequeue>(out_axes);
|
||||
auto out_axes_seq = dyn_cast_ptr<ValueSequeue>(out_axes);
|
||||
if (out_axes_seq != nullptr) {
|
||||
if (logical_view_abs_list.size() != out_axes_seq->size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of vmap's 'out_axes' should be equal to the number of results of 'fn': "
|
||||
|
@ -727,7 +726,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons
|
|||
int index = 0;
|
||||
(void)std::transform(
|
||||
logical_view_abs_list.begin(), logical_view_abs_list.end(), std::back_inserter(physical_view_abs_list),
|
||||
[&axis_size, &index, &out_axes_seq, out_axes](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
|
||||
[&axis_size, &index, out_axes_seq, out_axes](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
|
||||
ValuePtr sub_out_axes = out_axes;
|
||||
if (out_axes->isa<ValueSequeue>()) {
|
||||
sub_out_axes = (*out_axes_seq)[index];
|
||||
|
@ -737,7 +736,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons
|
|||
return GetPhysicalViewAbs(arg_spec, sub_out_axes, axis_size);
|
||||
}
|
||||
if (sub_out_axes->isa<Int64Imm>()) {
|
||||
int axis = dyn_cast<Int64Imm>(sub_out_axes)->value();
|
||||
int axis = dyn_cast_ptr<Int64Imm>(sub_out_axes)->value();
|
||||
return ExtendDim(&axis, arg_spec, axis_size);
|
||||
} else if (sub_out_axes->isa<None>()) {
|
||||
return arg_spec;
|
||||
|
@ -763,7 +762,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons
|
|||
}
|
||||
|
||||
int axis = 0;
|
||||
auto axis_int_ptr = dyn_cast<Int64Imm>(sub_out_axes);
|
||||
auto axis_int_ptr = dyn_cast_ptr<Int64Imm>(sub_out_axes);
|
||||
if (axis_int_ptr != nullptr) {
|
||||
axis = LongToInt(axis_int_ptr->value());
|
||||
} else {
|
||||
|
@ -787,9 +786,9 @@ EvalResultPtr VmapEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &
|
|||
int axis_size = -1;
|
||||
int index = 0;
|
||||
auto in_axes = in_axes_;
|
||||
auto in_axes_seq = dyn_cast<ValueSequeue>(in_axes);
|
||||
auto in_axes_seq = dyn_cast_ptr<ValueSequeue>(in_axes);
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
|
||||
[&axis_size, &index, &in_axes_seq, in_axes](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
[&axis_size, &index, in_axes_seq, in_axes](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
AbstractBasePtr abs = conf->ObtainEvalResult()->abstract();
|
||||
// Drop the side effect tag parameters, because it has no mapping axis.
|
||||
|
|
|
@ -61,7 +61,7 @@ std::pair<bool, bool> InterpretAbstractBoolChecker(const AbstractBasePtr &cond)
|
|||
auto value = cond->BuildValue();
|
||||
if (value->isa<parse::InterpretedObject>()) {
|
||||
is_interpret = true;
|
||||
auto interpreted_obj = value->cast<parse::InterpretedObjectPtr>();
|
||||
auto interpreted_obj = value->cast_ptr<parse::InterpretedObject>();
|
||||
py::object obj = interpreted_obj->obj();
|
||||
constexpr char PYTHON_MOD_PARSE_MODULE[] = "mindspore._extends.parse";
|
||||
constexpr char PYTHON_MOD_CHECK_OBJ_BOOL[] = "check_obj_bool";
|
||||
|
@ -114,10 +114,10 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
});
|
||||
|
||||
// Do undetermined infer firstly.
|
||||
auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>();
|
||||
auto do_signature = prim_->cast_ptr<prim::DoSignaturePrimitive>();
|
||||
MS_EXCEPTION_IF_NULL(do_signature);
|
||||
auto &func = do_signature->function();
|
||||
auto do_signature_func = dyn_cast<Primitive>(func);
|
||||
auto do_signature_func = dyn_cast_ptr<Primitive>(func);
|
||||
if (do_signature_func != nullptr) {
|
||||
if (prims_to_skip_undetermined_infer.find(do_signature_func->name()) == prims_to_skip_undetermined_infer.end()) {
|
||||
auto ret_abstract = EvalUndeterminedArgs(args_spec_list);
|
||||
|
@ -168,11 +168,11 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
|
|||
for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) {
|
||||
MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]);
|
||||
if (specialize_args_before_unpack[index]->isa<AbstractTuple>()) {
|
||||
auto arg_tuple = specialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
|
||||
auto arg_tuple = specialize_args_before_unpack[index]->cast_ptr<AbstractTuple>();
|
||||
std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(),
|
||||
std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; });
|
||||
} else if (specialize_args_before_unpack[index]->isa<AbstractDictionary>()) {
|
||||
auto arg_dict = specialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
|
||||
auto arg_dict = specialize_args_before_unpack[index]->cast_ptr<AbstractDictionary>();
|
||||
auto dict_elems = arg_dict->elements();
|
||||
(void)std::transform(
|
||||
dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args),
|
||||
|
@ -197,9 +197,9 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
|
||||
auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
|
||||
auto unpack_graph = prim_->cast_ptr<prim::UnpackGraphPrimitive>();
|
||||
MS_EXCEPTION_IF_NULL(unpack_graph);
|
||||
auto out_node = out_conf->node()->cast<CNodePtr>();
|
||||
auto out_node = out_conf->node()->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
const auto &out_node_inputs = out_node->inputs();
|
||||
if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
|
||||
|
@ -219,11 +219,11 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
MS_LOG(EXCEPTION) << "args_spec_list can't be empty.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
auto fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
|
||||
auto fn = args_spec_list[0]->cast_ptr<AbstractFunction>();
|
||||
if (fn == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
|
||||
}
|
||||
auto real_fn = fn->cast<FuncGraphAbstractClosurePtr>();
|
||||
auto real_fn = fn->cast_ptr<FuncGraphAbstractClosure>();
|
||||
MS_EXCEPTION_IF_NULL(real_fn);
|
||||
FuncGraphPtr forward_graph = real_fn->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(forward_graph);
|
||||
|
@ -255,14 +255,14 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
AnfNodePtr target_node = source_node;
|
||||
if (node_type->isa<AbstractTensor>()) {
|
||||
auto x = node_type->cast<AbstractTensorPtr>();
|
||||
auto x = node_type->cast_ptr<AbstractTensor>();
|
||||
if (x->element()->BuildType()->isa<Float>()) {
|
||||
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
|
||||
MS_EXCEPTION_IF_NULL(cast);
|
||||
target_node = func_graph->NewCNodeAfter(source_node, {NewValueNode(cast), source_node, target_type});
|
||||
}
|
||||
} else if (node_type->isa<AbstractTuple>()) {
|
||||
auto x = node_type->cast<AbstractTuplePtr>();
|
||||
auto x = node_type->cast_ptr<AbstractTuple>();
|
||||
auto &items = x->elements();
|
||||
std::vector<AnfNodePtr> nodes;
|
||||
nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
|
@ -276,7 +276,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
|
|||
}
|
||||
target_node = func_graph->NewCNode(nodes);
|
||||
} else if (node_type->isa<AbstractDictionary>()) {
|
||||
auto x = node_type->cast<AbstractDictionaryPtr>();
|
||||
auto x = node_type->cast_ptr<AbstractDictionary>();
|
||||
auto &items = x->elements();
|
||||
std::vector<AnfNodePtr> dict_key_nodes;
|
||||
std::vector<AnfNodePtr> dict_value_nodes;
|
||||
|
@ -293,7 +293,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
|
|||
func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(std::move(dict_key_nodes)),
|
||||
func_graph->NewCNode(std::move(dict_value_nodes))});
|
||||
} else if (node_type->isa<AbstractKeywordArg>()) {
|
||||
auto x = node_type->cast<AbstractKeywordArgPtr>();
|
||||
auto x = node_type->cast_ptr<AbstractKeywordArg>();
|
||||
std::string kwarg_key = x->get_key();
|
||||
AnfNodePtr kwarg_value_node =
|
||||
func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node});
|
||||
|
@ -336,7 +336,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
|
|||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
|
||||
|
||||
if (new_node->isa<CNode>()) {
|
||||
auto new_cnode = new_node->cast<CNodePtr>();
|
||||
auto new_cnode = new_node->cast_ptr<CNode>();
|
||||
new_cnode->CloneCNodeInfo(out_node);
|
||||
}
|
||||
return engine->ForwardConfig(out_conf, fn_conf);
|
||||
|
@ -439,7 +439,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base, bool only_conver
|
|||
if (dyn_shape_value) {
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
if (!res[i].contains(py::str(ATTR_SHAPE_VALUE))) {
|
||||
auto const_abstract_value = arg_tuple->elements()[i]->cast<abstract::AbstractTensorPtr>();
|
||||
auto const_abstract_value = arg_tuple->elements()[i]->cast_ptr<abstract::AbstractTensor>();
|
||||
MS_EXCEPTION_IF_NULL(const_abstract_value);
|
||||
auto const_value = const_abstract_value->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(const_value);
|
||||
|
@ -547,7 +547,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert
|
|||
if (shape_value) {
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
if (!res[i].contains(py::str(ATTR_SHAPE_VALUE))) {
|
||||
auto const_abstract_value = arg_list->elements()[i]->cast<abstract::AbstractTensorPtr>();
|
||||
auto const_abstract_value = arg_list->elements()[i]->cast_ptr<abstract::AbstractTensor>();
|
||||
MS_EXCEPTION_IF_NULL(const_abstract_value);
|
||||
auto const_value = const_abstract_value->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(const_value);
|
||||
|
@ -565,7 +565,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert
|
|||
}
|
||||
|
||||
void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, bool only_convert_value, py::dict *dic) {
|
||||
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
|
||||
auto arg_tensor = dyn_cast_ptr<AbstractTensor>(abs_base);
|
||||
MS_EXCEPTION_IF_NULL(dic);
|
||||
MS_EXCEPTION_IF_NULL(arg_tensor);
|
||||
if (only_convert_value) {
|
||||
|
@ -601,16 +601,16 @@ py::object GetPyObjForPrimitiveAbstract(const PrimitiveAbstractClosurePtr &prim_
|
|||
return py::none();
|
||||
}
|
||||
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
||||
auto do_sig_prim = prim->cast<prim::DoSignaturePrimitivePtr>();
|
||||
auto do_sig_prim = prim->cast_ptr<prim::DoSignaturePrimitive>();
|
||||
auto value = do_sig_prim->function();
|
||||
if (!value->isa<PrimitivePy>()) {
|
||||
return py::none();
|
||||
}
|
||||
auto prim_py = value->cast<PrimitivePyPtr>();
|
||||
auto prim_py = value->cast_ptr<PrimitivePy>();
|
||||
return prim_py->GetPyObj();
|
||||
}
|
||||
if (prim->isa<PrimitivePy>()) {
|
||||
auto prim_py = prim->cast<PrimitivePyPtr>();
|
||||
auto prim_py = prim->cast_ptr<PrimitivePy>();
|
||||
return prim_py->GetPyObj();
|
||||
}
|
||||
return py::none();
|
||||
|
@ -621,7 +621,7 @@ bool IsCallInstance(const PartialAbstractClosurePtr &partial_abs) {
|
|||
if (!fn->isa<PrimitiveAbstractClosure>()) {
|
||||
return false;
|
||||
}
|
||||
auto abs_prim = fn->cast<PrimitiveAbstractClosurePtr>();
|
||||
auto abs_prim = fn->cast_ptr<PrimitiveAbstractClosure>();
|
||||
auto prim = abs_prim->prim();
|
||||
if (prim->name() == prim::kPrimCallInstance->name()) {
|
||||
return true;
|
||||
|
@ -643,14 +643,14 @@ void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *
|
|||
auto value = args[0]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (IsCallInstance(partial_abs)) {
|
||||
auto value_obj = value->cast<parse::MsClassObjectPtr>();
|
||||
auto value_obj = value->cast_ptr<parse::MsClassObject>();
|
||||
if (value_obj != nullptr) {
|
||||
(*dic)[ATTR_DTYPE] = std::make_shared<MsClassType>();
|
||||
(*dic)[ATTR_VALUE] = value_obj->obj();
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto value_obj = value->cast<parse::ClassTypePtr>();
|
||||
auto value_obj = value->cast_ptr<parse::ClassType>();
|
||||
if (value_obj != nullptr) {
|
||||
(*dic)[ATTR_DTYPE] = std::make_shared<TypeType>();
|
||||
(*dic)[ATTR_VALUE] = value_obj->obj();
|
||||
|
@ -708,35 +708,35 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_conv
|
|||
} else if (abs_base->isa<AbstractList>()) {
|
||||
return AbstractListToPython(abs_base, only_convert_value);
|
||||
} else if (abs_base->isa<AbstractSlice>()) {
|
||||
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
|
||||
auto arg_slice = dyn_cast_ptr<AbstractSlice>(abs_base);
|
||||
ShapeVector shape;
|
||||
dic[ATTR_SHAPE] = shape;
|
||||
dic[ATTR_DTYPE] = arg_slice->BuildType();
|
||||
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
|
||||
} else if (abs_base->isa<AbstractRowTensor>()) {
|
||||
auto arg = dyn_cast<AbstractRowTensor>(abs_base);
|
||||
auto arg = dyn_cast_ptr<AbstractRowTensor>(abs_base);
|
||||
dic[ATTR_SHAPE] = arg->shape()->shape();
|
||||
dic[ATTR_DTYPE] = arg->BuildType();
|
||||
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
|
||||
} else if (abs_base->isa<AbstractCOOTensor>()) {
|
||||
auto arg = dyn_cast<AbstractCOOTensor>(abs_base);
|
||||
auto arg = dyn_cast_ptr<AbstractCOOTensor>(abs_base);
|
||||
AbstractBasePtrList sparse_shape = arg->shape()->elements();
|
||||
ShapeVector sparse_shape_vector;
|
||||
(void)std::transform(sparse_shape.begin(), sparse_shape.end(), std::back_inserter(sparse_shape_vector),
|
||||
[](const AbstractBasePtr &e) -> int64_t {
|
||||
ValuePtr value = e->cast<AbstractScalarPtr>()->BuildValue();
|
||||
ValuePtr value = e->cast_ptr<AbstractScalar>()->BuildValue();
|
||||
return GetValue<int64_t>(value);
|
||||
});
|
||||
dic[ATTR_SHAPE] = sparse_shape_vector;
|
||||
dic[ATTR_DTYPE] = arg->BuildType();
|
||||
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
|
||||
} else if (abs_base->isa<AbstractCSRTensor>()) {
|
||||
auto arg = dyn_cast<AbstractCSRTensor>(abs_base);
|
||||
auto arg = dyn_cast_ptr<AbstractCSRTensor>(abs_base);
|
||||
AbstractBasePtrList sparse_shape = arg->shape()->elements();
|
||||
ShapeVector sparse_shape_vector;
|
||||
(void)std::transform(sparse_shape.begin(), sparse_shape.end(), std::back_inserter(sparse_shape_vector),
|
||||
[](const AbstractBasePtr &e) -> int64_t {
|
||||
ValuePtr value = e->cast<AbstractScalarPtr>()->BuildValue();
|
||||
ValuePtr value = e->cast_ptr<AbstractScalar>()->BuildValue();
|
||||
return GetValue<int64_t>(value);
|
||||
});
|
||||
dic[ATTR_SHAPE] = sparse_shape_vector;
|
||||
|
@ -753,7 +753,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_conv
|
|||
} else if (abs_base->isa<AbstractFunction>()) {
|
||||
ConvertAbstractFunctionToPython(abs_base, &dic);
|
||||
} else if (abs_base->isa<AbstractUndetermined>()) {
|
||||
auto arg = dyn_cast<AbstractUndetermined>(abs_base);
|
||||
auto arg = dyn_cast_ptr<AbstractUndetermined>(abs_base);
|
||||
dic[ATTR_SHAPE] = py::none();
|
||||
dic[ATTR_DTYPE] = arg->BuildType();
|
||||
dic[ATTR_VALUE] = py::none();
|
||||
|
@ -804,7 +804,7 @@ void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBa
|
|||
<< "]'s attribute[output_num]:" << output_num << " not matches the infer result "
|
||||
<< res_spec->ToString();
|
||||
} else if (res_spec->isa<AbstractTuple>() &&
|
||||
(res_spec->cast<AbstractTuplePtr>()->size() != LongToSize(output_num))) {
|
||||
(res_spec->cast_ptr<AbstractTuple>()->size() != LongToSize(output_num))) {
|
||||
MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString()
|
||||
<< "]'s attribute[output_num]:" << output_num << " not matches the infer result "
|
||||
<< res_spec->ToString();
|
||||
|
@ -833,7 +833,7 @@ void SetShapeValue(const AbstractBasePtr &tensor, const py::object &output) {
|
|||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Convert shape max value data failed";
|
||||
}
|
||||
auto abs_tensor = dyn_cast<abstract::AbstractTensor>(tensor);
|
||||
auto abs_tensor = dyn_cast_ptr<abstract::AbstractTensor>(tensor);
|
||||
abs_tensor->set_value_range(min_value, max_value);
|
||||
}
|
||||
if (!output.contains(py::str(ATTR_SHAPE_VALUE))) {
|
||||
|
@ -849,7 +849,7 @@ void SetShapeValue(const AbstractBasePtr &tensor, const py::object &output) {
|
|||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Convert shape value data failed";
|
||||
}
|
||||
auto abs_tensor = dyn_cast<abstract::AbstractTensor>(tensor);
|
||||
auto abs_tensor = dyn_cast_ptr<abstract::AbstractTensor>(tensor);
|
||||
abs_tensor->set_shape_value(shape_value);
|
||||
}
|
||||
|
||||
|
@ -1023,7 +1023,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &,
|
|||
MS_EXCEPTION_IF_NULL(res_spec);
|
||||
if (res_spec->isa<AbstractTensor>()) {
|
||||
// Replace to tensor constant node in specialize
|
||||
auto res_tensor = res_spec->cast<AbstractTensorPtr>();
|
||||
auto res_tensor = res_spec->cast_ptr<AbstractTensor>();
|
||||
res_tensor->set_value(converted_ret);
|
||||
}
|
||||
return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
|
||||
|
@ -1248,8 +1248,8 @@ EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Ab
|
|||
// Check if all arguments are scalar type.
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (arg->isa<AbstractScalar>()) {
|
||||
auto arg_scalar = dyn_cast<AbstractScalar>(arg);
|
||||
auto arg_value = arg_scalar->GetValueTrack();
|
||||
auto arg_scalar = dyn_cast_ptr<AbstractScalar>(arg);
|
||||
const auto &arg_value = arg_scalar->GetValueTrack();
|
||||
value_list.push_back(arg_value);
|
||||
} else {
|
||||
// Raise TypeError Expected Scalar.
|
||||
|
@ -1335,18 +1335,18 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
|
|||
AbstractBasePtr abstract = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
|
||||
// Create new cnode
|
||||
std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
|
||||
auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abstract);
|
||||
auto func_graph_func = dyn_cast_ptr<abstract::FuncGraphAbstractClosure>(abstract);
|
||||
if (func_graph_func != nullptr) {
|
||||
FuncGraphPtr fg = func_graph_func->func_graph();
|
||||
input.push_back(NewValueNode(fg));
|
||||
} else {
|
||||
auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abstract);
|
||||
auto prim_func = dyn_cast_ptr<abstract::PrimitiveAbstractClosure>(abstract);
|
||||
MS_EXCEPTION_IF_NULL(prim_func);
|
||||
PrimitivePtr prim = prim_func->prim();
|
||||
input.push_back(NewValueNode(prim));
|
||||
}
|
||||
|
||||
AnfNodeConfigPtr conf = dyn_cast<abstract::AnfNodeConfig>(data_conf);
|
||||
auto conf = dyn_cast_ptr<abstract::AnfNodeConfig>(data_conf);
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
input.push_back(conf->node());
|
||||
MS_EXCEPTION_IF_NULL(old_conf);
|
||||
|
@ -1367,15 +1367,15 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con
|
|||
MS_EXCEPTION_IF_NULL(attr_value);
|
||||
ValuePtr item_value = attr_value;
|
||||
if (item_value->isa<StringImm>()) {
|
||||
item_value = std::make_shared<parse::Symbol>(item_value->cast<StringImmPtr>()->value());
|
||||
item_value = std::make_shared<parse::Symbol>(item_value->cast_ptr<StringImm>()->value());
|
||||
}
|
||||
if (!item_value->isa<parse::Symbol>()) {
|
||||
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
|
||||
}
|
||||
|
||||
// item_name to func addr from obj_map
|
||||
parse::SymbolPtr symbol = item_value->cast<parse::SymbolPtr>();
|
||||
parse::NameSpacePtr name_space = data_value->cast<parse::NameSpacePtr>();
|
||||
auto symbol = item_value->cast<parse::SymbolPtr>();
|
||||
auto name_space = data_value->cast<parse::NameSpacePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
auto out_node = out_conf->node();
|
||||
FuncGraphPtr func_graph = out_node->func_graph();
|
||||
|
@ -1429,12 +1429,12 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &,
|
|||
if (!item_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
|
||||
}
|
||||
std::string item_name = item_value->cast<StringImmPtr>()->value();
|
||||
const auto &item_name = item_value->cast_ptr<StringImm>()->value();
|
||||
// Get ms_class object.
|
||||
if (!data_value->isa<parse::MsClassObject>()) {
|
||||
MS_LOG(EXCEPTION) << "Expect a ms_class object, but got " << data_value->ToString();
|
||||
}
|
||||
auto ms_class = data_value->cast<parse::MsClassObjectPtr>();
|
||||
auto ms_class = data_value->cast_ptr<parse::MsClassObject>();
|
||||
MS_LOG(DEBUG) << "Resolve ms_class (" << ms_class->name() << ") with item " << item_name << ".";
|
||||
|
||||
// Get the attr/method of ms_class object.
|
||||
|
@ -1461,7 +1461,7 @@ EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AnalysisEnginePtr &engi
|
|||
if (python_obj == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto wrapper_obj = dyn_cast<parse::PyObjectWrapper>(python_obj);
|
||||
auto wrapper_obj = dyn_cast_ptr<parse::PyObjectWrapper>(python_obj);
|
||||
MS_EXCEPTION_IF_NULL(wrapper_obj);
|
||||
py::object real_python_obj = wrapper_obj->obj();
|
||||
MS_LOG(DEBUG) << "item_value: " << item_value->ToString() << ", func_value: " << func_value->ToString()
|
||||
|
@ -1485,7 +1485,7 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
|
|||
if (!item_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
|
||||
}
|
||||
std::string item_name = item_value->cast<StringImmPtr>()->value();
|
||||
std::string item_name = item_value->cast_ptr<StringImm>()->value();
|
||||
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
|
||||
Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
|
||||
if (require.empty()) {
|
||||
|
@ -1520,7 +1520,7 @@ ValuePtr GetMsClassObject(const AbstractBasePtr &abs) {
|
|||
if (!abs->isa<abstract::PartialAbstractClosure>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto partial_abs = abs->cast<abstract::PartialAbstractClosurePtr>();
|
||||
auto partial_abs = abs->cast_ptr<abstract::PartialAbstractClosure>();
|
||||
auto fn = partial_abs->fn();
|
||||
if (!fn->isa<abstract::PrimitiveAbstractClosure>()) {
|
||||
return nullptr;
|
||||
|
@ -1568,7 +1568,7 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
if (class_value != nullptr) {
|
||||
return GetEvaluatedValueForMsClassAttrOrMethod(engine, item_value, class_value, data_conf, out_conf);
|
||||
}
|
||||
auto data_func_graph = dyn_cast<FuncGraphAbstractClosure>(data_args);
|
||||
auto data_func_graph = dyn_cast_ptr<FuncGraphAbstractClosure>(data_args);
|
||||
if (data_func_graph != nullptr) {
|
||||
auto res =
|
||||
GetEvaluatedValueForCellAttrOrMethod(engine, item_value, data_func_graph->func_graph(), data_conf, out_conf);
|
||||
|
@ -1642,7 +1642,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
if (args_conf_list.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
|
||||
}
|
||||
AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
|
||||
auto node_conf = dyn_cast_ptr<AnfNodeConfig>(args_conf_list[0]);
|
||||
MS_EXCEPTION_IF_NULL(node_conf);
|
||||
MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
|
||||
AbstractBasePtr x = node_conf->ObtainEvalResult()->abstract();
|
||||
|
@ -1662,7 +1662,7 @@ static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager,
|
|||
const FuncGraphPtr &root_g = root_g_set.back();
|
||||
for (auto ¶m_node : root_g->parameters()) {
|
||||
auto param = param_node->cast<ParameterPtr>();
|
||||
if (param && name == param->name()) {
|
||||
if (param != nullptr && param->name() == name) {
|
||||
return param;
|
||||
}
|
||||
}
|
||||
|
@ -1680,7 +1680,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
return nullptr;
|
||||
}
|
||||
static TypePtr type = std::make_shared<SymbolicKeyType>();
|
||||
auto node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
|
||||
auto node_conf = dyn_cast_ptr<AnfNodeConfig>(args_conf_list[0]);
|
||||
if (node_conf == nullptr) {
|
||||
MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
|
||||
return nullptr;
|
||||
|
@ -1688,7 +1688,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
|
||||
AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
|
||||
auto ref_abs = abs->cast_ptr<AbstractRefTensor>();
|
||||
if (ref_abs == nullptr) {
|
||||
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
|
||||
return nullptr;
|
||||
|
@ -1701,11 +1701,11 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
// Only after that funcgrpah is inlined, the RefToEmbed CNode should be evaluated to specific SymbolicKey.
|
||||
bool ifEmbedIsWeight = false;
|
||||
if (node_conf->node() != nullptr && node_conf->node()->isa<Parameter>()) {
|
||||
auto param = node_conf->node()->cast<ParameterPtr>();
|
||||
auto param = node_conf->node()->cast_ptr<Parameter>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
ifEmbedIsWeight = param->has_default();
|
||||
}
|
||||
auto refkey = ref_abs->ref_key_value()->cast<StringImmPtr>();
|
||||
auto refkey = ref_abs->ref_key_value()->cast_ptr<StringImm>();
|
||||
if (refkey == nullptr || !ifEmbedIsWeight) {
|
||||
auto ret = std::make_shared<AbstractScalar>(type);
|
||||
auto ref_value = ref_abs->ref();
|
||||
|
@ -1788,12 +1788,12 @@ class ResolveEvaluator : public TransitionPrimEvaluator {
|
|||
|
||||
bool IsContainUndetermined(const AbstractBasePtr &arg) {
|
||||
if (arg->isa<AbstractSequence>()) {
|
||||
auto seq_arg = arg->cast<AbstractSequencePtr>();
|
||||
auto seq_arg = arg->cast_ptr<AbstractSequence>();
|
||||
return std::any_of(seq_arg->elements().begin(), seq_arg->elements().end(), IsContainUndetermined);
|
||||
}
|
||||
|
||||
if (arg->isa<AbstractKeywordArg>()) {
|
||||
auto kw_arg = arg->cast<AbstractKeywordArgPtr>();
|
||||
auto kw_arg = arg->cast_ptr<AbstractKeywordArg>();
|
||||
return IsContainUndetermined(kw_arg->get_arg());
|
||||
}
|
||||
|
||||
|
@ -1824,7 +1824,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
|
||||
ValuePtr value_track = arg_class_type->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
parse::PyObjectWrapperPtr type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
|
||||
auto type_obj = dyn_cast_ptr<parse::PyObjectWrapper>(value_track);
|
||||
if (type_obj == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
|
||||
}
|
||||
|
@ -1911,7 +1911,7 @@ class CallInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
}
|
||||
ValuePtr value_track = arg_cls->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
parse::MsClassObjectPtr ms_class = dyn_cast<parse::MsClassObject>(value_track);
|
||||
auto ms_class = dyn_cast_ptr<parse::MsClassObject>(value_track);
|
||||
if (ms_class == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "CallInstanceEvaluator only supports MsClassObject.";
|
||||
}
|
||||
|
@ -1933,7 +1933,7 @@ class CallInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
// Replace net with net.__call__
|
||||
AnfNodePtr old_node = out_conf->node();
|
||||
MS_EXCEPTION_IF_NULL(old_node);
|
||||
CNodePtr old_cnode = dyn_cast<CNode>(old_node);
|
||||
auto old_cnode = dyn_cast_ptr<CNode>(old_node);
|
||||
MS_EXCEPTION_IF_NULL(old_cnode);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(call_func_graph)};
|
||||
for (size_t i = 1; i < old_cnode->size(); i++) {
|
||||
|
@ -1966,7 +1966,7 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
|
||||
std::shared_ptr<parse::Script> script_obj = dyn_cast<parse::Script>(value_track);
|
||||
auto script_obj = dyn_cast_ptr<parse::Script>(value_track);
|
||||
if (script_obj == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
|
||||
}
|
||||
|
@ -2027,12 +2027,12 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
const auto &element_name = element.first;
|
||||
const auto &element_abs = element.second;
|
||||
if (element_abs->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
const auto &element_abs_fn = element_abs->cast<abstract::FuncGraphAbstractClosurePtr>();
|
||||
const auto &fg = element_abs_fn->func_graph();
|
||||
auto element_abs_fn = element_abs->cast_ptr<abstract::FuncGraphAbstractClosure>();
|
||||
auto fg = element_abs_fn->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto wrapper_obj = fg->python_obj();
|
||||
if (wrapper_obj != nullptr && wrapper_obj->isa<parse::PyObjectWrapper>()) {
|
||||
auto fn_py_obj = wrapper_obj->cast<parse::PyObjectWrapperPtr>()->obj();
|
||||
auto fn_py_obj = wrapper_obj->cast_ptr<parse::PyObjectWrapper>()->obj();
|
||||
(*global_params_dict)[py::str(element_name)] = fn_py_obj;
|
||||
MS_LOG(DEBUG) << "Found global python function object for " << element_name << ", add it to global dict.";
|
||||
}
|
||||
|
@ -2138,10 +2138,10 @@ class PartialEvaluator : public Evaluator {
|
|||
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
|
||||
// Sometimes, node[0] in out_conf becomes phi0;
|
||||
if (func->isa<PrimitiveAbstractClosure>()) {
|
||||
auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
|
||||
auto prim_func = dyn_cast_ptr<PrimitiveAbstractClosure>(func);
|
||||
MS_EXCEPTION_IF_NULL(prim_func->prim());
|
||||
if (prim_func->prim()->isa<prim::DoSignaturePrimitive>()) {
|
||||
prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(prim_func->prim());
|
||||
auto do_signature_prim = dyn_cast_ptr<prim::DoSignaturePrimitive>(prim_func->prim());
|
||||
return HandleDoSignature(engine, do_signature_prim->function(), out_conf);
|
||||
}
|
||||
}
|
||||
|
@ -2179,7 +2179,7 @@ class PartialEvaluator : public Evaluator {
|
|||
MS_EXCEPTION_IF_NULL(engine);
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
auto cnode = out_conf->node()->cast<CNodePtr>();
|
||||
auto cnode = out_conf->node()->cast_ptr<CNode>();
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cnode is nullptr";
|
||||
}
|
||||
|
@ -2234,7 +2234,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
std::string exception_string;
|
||||
// Processed in units of nodes. Raise ValueError(xxxx)
|
||||
size_t index_begin = 2;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto cnode = node->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto inputs = cnode->inputs();
|
||||
bool need_out_symbol = inputs.size() > 3;
|
||||
|
@ -2268,17 +2268,17 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
bool CheckNeedSymbol(const AnfNodePtr &, const AbstractBasePtr &abs) const {
|
||||
bool need_symbol = false;
|
||||
if (abs->isa<abstract::AbstractScalar>()) {
|
||||
auto scalar = abs->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar = abs->cast_ptr<abstract::AbstractScalar>();
|
||||
auto scalar_value = scalar->BuildValue();
|
||||
if (scalar_value->isa<StringImm>()) {
|
||||
need_symbol = true;
|
||||
}
|
||||
} else if (abs->isa<abstract::AbstractSequence>()) {
|
||||
auto abs_list = abs->cast<abstract::AbstractSequencePtr>();
|
||||
auto abs_list = abs->cast_ptr<abstract::AbstractSequence>();
|
||||
const auto &elements = abs_list->elements();
|
||||
for (auto &element : elements) {
|
||||
if (element->isa<abstract::AbstractScalar>()) {
|
||||
auto scalar = element->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar = element->cast_ptr<abstract::AbstractScalar>();
|
||||
auto scalar_value = scalar->BuildValue();
|
||||
if (scalar_value->isa<StringImm>()) {
|
||||
need_symbol = true;
|
||||
|
@ -2310,7 +2310,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
std::string GetTupleString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) {
|
||||
std::string exception_str;
|
||||
// Process raise ValueError("str")
|
||||
auto arg_tuple = arg->cast<abstract::AbstractTuplePtr>();
|
||||
auto arg_tuple = arg->cast_ptr<abstract::AbstractTuple>();
|
||||
const auto &arg_tuple_elements = arg_tuple->elements();
|
||||
if (arg_tuple_elements.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The arg_tuple_elements can't be empty.";
|
||||
|
@ -2334,7 +2334,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
std::string GetListString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) {
|
||||
std::string exception_str;
|
||||
// Process raise ValueError("str")
|
||||
auto arg_list = arg->cast<abstract::AbstractListPtr>();
|
||||
auto arg_list = arg->cast_ptr<abstract::AbstractList>();
|
||||
const auto &arg_list_elements = arg_list->elements();
|
||||
if (arg_list_elements.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The arg_list_elements can't be empty.";
|
||||
|
@ -2358,7 +2358,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
std::string GetExceptionType(const AbstractBasePtr &abs) const {
|
||||
std::string str;
|
||||
if (abs->isa<abstract::AbstractScalar>()) {
|
||||
auto scalar = abs->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar = abs->cast_ptr<abstract::AbstractScalar>();
|
||||
auto scalar_value = scalar->BuildValue();
|
||||
if (scalar_value->isa<StringImm>()) {
|
||||
str = GetValue<std::string>(scalar_value);
|
||||
|
@ -2491,8 +2491,8 @@ namespace {
|
|||
bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) {
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(model);
|
||||
auto x_tuple = dyn_cast<AbstractTuple>(x);
|
||||
auto model_tuple = dyn_cast<Tuple>(model);
|
||||
auto x_tuple = dyn_cast_ptr<AbstractTuple>(x);
|
||||
auto model_tuple = dyn_cast_ptr<Tuple>(model);
|
||||
|
||||
if (x_tuple == nullptr || model_tuple == nullptr) {
|
||||
return false;
|
||||
|
@ -2518,8 +2518,8 @@ bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) {
|
|||
bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) {
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(model);
|
||||
auto x_tensor = dyn_cast<AbstractTensor>(x);
|
||||
auto model_tensor = dyn_cast<TensorType>(model);
|
||||
auto x_tensor = dyn_cast_ptr<AbstractTensor>(x);
|
||||
auto model_tensor = dyn_cast_ptr<TensorType>(model);
|
||||
|
||||
if (x_tensor == nullptr || model_tensor == nullptr) {
|
||||
return false;
|
||||
|
@ -2535,8 +2535,8 @@ bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) {
|
|||
bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) {
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(model);
|
||||
auto x_list = dyn_cast<AbstractList>(x);
|
||||
auto model_list = dyn_cast<List>(model);
|
||||
auto x_list = dyn_cast_ptr<AbstractList>(x);
|
||||
auto model_list = dyn_cast_ptr<List>(model);
|
||||
|
||||
if (x_list == nullptr || model_list == nullptr) {
|
||||
return false;
|
||||
|
@ -2563,10 +2563,10 @@ bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) {
|
|||
inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) {
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(model);
|
||||
if (dyn_cast<AbstractScalar>(x) == nullptr) {
|
||||
if (dyn_cast_ptr<AbstractScalar>(x) == nullptr) {
|
||||
return false;
|
||||
}
|
||||
TypePtr x_type = x->GetTypeTrack();
|
||||
auto &x_type = x->GetTypeTrack();
|
||||
return IsSubType(x_type, model);
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -63,7 +63,7 @@ bool CanSpecializeValueNode(const AnfNodePtr &node) {
|
|||
}
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
if (node->abstract() != nullptr) {
|
||||
auto abs_func = node->abstract()->cast<FuncGraphAbstractClosurePtr>();
|
||||
auto abs_func = node->abstract()->cast_ptr<FuncGraphAbstractClosure>();
|
||||
// If this funcgraph had specialized in ProcessCNode of FirstPass,
|
||||
// then ignore it.
|
||||
if (abs_func != nullptr && abs_func->specialized()) {
|
||||
|
@ -110,12 +110,12 @@ void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) {
|
|||
}
|
||||
|
||||
// Handle MakeTuple/MakeList CNode.
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
auto cnode = dyn_cast_ptr<CNode>(node);
|
||||
if (cnode != nullptr) {
|
||||
if (pos + 1 >= cnode->inputs().size()) {
|
||||
continue;
|
||||
}
|
||||
auto input_value = GetValueNode<StringImmPtr>(cnode->input(pos + 1));
|
||||
auto input_value = GetValuePtr<StringImm>(cnode->input(pos + 1));
|
||||
if (input_value == nullptr || input_value->value() != kDeadNodeName) {
|
||||
continue;
|
||||
}
|
||||
|
@ -129,7 +129,7 @@ void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) {
|
|||
|
||||
// Change the abstract.
|
||||
(*flags)[pos] = false; // Change the use flag as 0.
|
||||
auto sequence_abs = dyn_cast<AbstractSequence>(node->abstract());
|
||||
auto sequence_abs = dyn_cast_ptr<AbstractSequence>(node->abstract());
|
||||
if (sequence_abs != nullptr && !sequence_abs->PurifyElements()) {
|
||||
MS_LOG(ERROR) << "Purify elements failed, abstract: " << sequence_abs->ToString()
|
||||
<< ", node: " << node->DebugString(recursive_level);
|
||||
|
@ -138,13 +138,13 @@ void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) {
|
|||
}
|
||||
// Handle ValueTuple/ValueList.
|
||||
if (IsValueNode<ValueTuple>(node) || IsValueNode<ValueList>(node)) {
|
||||
auto sequence_value = GetValueNode<ValueSequencePtr>(node);
|
||||
auto sequence_value = GetValuePtr<ValueSequence>(node);
|
||||
MS_EXCEPTION_IF_NULL(sequence_value);
|
||||
if (pos >= sequence_value->value().size()) {
|
||||
continue;
|
||||
}
|
||||
ValuePtr element_value = sequence_value->value()[pos];
|
||||
auto element_str_value = element_value->cast<StringImmPtr>();
|
||||
auto element_str_value = element_value->cast_ptr<StringImm>();
|
||||
if (element_str_value == nullptr || element_str_value->value() != kDeadNodeName) {
|
||||
continue;
|
||||
}
|
||||
|
@ -157,7 +157,7 @@ void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) {
|
|||
|
||||
// Change the abstract.
|
||||
(*flags)[pos] = false; // Change the use flag as 0.
|
||||
auto sequence_abs = dyn_cast<AbstractSequence>(node->abstract());
|
||||
auto sequence_abs = dyn_cast_ptr<AbstractSequence>(node->abstract());
|
||||
if (sequence_abs != nullptr && !sequence_abs->PurifyElements()) {
|
||||
constexpr int recursive_level = 2;
|
||||
MS_LOG(ERROR) << "Purify elements failed, abstract: " << sequence_abs->ToString()
|
||||
|
@ -250,7 +250,7 @@ AbstractBasePtr ProgramSpecializer::SpecializeAbstractFuncRecursively(const Abst
|
|||
auto build_new_abs = [this, &func_atoms](const AbstractFuncAtomPtr &poss) {
|
||||
auto resolved_atom = poss;
|
||||
if (poss->isa<AsyncAbstractFuncAtom>()) {
|
||||
const auto &async_abs_func = poss->cast<AsyncAbstractFuncAtomPtr>();
|
||||
auto async_abs_func = poss->cast_ptr<AsyncAbstractFuncAtom>();
|
||||
const auto &resolved_func = async_abs_func->GetUnique();
|
||||
resolved_atom = resolved_func->cast<AbstractFuncAtomPtr>();
|
||||
MS_EXCEPTION_IF_NULL(resolved_atom);
|
||||
|
@ -306,7 +306,7 @@ void ProgramSpecializer::SpecializeCNodeInput0FuncGraph() {
|
|||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto &input0 = node->cast<CNodePtr>()->input(0);
|
||||
auto &input0 = node->cast_ptr<CNode>()->input(0);
|
||||
MS_EXCEPTION_IF_NULL(input0);
|
||||
if (IsValueNode<FuncGraph>(input0)) {
|
||||
continue;
|
||||
|
@ -383,7 +383,7 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod
|
|||
|
||||
void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto c_node = node->cast<CNodePtr>();
|
||||
auto c_node = node->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(c_node);
|
||||
auto inputs = c_node->inputs();
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
|
@ -403,7 +403,7 @@ void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const An
|
|||
return new_inp;
|
||||
});
|
||||
|
||||
auto c_new_node = new_node->cast<CNodePtr>();
|
||||
auto c_new_node = new_node->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(c_new_node);
|
||||
c_new_node->set_inputs(new_inputs);
|
||||
}
|
||||
|
@ -534,7 +534,7 @@ void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node,
|
|||
if (new_node == old_node) {
|
||||
return;
|
||||
}
|
||||
AbstractSequencePtr old_sequence_abs = dyn_cast<AbstractSequence>(old_abs);
|
||||
auto old_sequence_abs = dyn_cast_ptr<AbstractSequence>(old_abs);
|
||||
if (old_sequence_abs == nullptr || old_sequence_abs->sequence_nodes() == nullptr ||
|
||||
old_sequence_abs->sequence_nodes()->empty()) {
|
||||
MS_LOG(DEBUG) << "No sequence node in old abs, " << old_node->DebugString() << " --> " << new_node->DebugString();
|
||||
|
@ -590,7 +590,7 @@ void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node,
|
|||
}
|
||||
MS_LOG(ERROR) << "New abstract, " << old_node->DebugString() << " --> " << new_node->DebugString()
|
||||
<< ", elements_use_flags: " << (*flags);
|
||||
AbstractSequencePtr new_sequence_abs = dyn_cast<AbstractSequence>(new_abs);
|
||||
auto new_sequence_abs = dyn_cast_ptr<AbstractSequence>(new_abs);
|
||||
if (new_sequence_abs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "New node should be sequence type as well, but got " << new_abs->ToString();
|
||||
}
|
||||
|
@ -608,7 +608,7 @@ void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node,
|
|||
template <typename T>
|
||||
void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecializer *const specializer) {
|
||||
const auto &old_input = cnode->input(index);
|
||||
auto sequence_value = GetValueNode<std::shared_ptr<T>>(old_input);
|
||||
auto sequence_value = GetValuePtr<T>(old_input);
|
||||
if (sequence_value == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
@ -625,7 +625,7 @@ void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecial
|
|||
}
|
||||
for (size_t i = 0; i < sequence_value_size; ++i) {
|
||||
ValuePtr old_sequence_value = sequence_value->value()[i];
|
||||
auto old_sequence_str_value = old_sequence_value->cast<StringImmPtr>();
|
||||
auto old_sequence_str_value = old_sequence_value->cast_ptr<StringImm>();
|
||||
if (!(*flags)[i]) {
|
||||
auto zero = MakeValue(0);
|
||||
(void)elements.emplace_back(zero);
|
||||
|
@ -643,7 +643,7 @@ void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecial
|
|||
auto new_sequence_value = std::make_shared<T>(elements);
|
||||
auto new_input = NewValueNode(new_sequence_value);
|
||||
auto new_input_abs = new_sequence_value->ToAbstract();
|
||||
AbstractSequencePtr new_sequence_abs = dyn_cast<AbstractSequence>(new_input_abs);
|
||||
auto new_sequence_abs = dyn_cast<AbstractSequence>(new_input_abs);
|
||||
MS_EXCEPTION_IF_NULL(new_sequence_abs);
|
||||
std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
|
||||
(void)sequence_nodes->emplace_back(AnfNodeWeakPtr(new_input));
|
||||
|
@ -704,7 +704,7 @@ void FuncGraphSpecializer::EliminateUnusedSequenceItem(const CNodePtr &cnode) co
|
|||
(void)inputs.emplace_back(cnode->input(0));
|
||||
for (size_t i = 0; i < (*flags).size(); ++i) {
|
||||
auto old_input = cnode->input(i + 1);
|
||||
auto old_input_value = GetValueNode<StringImmPtr>(old_input);
|
||||
auto old_input_value = GetValuePtr<StringImm>(old_input);
|
||||
if (!(*flags)[i]) {
|
||||
auto zero_value = NewValueNode(MakeValue(0));
|
||||
zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0)));
|
||||
|
@ -754,7 +754,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|||
MS_LOG(EXCEPTION) << "Fail to get abstract value with " << conf->ToString() << ", for " << new_node->DebugString();
|
||||
}
|
||||
if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) {
|
||||
auto partial_abstract = dyn_cast<PartialAbstractClosure>(new_node->abstract());
|
||||
auto partial_abstract = dyn_cast_ptr<PartialAbstractClosure>(new_node->abstract());
|
||||
if (partial_abstract->node() == node) {
|
||||
partial_abstract->set_node(new_node);
|
||||
}
|
||||
|
@ -768,8 +768,8 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|||
}
|
||||
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
|
||||
auto attrs = conf->ObtainEvalResult()->attribute();
|
||||
auto c_old = node->cast<CNodePtr>();
|
||||
auto c_new = new_node->cast<CNodePtr>();
|
||||
auto c_old = node->cast_ptr<CNode>();
|
||||
auto c_new = new_node->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(c_new);
|
||||
auto new_inputs = c_new->inputs();
|
||||
auto old_inputs = c_old->inputs();
|
||||
|
@ -786,7 +786,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|||
bool ignore_build_value = false;
|
||||
AnfNodePtr replace_node = nullptr;
|
||||
if (specializer_->engine()->check_isolated_side_effect()) {
|
||||
auto cnode_input = dyn_cast<CNode>(node_input);
|
||||
auto cnode_input = dyn_cast_ptr<CNode>(node_input);
|
||||
ignore_build_value = (cnode_input != nullptr && cnode_input->has_isolated_side_effect_node());
|
||||
if (ignore_build_value) {
|
||||
MS_LOG(INFO) << "Don't build value node for CNode which contains isolated side-effect inputs, node: "
|
||||
|
@ -878,7 +878,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const CNodePtr &cnode, con
|
|||
const AbstractBasePtr &abs, const AbstractBasePtrList &argvals) {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
AbstractFunctionPtr real_a = dyn_cast<AbstractFunction>(abs);
|
||||
auto real_a = dyn_cast_ptr<AbstractFunction>(abs);
|
||||
MS_EXCEPTION_IF_NULL(real_a);
|
||||
|
||||
AbstractFunctionPtr func_abs = real_a->GetUnique();
|
||||
|
@ -905,7 +905,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const CNodePtr &cnode, con
|
|||
// Set the flag, so this MetaFuncGraph will be Re-AutoMonaded.
|
||||
MS_EXCEPTION_IF_NULL(func_abs);
|
||||
if (func_abs->isa<MetaFuncGraphAbstractClosure>()) {
|
||||
auto specialized_fg = GetValueNode<FuncGraphPtr>(repl);
|
||||
auto specialized_fg = GetValuePtr<FuncGraph>(repl);
|
||||
if (specialized_fg != nullptr && (argvals.size() > 1) && argvals.back() != nullptr &&
|
||||
argvals.back()->isa<AbstractUMonad>()) {
|
||||
specialized_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
|
||||
|
@ -924,7 +924,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const CNodePtr &cnode
|
|||
MS_EXCEPTION_IF_NULL(errcode);
|
||||
*errcode = kSpecializeSuccess;
|
||||
|
||||
auto real_func = dyn_cast<TypedPrimitiveAbstractClosure>(func_abs);
|
||||
auto real_func = dyn_cast_ptr<TypedPrimitiveAbstractClosure>(func_abs);
|
||||
if (real_func != nullptr) {
|
||||
return BuildValueNode(real_func->prim(), abs);
|
||||
}
|
||||
|
@ -942,7 +942,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const CNodePtr &cnode
|
|||
argvals = result.first;
|
||||
AbstractBasePtr unique_output = result.second;
|
||||
|
||||
auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func_abs);
|
||||
auto prim_func = dyn_cast_ptr<PrimitiveAbstractClosure>(func_abs);
|
||||
if (prim_func != nullptr) {
|
||||
auto type_func = std::make_shared<TypedPrimitiveAbstractClosure>(prim_func->prim(), argvals, unique_output);
|
||||
return BuildValueNode(prim_func->prim(), type_func);
|
||||
|
@ -993,7 +993,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &c
|
|||
AbstractBasePtrList args;
|
||||
auto backed_fnval = fnval;
|
||||
if (fnval->isa<PartialAbstractClosure>()) {
|
||||
auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
|
||||
auto partial_closure = dyn_cast_ptr<PartialAbstractClosure>(fnval);
|
||||
backed_fnval = partial_closure->fn();
|
||||
args = partial_closure->args();
|
||||
}
|
||||
|
@ -1133,7 +1133,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
|
|||
// CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...)
|
||||
const size_t arg_start_index = 2;
|
||||
while (IsPrimitiveCNode(func, prim::kPrimPartial)) {
|
||||
std::vector<AnfNodePtr> inputs = func->cast<CNodePtr>()->inputs();
|
||||
std::vector<AnfNodePtr> inputs = func->cast_ptr<CNode>()->inputs();
|
||||
// First element is partial, second is func so arg is start from 2
|
||||
(void)args.insert(args.cbegin(), inputs.cbegin() + SizeToInt(arg_start_index), inputs.cend());
|
||||
func = inputs[1];
|
||||
|
@ -1222,10 +1222,10 @@ bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argv
|
|||
return true;
|
||||
}
|
||||
if (func->isa<MetaFuncGraphAbstractClosure>() && argvals.empty()) {
|
||||
auto meta_func_graph_wrapper = dyn_cast<MetaFuncGraphAbstractClosure>(func);
|
||||
auto meta_func_graph_wrapper = dyn_cast_ptr<MetaFuncGraphAbstractClosure>(func);
|
||||
auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph();
|
||||
if (meta_func_graph != nullptr && meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
|
||||
auto do_signature = dyn_cast<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
|
||||
auto do_signature = dyn_cast_ptr<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
|
||||
if (do_signature != nullptr && do_signature->function()->isa<Primitive>()) {
|
||||
MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY.";
|
||||
return true;
|
||||
|
@ -1331,7 +1331,7 @@ AnfNodePtr FuncGraphSpecializer::BuildValueNodeForAbstractFunction(const AnfNode
|
|||
const AbstractFunctionPtr &abs) {
|
||||
ValuePtr value = nullptr;
|
||||
if (abs->isa<PrimitiveAbstractClosure>()) {
|
||||
auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
|
||||
auto real_fn = dyn_cast_ptr<PrimitiveAbstractClosure>(abs);
|
||||
// For primitive, check if the attribute is the same with cnode inferred attribute, if not, clone a new one
|
||||
if (attrs != nullptr) {
|
||||
value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
|
||||
|
@ -1339,20 +1339,20 @@ AnfNodePtr FuncGraphSpecializer::BuildValueNodeForAbstractFunction(const AnfNode
|
|||
value = real_fn->prim();
|
||||
}
|
||||
} else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
|
||||
auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
|
||||
auto real_fn = dyn_cast_ptr<MetaFuncGraphAbstractClosure>(abs);
|
||||
value = real_fn->meta_func_graph();
|
||||
} else if (abs->isa<FuncGraphAbstractClosure>()) {
|
||||
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(abs);
|
||||
auto real_fn = dyn_cast_ptr<FuncGraphAbstractClosure>(abs);
|
||||
value = real_fn->func_graph();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (!value->isa<FuncGraph>() || value->cast<FuncGraphPtr>()->parent() == nullptr ||
|
||||
(IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast<FuncGraphPtr>()->parent()))) {
|
||||
if (!value->isa<FuncGraph>() || value->cast_ptr<FuncGraph>()->parent() == nullptr ||
|
||||
(IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast_ptr<FuncGraph>()->parent()))) {
|
||||
return BuildValueNode(value, ival);
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimJ) && origin_node->isa<Parameter>() &&
|
||||
!value->cast<FuncGraphPtr>()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) {
|
||||
!value->cast_ptr<FuncGraph>()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) {
|
||||
// Only if J(Parameter=func_graph) and func_graph(aka 'value') is not K graph.
|
||||
MS_LOG(DEBUG) << "Specialize the parameter used by J CNode, cnode: " << cnode->DebugString();
|
||||
return BuildValueNode(value, ival);
|
||||
|
|
|
@ -43,7 +43,7 @@ AnalysisContextPtr StackFrame::GetParentContext(const BaseFuncGraphEvaluatorPtr
|
|||
MS_EXCEPTION_IF_NULL(graph_func);
|
||||
MS_EXCEPTION_IF_NULL(fg_evaluator);
|
||||
AnalysisContextPtr parent_context = nullptr;
|
||||
auto func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(graph_func);
|
||||
auto func_graph_abs = dyn_cast_ptr<FuncGraphAbstractClosure>(graph_func);
|
||||
if (func_graph_abs != nullptr) { // Set parent context for FuncGraphAbstractClosure.
|
||||
parent_context = func_graph_abs->context();
|
||||
} else if (graph_func->isa<MetaFuncGraphAbstractClosure>()) { // Or DummyContext for MetaFuncGraphAbstractClosure.
|
||||
|
@ -164,7 +164,7 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
|
|||
<< ", current_context_: " << current_context_->ToString();
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_, current_context_->func_graph());
|
||||
EvalResultPtr node_eval_result = nullptr;
|
||||
const auto &fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(evaluator());
|
||||
auto fg_evaluator = dyn_cast_ptr<BaseFuncGraphEvaluator>(evaluator());
|
||||
if (fg_evaluator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Evaluator should be a BaseGraphEvaluator, but got " << evaluator()->ToString();
|
||||
}
|
||||
|
@ -194,7 +194,7 @@ void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last
|
|||
// Check if child func graph contains isolated side-effect.
|
||||
if (engine->check_isolated_side_effect()) {
|
||||
if (last_stack_frame->func_graph()->has_isolated_side_effect_node()) {
|
||||
auto cnode = dyn_cast<CNode>(CurrentNode());
|
||||
auto cnode = dyn_cast_ptr<CNode>(CurrentNode());
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->set_has_isolated_side_effect_node(true);
|
||||
cnode->func_graph()->set_has_isolated_side_effect_node(true);
|
||||
|
@ -205,7 +205,7 @@ void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last
|
|||
auto evaluator = last_stack_frame->evaluator();
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
evaluator->evaluator_cache_mgr()->SetValue(last_stack_frame->args_abs_list(), result);
|
||||
const auto &fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(evaluator);
|
||||
auto fg_evaluator = dyn_cast_ptr<BaseFuncGraphEvaluator>(evaluator);
|
||||
if (fg_evaluator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Evaluator should be a BaseGraphEvaluator, but got " << evaluator->ToString();
|
||||
}
|
||||
|
|
|
@ -315,7 +315,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
return std::make_shared<EvalResult>(possible_func->Clone(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(possible_func);
|
||||
auto func = dyn_cast_ptr<AbstractFunction>(possible_func);
|
||||
if (func == nullptr) {
|
||||
CheckInterpretedObject(possible_func);
|
||||
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << possible_func->ToString() << ".";
|
||||
|
@ -357,11 +357,11 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list);
|
||||
// Check if func graph contains isolated side-effect, and sync.
|
||||
if (check_isolated_side_effect()) {
|
||||
FuncGraphAbstractClosurePtr func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(func);
|
||||
auto func_graph_abs = mindspore::cast<FuncGraphAbstractClosure>(func);
|
||||
if (func_graph_abs != nullptr) {
|
||||
contains_isolated_side_effect |= func_graph_abs->func_graph()->has_isolated_side_effect_node();
|
||||
}
|
||||
MetaFuncGraphAbstractClosurePtr meta_func_graph_abs = dyn_cast<MetaFuncGraphAbstractClosure>(func);
|
||||
auto meta_func_graph_abs = mindspore::cast<MetaFuncGraphAbstractClosure>(func);
|
||||
if (meta_func_graph_abs != nullptr) {
|
||||
contains_isolated_side_effect |= meta_func_graph_abs->meta_func_graph()->has_isolated_side_effect_node();
|
||||
}
|
||||
|
@ -652,7 +652,7 @@ EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, c
|
|||
// again, so update the config_map with new_conf;
|
||||
anfnode_config_map_[orig_conf] = new_conf;
|
||||
MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->ToString() << ", to new_conf: " << new_conf->ToString();
|
||||
auto old_cnode = orig_conf->node()->cast<CNodePtr>();
|
||||
auto old_cnode = orig_conf->node()->cast_ptr<CNode>();
|
||||
auto new_cnode = new_conf->node()->cast<CNodePtr>();
|
||||
if (old_cnode != nullptr && new_cnode != nullptr) {
|
||||
if (old_cnode->func_graph() == new_cnode->func_graph()) {
|
||||
|
@ -757,20 +757,20 @@ std::string JoinBranchesFailedInfo(const AbstractBasePtr &spec, const AbstractBa
|
|||
<< spec->ToString() << ", and that of the previous branch is " << last_spec->ToString() << ".\n"
|
||||
<< "The node is " << node->DebugString(recursive_level);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>()->input(0);
|
||||
auto cnode = node->cast_ptr<CNode>()->input(0);
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
|
||||
// {prim::kPrimSwitch, cond, true_branch, false_branch}
|
||||
constexpr int true_index = 2;
|
||||
constexpr int false_index = 3;
|
||||
auto inputs = cnode->cast<CNodePtr>()->inputs();
|
||||
const auto &inputs = cnode->cast_ptr<CNode>()->inputs();
|
||||
buffer << ", true branch: " << inputs.at(true_index)->ToString()
|
||||
<< ", false branch: " << inputs.at(false_index)->ToString();
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
|
||||
// {prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, branch1, branch2, ...}}
|
||||
constexpr int branch_index = 2;
|
||||
auto tuple_node = cnode->cast<CNodePtr>()->input(branch_index);
|
||||
const auto &tuple_node = cnode->cast_ptr<CNode>()->input(branch_index);
|
||||
if (IsPrimitiveCNode(tuple_node, prim::kPrimMakeTuple)) {
|
||||
auto tuple_inputs = tuple_node->cast<CNodePtr>()->inputs();
|
||||
const auto &tuple_inputs = tuple_node->cast_ptr<CNode>()->inputs();
|
||||
for (size_t i = 1; i < tuple_inputs.size(); i++) {
|
||||
buffer << ", branch" << i << ": " << tuple_inputs.at(i);
|
||||
}
|
||||
|
@ -827,7 +827,7 @@ bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
|
|||
return true;
|
||||
}
|
||||
if (abstract->isa<AbstractSequence>()) {
|
||||
auto elements = abstract->cast<AbstractSequencePtr>()->elements();
|
||||
auto elements = abstract->cast_ptr<AbstractSequence>()->elements();
|
||||
if (std::any_of(elements.begin(), elements.end(),
|
||||
[](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
|
||||
return true;
|
||||
|
@ -888,7 +888,7 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
|
|||
AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs,
|
||||
const std::vector<AsyncAbstractPtr> &pending_async_abstract_list,
|
||||
const std::vector<std::size_t> &index) {
|
||||
const auto sequence_abs = dyn_cast<AbstractSequence>(orig_abs);
|
||||
auto sequence_abs = dyn_cast_ptr<AbstractSequence>(orig_abs);
|
||||
if (sequence_abs != nullptr) {
|
||||
const auto &orig_elements = sequence_abs->elements();
|
||||
AbstractBasePtrList new_elements;
|
||||
|
@ -1153,7 +1153,7 @@ AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &cont
|
|||
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
|
||||
if (enable_eliminate_unused_element && value->isa<ValueSequence>()) {
|
||||
auto abs = value->ToAbstract();
|
||||
auto sequence_abs = dyn_cast<AbstractSequence>(abs);
|
||||
auto sequence_abs = dyn_cast_ptr<AbstractSequence>(abs);
|
||||
MS_EXCEPTION_IF_NULL(sequence_abs);
|
||||
if (anf_node != nullptr) {
|
||||
SetSequenceNodeElementsUseFlags(anf_node, std::make_shared<std::vector<bool>>(sequence_abs->elements().size()));
|
||||
|
@ -1179,12 +1179,12 @@ EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrLi
|
|||
if (evaluator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
|
||||
}
|
||||
auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
|
||||
auto trivial_evaluator = dyn_cast_ptr<TrivialPrimEvaluator>(evaluator);
|
||||
if (trivial_evaluator != nullptr) {
|
||||
return trivial_evaluator->EvalPrim(nullptr, arg_specs);
|
||||
}
|
||||
// Support MakeTuple/MakeList ops in PyNative mode.
|
||||
auto transition_evaluator = dyn_cast<TransitionPrimEvaluator>(evaluator);
|
||||
auto transition_evaluator = dyn_cast_ptr<TransitionPrimEvaluator>(evaluator);
|
||||
if (transition_evaluator != nullptr &&
|
||||
(transition_evaluator->isa<MakeTupleEvaluator>() || transition_evaluator->isa<MakeListEvaluator>())) {
|
||||
return transition_evaluator->EvalPrim(nullptr, arg_specs, nullptr, nullptr);
|
||||
|
|
|
@ -146,15 +146,15 @@ void ValidateValueNode(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
void CheckValueTuple(const AnfNodePtr &node) {
|
||||
const auto &value_node = node->cast<ValueNodePtr>();
|
||||
auto value_node = node->cast_ptr<ValueNode>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
const auto value = value_node->value();
|
||||
const auto &value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
const auto value_tuple = value->cast<ValueTuplePtr>();
|
||||
auto value_tuple = value->cast_ptr<ValueTuple>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
const auto tuple_values = value_tuple->value();
|
||||
const auto &tuple_values = value_tuple->value();
|
||||
for (size_t i = 0; i < tuple_values.size(); ++i) {
|
||||
const auto input_node = NewValueNode(tuple_values[i]);
|
||||
auto input_node = NewValueNode(tuple_values[i]);
|
||||
ValidateOperation(input_node);
|
||||
ValidateValueNode(input_node);
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ void Validate(const FuncGraphPtr &func_graph) {
|
|||
for (auto node : all_nodes) {
|
||||
TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
|
||||
while (IsPrimitiveCNode(node, prim::kPrimReturn) || IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
node = node->cast<CNodePtr>()->input(1);
|
||||
node = node->cast_ptr<CNode>()->input(1);
|
||||
}
|
||||
if (IsValueNode<ValueTuple>(node)) {
|
||||
CheckValueTuple(node);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -41,7 +41,7 @@ AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) {
|
|||
}
|
||||
return std::make_shared<AbstractFuncUnion>(this_func, other);
|
||||
}
|
||||
auto other_union = dyn_cast<AbstractFuncUnion>(other);
|
||||
auto other_union = dyn_cast_ptr<AbstractFuncUnion>(other);
|
||||
MS_EXCEPTION_IF_NULL(other_union);
|
||||
if (other_union->IsSuperSet(this_func)) {
|
||||
return other;
|
||||
|
@ -122,7 +122,7 @@ AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) {
|
|||
}
|
||||
return std::make_shared<AbstractFuncUnion>(this_func, other);
|
||||
}
|
||||
auto other_union = dyn_cast<AbstractFuncUnion>(other);
|
||||
auto other_union = dyn_cast_ptr<AbstractFuncUnion>(other);
|
||||
MS_EXCEPTION_IF_NULL(other_union);
|
||||
if (other_union->IsSuperSet(this_func)) {
|
||||
return other;
|
||||
|
|
|
@ -173,14 +173,14 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
|||
if (*this == *other) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
}
|
||||
auto type_self = GetTypeTrack();
|
||||
auto type_other = other->GetTypeTrack();
|
||||
const auto &type_self = GetTypeTrack();
|
||||
const auto &type_other = other->GetTypeTrack();
|
||||
TypePtr res_type = TypeJoin(type_self, type_other);
|
||||
if (res_type == kAnyType) {
|
||||
TypeJoinLogging(type_self, type_other, shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
auto value_self = GetValueTrack();
|
||||
auto value_other = other->GetValueTrack();
|
||||
const auto &value_self = GetValueTrack();
|
||||
const auto &value_other = other->GetValueTrack();
|
||||
ValuePtr res_value = ValueJoin(value_self, value_other);
|
||||
if (res_value == value_self) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
|
@ -193,7 +193,7 @@ AbstractBasePtr AbstractType::Clone() const {
|
|||
if (value_self == nullptr || !value_self->isa<Type>()) {
|
||||
return nullptr;
|
||||
}
|
||||
TypePtr type_self = value_self->cast<TypePtr>();
|
||||
auto type_self = value_self->cast_ptr<Type>();
|
||||
return std::make_shared<AbstractType>(type_self->Clone());
|
||||
}
|
||||
|
||||
|
@ -201,7 +201,8 @@ bool AbstractType::operator==(const AbstractBase &other) const {
|
|||
if (this == &other) {
|
||||
return true;
|
||||
}
|
||||
return tid() == other.tid() && IsEqual(dyn_cast<Type>(GetValueTrack()), dyn_cast<Type>(other.GetValueTrack()));
|
||||
return tid() == other.tid() &&
|
||||
IsEqual(dyn_cast_ptr<Type>(GetValueTrack()), dyn_cast_ptr<Type>(other.GetValueTrack()));
|
||||
}
|
||||
|
||||
std::string AbstractType::ToString() const {
|
||||
|
@ -215,7 +216,7 @@ std::string AbstractType::ToString() const {
|
|||
buffer << type_name() << "(Value: nullptr)";
|
||||
return buffer.str();
|
||||
}
|
||||
TypePtr type_self = value_self->cast<TypePtr>();
|
||||
auto type_self = value_self->cast_ptr<Type>();
|
||||
buffer << type_name() << "("
|
||||
<< "Value: " << type_self->ToString() << ")";
|
||||
return buffer.str();
|
||||
|
@ -496,7 +497,7 @@ AnfNodeWeakPtrList AbstractSequence::SequenceNodesJoin(const AbstractBasePtr &ot
|
|||
if (!enable_eliminate_unused_element || this->sequence_nodes() == nullptr) {
|
||||
return sequence_nodes;
|
||||
}
|
||||
auto other_sequence = dyn_cast<AbstractSequence>(other);
|
||||
auto other_sequence = dyn_cast_ptr<AbstractSequence>(other);
|
||||
if (other_sequence == nullptr) {
|
||||
return sequence_nodes;
|
||||
}
|
||||
|
@ -714,7 +715,7 @@ template MS_CORE_API ValuePtr AbstractSequence::ElementsBuildValue<ValueList>()
|
|||
template <typename T>
|
||||
AbstractBasePtr AbstractSequence::ElementsJoin(const AbstractBasePtr &other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
auto other_sequeue = dyn_cast<T>(other);
|
||||
auto other_sequeue = dyn_cast_ptr<T>(other);
|
||||
if (other_sequeue == nullptr) {
|
||||
AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
|
@ -761,7 +762,7 @@ bool AbstractSequence::operator==(const AbstractSequence &other) const {
|
|||
}
|
||||
|
||||
void AbstractTuple::set_shape(const BaseShapePtr &shape) {
|
||||
auto tuple_shape = dyn_cast<TupleShape>(shape);
|
||||
auto tuple_shape = dyn_cast_ptr<TupleShape>(shape);
|
||||
MS_EXCEPTION_IF_NULL(tuple_shape);
|
||||
if (tuple_shape->shape().size() != elements_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Size mismatch: " << tuple_shape->shape().size() << " vs " << elements_.size();
|
||||
|
@ -789,7 +790,7 @@ bool AbstractTuple::ContainsAllBroadenTensors() const {
|
|||
for (size_t i = 0; i < elements_.size(); ++i) {
|
||||
if (!(elements_[i]->isa<abstract::AbstractUndetermined>() && elements_[i]->IsBroaden()) &&
|
||||
!(elements_[i]->isa<abstract::AbstractTuple>() &&
|
||||
elements_[i]->cast<abstract::AbstractTuplePtr>()->ContainsAllBroadenTensors())) {
|
||||
elements_[i]->cast_ptr<abstract::AbstractTuple>()->ContainsAllBroadenTensors())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -927,7 +928,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
|||
|
||||
// AbstractTensor join with AbstractUndetermined
|
||||
if (other_type->type_id() == kObjectTypeUndeterminedType) {
|
||||
auto other_undetermined_tensor = dyn_cast<AbstractUndetermined>(other);
|
||||
auto other_undetermined_tensor = dyn_cast_ptr<AbstractUndetermined>(other);
|
||||
MS_EXCEPTION_IF_NULL(other_undetermined_tensor);
|
||||
// Check shape
|
||||
auto res_shape = ShapeJoin(shape(), other_undetermined_tensor->shape());
|
||||
|
@ -941,7 +942,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
|||
}
|
||||
|
||||
// AbstractTensor join with AbstractTensor
|
||||
auto other_tensor = dyn_cast<AbstractTensor>(other);
|
||||
auto other_tensor = dyn_cast_ptr<AbstractTensor>(other);
|
||||
if (other_tensor == nullptr) {
|
||||
AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
|
@ -964,8 +965,8 @@ bool AbstractTensor::equal_to(const AbstractTensor &other) const {
|
|||
return true;
|
||||
}
|
||||
// Check value. for AbstractTensor, both value should be AnyValue.
|
||||
auto v1 = GetValueTrack();
|
||||
auto v2 = other.GetValueTrack();
|
||||
const auto &v1 = GetValueTrack();
|
||||
const auto &v2 = other.GetValueTrack();
|
||||
if (v1 != v2 && (v1 == nullptr || !v1->isa<AnyValue>() || v2 == nullptr || !v2->isa<AnyValue>())) {
|
||||
return false;
|
||||
}
|
||||
|
@ -1142,7 +1143,7 @@ TypePtr AbstractJTagged::BuildType() const {
|
|||
|
||||
AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
auto other_jtagged = dyn_cast<AbstractJTagged>(other);
|
||||
auto other_jtagged = dyn_cast_ptr<AbstractJTagged>(other);
|
||||
if (other_jtagged == nullptr) {
|
||||
AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
|
@ -1182,8 +1183,8 @@ AbstractRefTensor::AbstractRefTensor(const AbstractTensorPtr &ref_value, const V
|
|||
|
||||
TypePtr AbstractRefTensor::BuildType() const {
|
||||
auto type = AbstractTensor::BuildType();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
auto subtype = type->cast<TensorTypePtr>();
|
||||
auto subtype = dyn_cast_ptr<TensorType>(type);
|
||||
MS_EXCEPTION_IF_NULL(subtype);
|
||||
return std::make_shared<RefType>(subtype);
|
||||
}
|
||||
|
||||
|
@ -1430,7 +1431,7 @@ BaseShapePtrList AbstractSparseTensor::ElementsShapeTupleRecursive() const {
|
|||
BaseShapePtrList element_shape_list;
|
||||
for (const auto &element : elements()) {
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
auto abs_tuple = element->cast<AbstractTuplePtr>();
|
||||
auto abs_tuple = element->cast_ptr<AbstractTuple>();
|
||||
if (abs_tuple == nullptr) {
|
||||
element_shape_list.push_back(element->BuildShape());
|
||||
} else {
|
||||
|
|
|
@ -120,17 +120,17 @@ class MS_CORE_API AbstractBase : public Base {
|
|||
/// \brief Get the abstract value, which is tracked.
|
||||
///
|
||||
/// \return A pointer to the Value.
|
||||
ValuePtr GetValueTrack() const { return value_; }
|
||||
const ValuePtr &GetValueTrack() const { return value_; }
|
||||
|
||||
/// \brief Get the abstract type, which is tracked.
|
||||
///
|
||||
/// \return A pointer to the Type.
|
||||
TypePtr GetTypeTrack() const { return type_; }
|
||||
const TypePtr &GetTypeTrack() const { return type_; }
|
||||
|
||||
/// \brief Get the abstract shape, which is tracked.
|
||||
///
|
||||
/// \return A pointer to the BaseShape.
|
||||
BaseShapePtr GetShapeTrack() const { return shape_; }
|
||||
const BaseShapePtr &GetShapeTrack() const { return shape_; }
|
||||
|
||||
/// \brief Try to build a real value from an abstract value.
|
||||
///
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -141,8 +141,8 @@ bool AnalysisContext::operator==(const AnalysisContext &other) const {
|
|||
if (func_graph_->has_flag(GRAPH_FLAG_IS_WHILE_HEADER) &&
|
||||
args_spec_list_[i]->isa<abstract::FuncGraphAbstractClosure>() &&
|
||||
other.args_spec_list_[i]->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
auto temp_this = args_spec_list_[i]->cast<abstract::FuncGraphAbstractClosurePtr>()->Copy();
|
||||
auto temp_other = other.args_spec_list_[i]->cast<abstract::FuncGraphAbstractClosurePtr>()->Copy();
|
||||
auto temp_this = args_spec_list_[i]->cast_ptr<abstract::FuncGraphAbstractClosure>()->Copy();
|
||||
auto temp_other = other.args_spec_list_[i]->cast_ptr<abstract::FuncGraphAbstractClosure>()->Copy();
|
||||
temp_this->set_tracking_id(nullptr);
|
||||
temp_other->set_tracking_id(nullptr);
|
||||
if (!(*temp_this == *temp_other)) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -41,7 +41,7 @@ ABSTRACT_REPORT_NAME_DEC(KeywordArg)
|
|||
TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &error_message_prefix) {
|
||||
auto ori_type = type;
|
||||
if (type->isa<TensorType>()) {
|
||||
auto tensor = type->cast<TensorTypePtr>();
|
||||
auto tensor = type->cast_ptr<TensorType>();
|
||||
type = tensor->element();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -212,7 +212,7 @@ AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const Abstrac
|
|||
}
|
||||
|
||||
AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
|
||||
AbstractFunctionPtr f_spec = dyn_cast<AbstractFunction>(spec);
|
||||
auto f_spec = dyn_cast_ptr<AbstractFunction>(spec);
|
||||
if (f_spec != nullptr) {
|
||||
return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
}
|
||||
|
@ -285,7 +285,7 @@ AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) {
|
|||
|
||||
auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
|
||||
if (type->isa<TensorType>()) {
|
||||
auto tensor_type = type->cast<TensorTypePtr>();
|
||||
auto tensor_type = type->cast_ptr<TensorType>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
|
||||
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
|
||||
|
@ -319,8 +319,8 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type
|
|||
}
|
||||
return MakeAbstractTensor(shape, type);
|
||||
} else if (base_shape->isa<TupleShape>() && type->isa<Tuple>()) {
|
||||
auto shape_tuple = base_shape->cast<TupleShapePtr>();
|
||||
auto type_tuple = type->cast<TuplePtr>();
|
||||
auto shape_tuple = base_shape->cast_ptr<TupleShape>();
|
||||
auto type_tuple = type->cast_ptr<Tuple>();
|
||||
AbstractBasePtrList ptr_list;
|
||||
for (size_t it = 0; it < shape_tuple->size(); ++it) {
|
||||
auto tensor_it = MakeAbstract((*shape_tuple)[it], (*type_tuple)[it]);
|
||||
|
@ -329,8 +329,8 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type
|
|||
auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
|
||||
return tuple;
|
||||
} else if (base_shape->isa<ListShape>() && type->isa<List>()) {
|
||||
auto shape_list = base_shape->cast<ListShapePtr>();
|
||||
auto type_list = type->cast<ListPtr>();
|
||||
auto shape_list = base_shape->cast_ptr<ListShape>();
|
||||
auto type_list = type->cast_ptr<List>();
|
||||
AbstractBasePtrList ptr_list;
|
||||
for (size_t it = 0; it < shape_list->size(); ++it) {
|
||||
auto tensor_it = MakeAbstract((*shape_list)[it], (*type_list)[it]);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -42,12 +42,12 @@ bool BaseRef::operator==(const BaseRef &other) const {
|
|||
return false;
|
||||
}
|
||||
if (m_ptr->isa<Value>()) {
|
||||
return *(m_ptr->cast<ValuePtr>()) == *(other.m_ptr->cast<ValuePtr>());
|
||||
return *(m_ptr->cast_ptr<Value>()) == *(other.m_ptr->cast_ptr<Value>());
|
||||
}
|
||||
|
||||
// for noderef equal
|
||||
if (m_ptr->isa<BaseRef>()) {
|
||||
return *std::static_pointer_cast<BaseRef>(m_ptr) == *std::static_pointer_cast<BaseRef>(other.m_ptr);
|
||||
return *(m_ptr->cast_ptr<BaseRef>()) == *(other.m_ptr->cast_ptr<BaseRef>());
|
||||
}
|
||||
|
||||
// for node equal
|
||||
|
|
|
@ -1173,7 +1173,11 @@ inline S GetValueNode(const AnfNodePtr &node) {
|
|||
|
||||
template <typename S, typename std::enable_if<std::is_base_of<Value, S>::value, S>::type * = nullptr>
|
||||
inline S *GetValuePtr(const AnfNodePtr &node) {
|
||||
auto value = GetValuePtr(node);
|
||||
auto value_node = dyn_cast_ptr<ValueNode>(node);
|
||||
if (value_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto &value = value_node->value();
|
||||
return (value == nullptr) ? nullptr : value->cast_ptr<S>();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -55,6 +55,11 @@ class MS_CORE_API RefType final : public TensorType {
|
|||
/// \param[in] subtype Define the TensorType for RefType object to refer to.
|
||||
explicit RefType(const TensorTypePtr &subtype) : TensorType(subtype->element()) {}
|
||||
|
||||
/// \brief Constructor for RefType.
|
||||
///
|
||||
/// \param[in] subtype Define the TensorType for RefType object to refer to.
|
||||
explicit RefType(const TensorType *subtype) : TensorType(subtype->element()) {}
|
||||
|
||||
/// \brief Destructor of RefType.
|
||||
~RefType() override {}
|
||||
MS_DECLARE_PARENT(RefType, TensorType)
|
||||
|
|
|
@ -592,7 +592,7 @@ std::string FuncGraph::GetVariableArgName() {
|
|||
|
||||
const auto ¶m_node = GetVariableArgParameter();
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
const auto ¶meter = param_node->cast<ParameterPtr>();
|
||||
auto parameter = param_node->cast_ptr<Parameter>();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
return parameter->name();
|
||||
}
|
||||
|
@ -614,7 +614,7 @@ std::string FuncGraph::GetVariableKwargName() {
|
|||
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is " << fv_param_count_
|
||||
<< ", parameters is less than 1 + fv_param_count";
|
||||
}
|
||||
const auto ¶meter = parameters_[(parameters_.size() - fv_param_count_) - 1]->cast<ParameterPtr>();
|
||||
auto parameter = parameters_[(parameters_.size() - fv_param_count_) - 1]->cast_ptr<Parameter>();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
return parameter->name();
|
||||
}
|
||||
|
@ -666,7 +666,7 @@ int FuncGraph::GetPositionalArgsCount() const {
|
|||
AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
|
||||
for (size_t i = 0; i < parameters_.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(parameters_[i]);
|
||||
auto param_cast = parameters_[i]->cast<ParameterPtr>();
|
||||
auto param_cast = parameters_[i]->cast_ptr<Parameter>();
|
||||
MS_EXCEPTION_IF_NULL(param_cast);
|
||||
if (param_cast->name() == name) {
|
||||
return parameters_[i];
|
||||
|
|
|
@ -87,7 +87,7 @@ void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
|
|||
void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(target);
|
||||
auto old_param = node->cast<ParameterPtr>();
|
||||
auto old_param = node->cast_ptr<Parameter>();
|
||||
MS_EXCEPTION_IF_NULL(old_param);
|
||||
auto debug_info = CloneNodeDebugInfo(node->debug_info(), relation_);
|
||||
auto new_param = (is_add ? target->add_parameter(std::move(debug_info))
|
||||
|
@ -136,7 +136,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node) {
|
|||
ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
|
||||
new_const->set_scope(scope);
|
||||
new_const->set_abstract(node->abstract());
|
||||
new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
|
||||
new_const->set_has_new_value(node->cast_ptr<ValueNode>()->has_new_value());
|
||||
repl_node_[node] = std::move(new_const);
|
||||
}
|
||||
|
||||
|
@ -148,7 +148,7 @@ void Cloner::CloneFuncGraphValueNode(const AnfNodePtr &node, const FuncGraphPtr
|
|||
ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
|
||||
new_const->set_scope(scope);
|
||||
new_const->set_abstract(node->abstract());
|
||||
new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
|
||||
new_const->set_has_new_value(node->cast_ptr<ValueNode>()->has_new_value());
|
||||
repl_node_[node] = std::move(new_const);
|
||||
}
|
||||
|
||||
|
@ -229,7 +229,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
|
|||
|
||||
auto &cnodes = func_graph->func_graph_cnodes_index();
|
||||
for (auto &cnode : cnodes) {
|
||||
auto parent = cnode.first->first->cast<CNodePtr>();
|
||||
auto parent = cnode.first->first->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(parent);
|
||||
const auto &valuenode = parent->input(IntToSize(cnode.first->second));
|
||||
CloneFuncGraphValueNode(valuenode, target_func_graph);
|
||||
|
@ -291,7 +291,7 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
|
|||
auto free_var_node = utils::cast<AnfNodePtr>(free_var);
|
||||
// Don't lift weight parameter to top func_graph.
|
||||
if (IsLiftTopFuncGraph(func_graph) && free_var_node->isa<Parameter>()) {
|
||||
auto free_var_param = free_var_node->cast<ParameterPtr>();
|
||||
auto free_var_param = free_var_node->cast_ptr<Parameter>();
|
||||
if (free_var_param->has_default()) {
|
||||
MS_LOG(DEBUG) << "Bypass weight param: " << free_var_param->DebugString()
|
||||
<< " for top_func_graph: " << lift_top_func_graph->ToString();
|
||||
|
@ -306,7 +306,7 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Gen param: " << free_var_node->ToString() << " for func_graph: " << func_graph->ToString();
|
||||
auto fv_parameter = AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var));
|
||||
auto fv_parameter = AddParameter(func_graph, free_var_node);
|
||||
fv_parameter->set_user_data<bool>("lifted_from_fv", std::make_shared<bool>(true));
|
||||
auto &fg_params = repl_func_graph_params_[func_graph];
|
||||
(void)fg_params.emplace_back(fv_parameter);
|
||||
|
@ -316,7 +316,7 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
|
|||
void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) const {
|
||||
param->set_abstract(node->abstract());
|
||||
if (node->isa<Parameter>()) {
|
||||
ParameterPtr old_param = node->cast<ParameterPtr>();
|
||||
auto old_param = node->cast_ptr<Parameter>();
|
||||
if (old_param->has_default()) {
|
||||
// Default parameter can be shared since it is readonly.
|
||||
param->set_default_param(old_param->default_param());
|
||||
|
@ -728,11 +728,11 @@ void Cloner::CloneNodes() {
|
|||
|
||||
void Cloner::LinkEdges() {
|
||||
for (auto &repl : repl_node_) {
|
||||
CNodePtr old_node = dyn_cast<CNode>(repl.first);
|
||||
auto old_node = dyn_cast_ptr<CNode>(repl.first);
|
||||
if (old_node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
CNodePtr new_node = repl.second->cast<CNodePtr>();
|
||||
auto new_node = repl.second->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
for (auto &input : old_node->inputs()) {
|
||||
auto iter = repl_node_.find(input);
|
||||
|
|
|
@ -131,7 +131,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
|
|||
auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
|
||||
[param_name](const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto param = node->cast<ParameterPtr>();
|
||||
auto param = node->cast_ptr<Parameter>();
|
||||
return param != nullptr && param->name() == param_name;
|
||||
});
|
||||
if (find_kw_arg_in_list) {
|
||||
|
|
|
@ -186,8 +186,8 @@ std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
|
|||
return vecs;
|
||||
}
|
||||
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
auto graph = GetValueNode<FuncGraphPtr>(node);
|
||||
auto graph = GetValuePtr<FuncGraph>(node);
|
||||
if (graph != nullptr) {
|
||||
auto &ret = graph->return_node();
|
||||
if (ret != nullptr) {
|
||||
vecs.push_back(ret);
|
||||
|
@ -209,19 +209,16 @@ std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node) {
|
|||
return vecs;
|
||||
}
|
||||
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
auto graph = GetValueNode<FuncGraphPtr>(node);
|
||||
auto graph = GetValuePtr<FuncGraph>(node);
|
||||
if (graph != nullptr) {
|
||||
auto &ret = graph->return_node();
|
||||
if (ret != nullptr) {
|
||||
vecs.push_back(ret);
|
||||
}
|
||||
return vecs;
|
||||
} else {
|
||||
if (node->isa<CNode>()) {
|
||||
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
|
||||
}
|
||||
return vecs;
|
||||
} else if (node->isa<CNode>()) {
|
||||
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
|
||||
}
|
||||
return vecs;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> SuccIncoming(const AnfNodePtr &node) {
|
||||
|
@ -234,26 +231,24 @@ std::vector<AnfNodePtr> SuccIncoming(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
std::vector<AnfNodePtr> SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node) {
|
||||
std::vector<AnfNodePtr> vecs;
|
||||
if (node == nullptr) {
|
||||
return vecs;
|
||||
}
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode != nullptr) {
|
||||
auto &inputs = cnode->inputs();
|
||||
// Check if free variables used.
|
||||
for (const auto &input : inputs) {
|
||||
auto input_fg = GetValueNode<FuncGraphPtr>(input);
|
||||
if (input_fg) {
|
||||
for (auto &fv : input_fg->free_variables_nodes()) {
|
||||
if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
|
||||
vecs.push_back(fv);
|
||||
}
|
||||
if (cnode == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<AnfNodePtr> vecs;
|
||||
const auto &inputs = cnode->inputs();
|
||||
// Check if free variables used.
|
||||
for (const auto &input : inputs) {
|
||||
auto input_fg = GetValuePtr<FuncGraph>(input);
|
||||
if (input_fg != nullptr) {
|
||||
for (auto &fv : input_fg->free_variables_nodes()) {
|
||||
if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
|
||||
vecs.push_back(fv);
|
||||
}
|
||||
}
|
||||
}
|
||||
FetchCNodeSuccessors(cnode, &vecs);
|
||||
}
|
||||
FetchCNodeSuccessors(cnode, &vecs);
|
||||
return vecs;
|
||||
}
|
||||
|
||||
|
@ -263,28 +258,24 @@ std::vector<AnfNodePtr> SuccWithFilter(const GraphFilterFunc &graph_filter, cons
|
|||
return vecs;
|
||||
}
|
||||
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
auto graph = GetValueNode<FuncGraphPtr>(node);
|
||||
auto graph = GetValueNode<FuncGraphPtr>(node);
|
||||
if (graph != nullptr) {
|
||||
if (graph_filter != nullptr && graph_filter(graph)) {
|
||||
return vecs;
|
||||
}
|
||||
|
||||
auto &ret = graph->return_node();
|
||||
if (ret != nullptr) {
|
||||
vecs.push_back(ret);
|
||||
}
|
||||
return vecs;
|
||||
} else {
|
||||
if (node->isa<CNode>()) {
|
||||
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
|
||||
}
|
||||
return vecs;
|
||||
} else if (node->isa<CNode>()) {
|
||||
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
|
||||
}
|
||||
return vecs;
|
||||
}
|
||||
|
||||
const std::vector<AnfNodePtr> &GetInputs(const AnfNodePtr &node) {
|
||||
static std::vector<AnfNodePtr> empty_inputs;
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
auto cnode = dyn_cast_ptr<CNode>(node);
|
||||
if (cnode != nullptr) {
|
||||
return cnode->inputs();
|
||||
}
|
||||
|
|
|
@ -92,8 +92,8 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher {
|
|||
if (!IsValueNode<FuncGraph>(vnode)) {
|
||||
return;
|
||||
}
|
||||
auto fg = GetValueNode<FuncGraphPtr>(vnode);
|
||||
AnfNodePtr ret = fg->return_node();
|
||||
auto fg = GetValuePtr<FuncGraph>(vnode);
|
||||
const auto &ret = fg->return_node();
|
||||
DeepFirstSearcher::Visit(ret);
|
||||
}
|
||||
|
||||
|
|
|
@ -661,7 +661,7 @@ void FuncGraphManager::MoveAllCNodeDropGraph(const FuncGraphPtr &source, const F
|
|||
const ScopePtr &scope) {
|
||||
AnfNodePtr source_return = source->get_return();
|
||||
AnfNodePtr source_output = source->output();
|
||||
AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0);
|
||||
AnfNodePtr source_prim = source_return->cast_ptr<CNode>()->input(0);
|
||||
|
||||
int index = 0;
|
||||
(void)node_users_[source_prim].erase(make_pair(source_return, index));
|
||||
|
@ -1105,7 +1105,7 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se
|
|||
std::find_if(meta_fg_prim_values.begin(), meta_fg_prim_values.end(), [seen_num](const auto &iter) {
|
||||
// Check g1->MetaFgPrim(fg)->g2->g cycle.
|
||||
if (IsValueNode<FuncGraph>(iter.first)) {
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(iter.first);
|
||||
auto func_graph = GetValuePtr<FuncGraph>(iter.first);
|
||||
return func_graph->seen_ != seen_num;
|
||||
}
|
||||
if (IsValueNode<Primitive>(iter.first)) {
|
||||
|
|
|
@ -435,7 +435,7 @@ IMM_TRAITS(StringImmPtr, const char *)
|
|||
|
||||
/// \brief RefKey defines a class whose real type is String.
|
||||
/// \brief Notice: RefKey is keep for compatible only, we use RefKey just as StringImm.
|
||||
class MS_CORE_API RefKey : public StringImm {
|
||||
class MS_CORE_API RefKey final : public StringImm {
|
||||
public:
|
||||
/// \brief Constructor of RefKey.
|
||||
///
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,8 +27,8 @@ void AnfIrVisitor::Visit(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
void AnfIrVisitor::Visit(const ValueNodePtr &vnode) {
|
||||
if (IsValueNode<FuncGraph>(vnode)) {
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
|
||||
auto func_graph = GetValuePtr<FuncGraph>(vnode);
|
||||
if (func_graph != nullptr) {
|
||||
Visit(func_graph->output());
|
||||
}
|
||||
}
|
||||
|
@ -36,36 +36,34 @@ void AnfIrVisitor::Visit(const ValueNodePtr &vnode) {
|
|||
void AnfIrVisitor::Visit(const ParameterPtr &) {}
|
||||
|
||||
VisitFuncType AnfIrVisitor::Match(const PrimitivePtr &prim, const std::vector<PredicateFuncType> &funcs) {
|
||||
auto fn = [prim, funcs, this](const AnfNodePtr &node) {
|
||||
return [prim, funcs, this](const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
auto &inputs = node->cast_ptr<CNode>()->inputs();
|
||||
auto funcs_size = funcs.size();
|
||||
auto inputs_size = inputs.size();
|
||||
|
||||
// check the inputs are matched with the predicate functions
|
||||
// Check the inputs are matched with the predicate functions.
|
||||
if (funcs_size > 0) {
|
||||
// use the predicate function list to check the number of inputs
|
||||
// Use the predicate function list to check the number of inputs.
|
||||
if (funcs_size != (inputs_size - 1)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// check per input
|
||||
for (size_t i = 0; i < funcs_size; i++) {
|
||||
// Check inputs.
|
||||
for (size_t i = 0; i < funcs_size; ++i) {
|
||||
if (!funcs[i](inputs[i + 1])) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// visit the inputs
|
||||
for (size_t i = 1; i < inputs_size; i++) {
|
||||
// Visit argument inputs.
|
||||
for (size_t i = 1; i < inputs_size; ++i) {
|
||||
this->Visit(inputs[i]);
|
||||
}
|
||||
};
|
||||
|
||||
return fn;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -103,7 +103,7 @@ static inline bool UseMPI() {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
|
||||
bool IsEqual(const T *a, const T *b) {
|
||||
if (a == b) {
|
||||
return true;
|
||||
}
|
||||
|
@ -113,6 +113,11 @@ inline bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
|
|||
return *a == *b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
|
||||
return IsEqual(a.get(), b.get());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IsAttrsEqual(const T &a, const T &b) {
|
||||
if (&a == &b) {
|
||||
|
|
Loading…
Reference in New Issue