From 7ce82746ab442c456b444fedb1a12694af644c82 Mon Sep 17 00:00:00 2001 From: chenfei Date: Tue, 10 May 2022 20:54:21 +0800 Subject: [PATCH] rm abstract ref key --- mindspore/ccsrc/common/debug/anf_ir_dump.cc | 10 ++ .../optimizer/auto_monad_eliminate.cc | 2 +- .../optimizer/irpass/gradient_eliminate.cc | 8 -- mindspore/ccsrc/pipeline/jit/action.cc | 3 +- .../pipeline/jit/static_analysis/evaluator.cc | 9 +- .../jit/static_analysis/order_enforce.cc | 18 ++- .../pipeline/jit/static_analysis/prim.cc | 14 +-- .../jit/static_analysis/static_analysis.cc | 2 +- mindspore/ccsrc/pipeline/jit/validator.cc | 11 +- .../transform/express_ir/mindir_exporter.cc | 2 +- mindspore/core/abstract/abstract_value.cc | 119 +++++++----------- mindspore/core/abstract/abstract_value.h | 81 ++---------- mindspore/core/ir/meta_tensor_extends.cc | 3 +- mindspore/core/ir/tensor.cc | 3 +- mindspore/core/ir/value_extends.cc | 4 +- .../core/load_mindir/anf_model_parser.cc | 6 +- tests/ut/cpp/ir/value_test.cc | 7 -- 17 files changed, 93 insertions(+), 209 deletions(-) diff --git a/mindspore/ccsrc/common/debug/anf_ir_dump.cc b/mindspore/ccsrc/common/debug/anf_ir_dump.cc index 5c0b7093d92..a037f9068ce 100644 --- a/mindspore/ccsrc/common/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/common/debug/anf_ir_dump.cc @@ -62,12 +62,16 @@ void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) { } ValuePtr tensor_value = nullptr; + RefKeyPtr ref_key = nullptr; abstract::AbstractSequencePtr sequence_abs = nullptr; auto abstract = node->abstract(); if (abstract != nullptr) { if (abstract->isa()) { tensor_value = abstract->BuildValue(); } + if (abstract->isa()) { + ref_key = abstract->cast()->ref_key_value()->cast(); + } sequence_abs = dyn_cast(abstract); } @@ -78,6 +82,9 @@ void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) { if (tensor_value != nullptr && tensor_value != kAnyValue) { buffer << ", value=..."; } + if (ref_key != nullptr) { + buffer << ", ref_key=:" << ref_key->name(); + } PrintTupleNodeUsedFlags(buffer, sequence_abs); buffer << ">"; } else if (type != nullptr) { @@ -85,6 +92,9 @@ void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) { if (tensor_value != nullptr && tensor_value != kAnyValue) { buffer << ", value=..."; } + if (ref_key != nullptr) { + buffer << ", ref_key=:" << ref_key->name(); + } PrintTupleNodeUsedFlags(buffer, sequence_abs); buffer << ">"; } else { diff --git a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc index b368f761891..4392f3c0082 100644 --- a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc @@ -50,7 +50,7 @@ std::optional GetRefKey(const AnfNodePtr &node) { if (abs_ref == nullptr) { return std::nullopt; } - auto ref_key = abs_ref->ref_key_value(); + auto ref_key = abs_ref->ref_key_value()->cast(); if (ref_key == nullptr) { return std::nullopt; } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc index d6d078aa3d7..773402e9efa 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc @@ -36,14 +36,6 @@ AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceB return nullptr; } -bool IsSideEffectOp(const AnfNodePtr &node) { - if (!node->isa()) { - return false; - } - auto effect_info = GetPrimEffectInfo(GetCNodePrimitive(node)); - return effect_info.memory || effect_info.io; -} - AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const OptimizerPtr &optimizer) { AnfNodePtr expanded_node = nullptr; if (IsValueNode(vnode)) { diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 58ad426123d..9b263b85d2a 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -578,8 +578,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(value); auto abs_value = value->ToAbstract()->cast(); auto ref_key = std::make_shared(param_node->name()); - auto abs_ref_key = ref_key->ToAbstract(); - auto abs_ref = std::make_shared(abs_ref_key, abs_value); + auto abs_ref = std::make_shared(abs_value, ref_key); context->ParallelParameterContextRestoreShape(func_graph, param_node, abs_ref); args_spec.push_back(abs_ref); context->ParallelParameterContextCkptShape(func_graph, param_node, abs_ref); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index e26599a00da..a6b7afa78ed 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -509,17 +509,10 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) { AbstractBasePtrList args_spec_list; - auto is_py_eval = (identifier_ == "PythonPrimEvaluator"); (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr { + [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); auto abstract = conf->ObtainEvalResult()->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - // Broaden the ref_key, while infer python prim for cache - if (is_py_eval && abstract->isa()) { - auto abs_ref = abstract->cast(); - abstract = std::make_shared(abs_ref->ref_key()->Broaden(), abs_ref); - } return abstract; }); return EvalPrim(engine, args_spec_list); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc index 29066a8b6b1..9ca0fa6ccc4 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc @@ -413,7 +413,7 @@ class OrderEnforcer { if (abs_ref == nullptr) { return ""; } - auto ref_key = abs_ref->ref_key_value(); + auto ref_key = abs_ref->ref_key_value()->cast(); if (ref_key == nullptr) { return ""; } @@ -433,7 +433,8 @@ class OrderEnforcer { std::vector GetSpecialLoads(const std::map> &loads_map1, const std::map> &loads_map2, - const std::map> &loads_map3) { + const std::map> &loads_map3, + const std::set &call_lodes) { std::vector need_insert_loads; for (auto &refkey_load : loads_map1) { auto &loads = refkey_load.second; @@ -456,6 +457,12 @@ class OrderEnforcer { (void)need_insert_loads.emplace_back(loads[0]); } } + // Add call node will output is a AbstractRef and ref_key is kAnyValue. + for (const auto &call_lode : call_lodes) { + if (std::find(need_insert_loads.begin(), need_insert_loads.end(), call_lode) == need_insert_loads.end()) { + need_insert_loads.push_back(call_lode); + } + } return need_insert_loads; } @@ -476,11 +483,15 @@ class OrderEnforcer { std::map> refkey_loads; std::map> refkey_loads_in_call_or_partial; std::map> refkey_loads_input_is_call_or_partial; + std::set ref_call_lodes; for (auto &node : check_nodes) { // Record load refkey if (IsPrimitiveCNode(node, prim::kPrimLoad)) { auto load = node->cast(); auto input = load->input(1); + if (CheckLoadInput(input)) { + (void)ref_call_lodes.insert(load); + } auto refkey = GetRefKey(input); if (refkey == "") { MS_LOG(INFO) << "Load without ref key:" << load->DebugString(); @@ -517,7 +528,8 @@ class OrderEnforcer { } } } - return GetSpecialLoads(refkey_loads, refkey_loads_in_call_or_partial, refkey_loads_input_is_call_or_partial); + return GetSpecialLoads(refkey_loads, refkey_loads_in_call_or_partial, refkey_loads_input_is_call_or_partial, + ref_call_lodes); } void InsertTensorMoveForLoad(const CNodePtr &node) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 18080542dff..1d62aae3cae 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -596,7 +596,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_conv auto dic = py::dict(); if (abs_base->isa()) { ConvertAbstractTensorToPython(abs_base, only_convert_value, &dic); - } else if (abs_base->isa() || abs_base->isa() || abs_base->isa()) { + } else if (abs_base->isa() || abs_base->isa()) { ShapeVector shape; dic[ATTR_SHAPE] = shape; dic[ATTR_DTYPE] = abs_base->BuildType(); @@ -1475,16 +1475,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString(); return nullptr; } - auto key_abs = ref_abs->ref_key(); - if (key_abs == nullptr) { - MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr."; - return nullptr; - } - auto key_value = key_abs->BuildValue(); - if (key_value == nullptr) { - MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr."; - return nullptr; - } // Check if the input of RefEmbed is a weight parameter, if not, don't create the // specific SymbolicKey. // Notes: when different weight parameter have same type and shape passed as parameter to same funcgraph @@ -1497,7 +1487,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { MS_EXCEPTION_IF_NULL(param); ifEmbedIsWeight = param->has_default(); } - auto refkey = key_value->cast(); + auto refkey = ref_abs->ref_key_value()->cast(); if (refkey == nullptr || !ifEmbedIsWeight) { auto ret = std::make_shared(type); auto ref_value = ref_abs->ref(); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index df8a478dddb..2d876c324e5 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -816,7 +816,7 @@ bool NeedWaitForBranches(const AbstractBasePtr &abstract) { if (abstract->isa()) { auto elements = abstract->cast()->elements(); if (std::any_of(elements.begin(), elements.end(), - [](const AbstractBasePtr &item) { return item->isa(); })) { + [](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) { return true; } } diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index 7c232315aba..cf5ea7558c8 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -118,12 +118,11 @@ void ValidateAbstract(const AnfNodePtr &node) { MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString(); return; } - bool is_legal_abstract = abstract->isa() || abstract->isa() || - abstract->isa() || abstract->isa() || - abstract->isa() || abstract->isa() || - abstract->isa() || abstract->isa() || - abstract->isa() || abstract->isa() || - abstract->isa() || abstract->isa(); + bool is_legal_abstract = + abstract->isa() || abstract->isa() || abstract->isa() || + abstract->isa() || abstract->isa() || abstract->isa() || + abstract->isa() || abstract->isa() || abstract->isa() || + abstract->isa() || abstract->isa(); if (is_legal_abstract) { return; } diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index c568179c80d..01aa41cccbb 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -603,7 +603,7 @@ bool IrExportBuilder::SetTensorProto(const AbstractBasePtr &abstract, mind_ir::T MS_LOG(ERROR) << "The abstract " << abstract->ToString() << " should be AbstractRef."; return false; } - auto ref_key_value = abs_ref->ref_key_value(); + auto ref_key_value = abs_ref->ref_key_value()->cast(); if (ref_key_value == nullptr) { MS_LOG(INFO) << "The ref_key_value of abstract ref " << abstract->ToString() << " is nullptr"; return true; diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index 79a7de63e6d..b8dd4c02756 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -1153,11 +1153,12 @@ std::string AbstractJTagged::ToString() const { return buffer.str(); } -AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value) - : AbstractTensor(*ref_value), ref_key_(ref_key), ref_key_value_(nullptr) { +AbstractRef::AbstractRef(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value) + : AbstractTensor(*ref_value), ref_key_value_(ref_key_value) { set_type(std::make_shared()); - if (ref_key && ref_key->isa()) { - ref_key_value_ = ref_key->cast()->ref_key_value(); + MS_EXCEPTION_IF_NULL(ref_key_value); + if (ref_key_value != kAnyValue && !ref_key_value->isa()) { + MS_LOG(EXCEPTION) << "ref_key_value must be kAnyValue or RefKey, but got:" << ref_key_value->ToString(); } } @@ -1172,65 +1173,66 @@ bool AbstractRef::operator==(const AbstractRef &other) const { if (this == &other) { return true; } - return IsEqual(ref_key_, other.ref_key_) && AbstractTensor::equal_to(other); + // Check whether the ref_key_value is equal. + if (!IsEqual(ref_key_value_, other.ref_key_value_)) { + return false; + } + // Check whether Tensor value is equal. + return AbstractTensor::equal_to(other); } bool AbstractRef::operator==(const AbstractBase &other) const { - if (this == &other) { - return true; - } if (!other.isa()) { return false; } return *this == static_cast(other); } -AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) { - MS_EXCEPTION_IF_NULL(other); +AbstractBasePtr AbstractRef::Join(const std::shared_ptr &other) { if (*this == *other) { - auto ret = shared_from_base(); - return ret; + return shared_from_base(); } - auto value_self = GetValueTrack(); - MS_EXCEPTION_IF_NULL(value_self); - ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); - if (res_value == value_self) { - auto ret = shared_from_base(); - return ret; - } - auto ret = std::make_shared(); - ret->set_value(res_value); - return ret; + // Firstly, join the ref_key_value. + auto joined_ref_key = ValueJoin(ref_key_value_, other->ref_key_value_); + // Secondly , join the tensor value. + auto joined_tensor = AbstractTensor::Join(other)->cast(); + MS_EXCEPTION_IF_NULL(joined_tensor); + return std::make_shared(joined_tensor, joined_ref_key); } AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); - auto other_ref = other->cast(); - if (other_ref == nullptr) { - auto join_abs = AbstractTensor::Join(other); - MS_EXCEPTION_IF_NULL(join_abs); - return join_abs->cast(); + // Abstract ref join abstract ref + if (other->isa()) { + return AbstractRef::Join(other->cast()); } - MS_EXCEPTION_IF_NULL(ref_key_); - MS_EXCEPTION_IF_NULL(other_ref->ref_key_); - if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) { - return shared_from_base(); + // Abstract ref join other abstract are same to AbstractTensor::Join. + auto joined_tensor = AbstractTensor::Join(other); + if (!joined_tensor->isa()) { + MS_LOG(EXCEPTION) << "Expect an AbstractTensor, but got:" << joined_tensor->ToString() + << ", other:" << other->ToString(); } - auto ref_key = ref_key_->Join(other_ref->ref_key_); - auto joined_abs_tensor = other_ref->ref(); - MS_EXCEPTION_IF_NULL(joined_abs_tensor); - auto ref = AbstractTensor::Join(joined_abs_tensor); - MS_EXCEPTION_IF_NULL(ref); - auto ref_tensor = ref->cast(); - MS_EXCEPTION_IF_NULL(ref_tensor); - return std::make_shared(ref_key, ref_tensor); + return joined_tensor; +} + +AbstractBasePtr AbstractRef::Clone() const { + auto abs_tensor = AbstractTensor::Clone()->cast(); + return std::make_shared(abs_tensor, ref_key_value_); +} + +AbstractBasePtr AbstractRef::Broaden() const { + // always broaden for ref + auto abs_tensor = AbstractTensor::Broaden()->cast(); + // Broaden the tensor value and keep the ref_key_value. + auto ret = std::make_shared(abs_tensor, ref_key_value_); + return ret; } std::string AbstractRef::ToString() const { std::ostringstream buffer; - MS_EXCEPTION_IF_NULL(ref_key_); + MS_EXCEPTION_IF_NULL(ref_key_value_); buffer << type_name() << "(" - << "key: " << ref_key_->ToString() << " ref_value: " << AbstractTensor::ToString(); + << "key: " << ref_key_value_->ToString() << " ref_value: " << AbstractTensor::ToString(); auto value = GetValueTrack(); if (value != nullptr) { buffer << ", value: " << value->ToString(); @@ -1258,41 +1260,6 @@ std::string AbstractNone::ToString() const { ValuePtr AbstractNone::RealBuildValue() const { return kNone; } -bool AbstractRefKey::operator==(const AbstractRefKey &other) const { - ValuePtr v1 = GetValueTrack(); - ValuePtr v2 = other.GetValueTrack(); - if (v1 == v2) { - return true; - } - if (v1 == nullptr || v2 == nullptr) { - return false; - } - if (v1->isa() && v2->isa()) { - return true; - } - return IsEqual(dyn_cast(v1), dyn_cast(v2)); -} - -bool AbstractRefKey::operator==(const AbstractBase &other) const { - if (this == &other) { - return true; - } - if (!other.isa()) { - return false; - } - return *this == static_cast(other); -} - -std::string AbstractRefKey::ToString() const { - std::ostringstream buffer; - buffer << type_name(); - auto value = GetValueTrack(); - if (value != nullptr) { - buffer << "(value: " << value->ToString() << ")"; - } - return buffer.str(); -} - bool AbstractNull::operator==(const AbstractNull &) const { return true; } bool AbstractNull::operator==(const AbstractBase &other) const { diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index e5a60f4e7cb..9732dd4c86d 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -1233,55 +1233,6 @@ class MS_CORE_API AbstractEllipsis final : public AbstractBase { }; using AbstractEllipsisPtr = std::shared_ptr; -/// \brief Class AbstractRefKey describes a RefKey node's abstract value. -class MS_CORE_API AbstractRefKey final : public AbstractBase { - public: - /// \brief Constructor of AbstractRefKey. - AbstractRefKey() : AbstractBase(), ref_key_value_(nullptr) { set_type(std::make_shared()); } - - /// \brief Destructor of AbstractRefKey. - ~AbstractRefKey() override = default; - MS_DECLARE_PARENT(AbstractRefKey, AbstractBase) - - TypePtr BuildType() const override { return std::make_shared(); } - - /// \brief Overwrite the operator '==' to compare other v. - /// - /// \param[in] other The other instance of AbstractTimeOut. - /// - /// \return A boolean, which indicates whether the other abstract is same. - bool operator==(const AbstractRefKey &other) const; - - bool operator==(const AbstractBase &other) const override; - - AbstractBasePtr Clone() const override { - auto cloned = std::make_shared(); - cloned->set_value(GetValueTrack()); - return cloned; - } - - inline void set_value(const ValuePtr &value) { - AbstractBase::set_value(value); - if (value != nullptr) { - ref_key_value_ = value->cast(); - } - } - - /// \brief Get the ref key. - /// - /// \return The pointer to RefKey. - RefKeyPtr ref_key_value() const { return ref_key_value_; } - - AbstractBasePtr Join(const AbstractBasePtr &other) override; - - std::string ToString() const override; - - private: - // cache for ref_key after build value, when value is null, return nullptr. - RefKeyPtr ref_key_value_{nullptr}; -}; -using AbstractRefKeyPtr = std::shared_ptr; - /// \brief Class AbstractRef describes a RefTensor's abstract value. class MS_CORE_API AbstractRef final : public AbstractTensor { public: @@ -1289,7 +1240,7 @@ class MS_CORE_API AbstractRef final : public AbstractTensor { /// /// \param[in] ref_key The ref key of tensor. /// \param[in] ref_value The tensor. - AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value); + AbstractRef(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value); /// \brief Destructor of AbstractEllipsis. ~AbstractRef() override = default; @@ -1306,13 +1257,7 @@ class MS_CORE_API AbstractRef final : public AbstractTensor { bool operator==(const AbstractBase &other) const override; - AbstractBasePtr Clone() const override { - auto abs_tensor = AbstractTensor::Clone()->cast(); - if (abs_tensor == nullptr) { - return nullptr; - } - return std::make_shared(ref_key_->Clone(), abs_tensor); - } + AbstractBasePtr Clone() const override; /// \brief Use parent's AbstractTensor::Clone() to clone an abstract. /// @@ -1326,33 +1271,21 @@ class MS_CORE_API AbstractRef final : public AbstractTensor { /// \return A pointer to the abstract tensor. inline AbstractTensorPtr ref() { return shared_from_base(); } - /// \brief Get the ref key. - /// - /// \return A pointer to the abstract key. - inline AbstractBasePtr ref_key() const { return ref_key_; } - /// \brief Get the ref key value. /// /// \return A point to the RefKey. - inline RefKeyPtr ref_key_value() const { return ref_key_value_; } + inline ValuePtr ref_key_value() const { return ref_key_value_; } - AbstractBasePtr Broaden() const override { - // always broaden for ref - auto abs_tensor = AbstractTensor::Broaden()->cast(); - if (abs_tensor == nullptr) { - return nullptr; - } - return std::make_shared(ref_key_->Broaden(), abs_tensor); - } + AbstractBasePtr Broaden() const override; + virtual AbstractBasePtr Join(const std::shared_ptr &other); AbstractBasePtr Join(const AbstractBasePtr &other) override; AbstractBasePtr PartialBroaden() const override; private: - AbstractBasePtr ref_key_; - // cache for ref_key after build value, when value is null, return nullptr. - RefKeyPtr ref_key_value_; + // ref_key_value is the reference key of AbstractRef, the value can be a string value or kAnyValue + ValuePtr ref_key_value_; }; using AbstractRefPtr = std::shared_ptr; diff --git a/mindspore/core/ir/meta_tensor_extends.cc b/mindspore/core/ir/meta_tensor_extends.cc index 947f9dd782d..813cd7d8f7c 100644 --- a/mindspore/core/ir/meta_tensor_extends.cc +++ b/mindspore/core/ir/meta_tensor_extends.cc @@ -32,8 +32,7 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() { if (is_parameter_) { auto param_name = param_info_->name(); auto ref_key = std::make_shared(param_name); - auto abs_ref_key = ref_key->ToAbstract(); - abs_tensor = std::make_shared(abs_ref_key, abs_tensor); + abs_tensor = std::make_shared(abs_tensor, ref_key); } else { abs_tensor->set_value(shared_from_base()); } diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index 643c65d26ca..e3f1fab1e31 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -747,8 +747,7 @@ abstract::AbstractBasePtr Tensor::ToAbstract() { if (is_parameter_) { auto param_name = param_info_->name(); auto ref_key = std::make_shared(param_name); - auto abs_ref_key = ref_key->ToAbstract(); - abs_tensor = std::make_shared(abs_ref_key, abs_tensor); + abs_tensor = std::make_shared(abs_tensor, ref_key); } else { abs_tensor->set_value(shared_from_base()); } diff --git a/mindspore/core/ir/value_extends.cc b/mindspore/core/ir/value_extends.cc index 3c29e1e9e0c..8daef757331 100644 --- a/mindspore/core/ir/value_extends.cc +++ b/mindspore/core/ir/value_extends.cc @@ -32,9 +32,7 @@ abstract::AbstractBasePtr StringImm::ToAbstract() { } abstract::AbstractBasePtr RefKey::ToAbstract() { - auto refkey = std::make_shared(); - refkey->set_value(shared_from_base()); - return refkey; + MS_LOG(EXCEPTION) << "Ref key can't convert to abstract, ref_key:" << ToString(); } abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared(); } diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index 91d917b5e66..d36b6573e54 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -414,8 +414,7 @@ abstract::AbstractTensorPtr MSANFModelParser::GetAbsTensorFromTensorProto(const auto tensor_info = std::make_shared(TypeIdToType(iter->second), tensor_shape); if (tensor_proto.has_ref_key()) { auto ref_key = std::make_shared(tensor_proto.ref_key()); - auto abs_ref_key = ref_key->ToAbstract(); - auto abs_ref = std::make_shared(abs_ref_key, tensor_info); + auto abs_ref = std::make_shared(tensor_info, ref_key); return abs_ref; } return tensor_info; @@ -550,7 +549,8 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi auto parameter_abs = parameter->abstract(); if (parameter_abs->isa()) { auto parameter_abs_value = parameter_abs->cast()->ref_key_value(); - if (parameter_abs_value->name() == tensor_proto.ref_key()) { + auto ref_key_value = parameter_abs_value->cast(); + if (ref_key_value != nullptr && ref_key_value->name() == tensor_proto.ref_key()) { node->set_default_param(parameter->cast()->default_param()); break; } diff --git a/tests/ut/cpp/ir/value_test.cc b/tests/ut/cpp/ir/value_test.cc index 37966f4a6d7..d6fe555a888 100644 --- a/tests/ut/cpp/ir/value_test.cc +++ b/tests/ut/cpp/ir/value_test.cc @@ -77,13 +77,6 @@ TEST_F(TestValue, testToAbstract) { ret = tv->ToAbstract(); ASSERT_TRUE(ret); ASSERT_EQ(*(ret), *(ta)); - - ValuePtr rv = std::make_shared("net.weight"); - abstract::AbstractRefKeyPtr ra = std::make_shared(); - ra->set_value(rv); - ret = rv->ToAbstract(); - ASSERT_TRUE(ret); - ASSERT_EQ(*(ret), *(ra)); } TEST_F(TestValue, GetValue) {