rm abstract ref key

This commit is contained in:
chenfei 2022-05-10 20:54:21 +08:00
parent 8eed9eca2d
commit 7ce82746ab
17 changed files with 93 additions and 209 deletions

View File

@ -62,12 +62,16 @@ void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) {
}
ValuePtr tensor_value = nullptr;
RefKeyPtr ref_key = nullptr;
abstract::AbstractSequencePtr sequence_abs = nullptr;
auto abstract = node->abstract();
if (abstract != nullptr) {
if (abstract->isa<abstract::AbstractTensor>()) {
tensor_value = abstract->BuildValue();
}
if (abstract->isa<abstract::AbstractRef>()) {
ref_key = abstract->cast<abstract::AbstractRefPtr>()->ref_key_value()->cast<RefKeyPtr>();
}
sequence_abs = dyn_cast<abstract::AbstractSequence>(abstract);
}
@ -78,6 +82,9 @@ void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) {
if (tensor_value != nullptr && tensor_value != kAnyValue) {
buffer << ", value=...";
}
if (ref_key != nullptr) {
buffer << ", ref_key=:" << ref_key->name();
}
PrintTupleNodeUsedFlags(buffer, sequence_abs);
buffer << ">";
} else if (type != nullptr) {
@ -85,6 +92,9 @@ void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) {
if (tensor_value != nullptr && tensor_value != kAnyValue) {
buffer << ", value=...";
}
if (ref_key != nullptr) {
buffer << ", ref_key=:" << ref_key->name();
}
PrintTupleNodeUsedFlags(buffer, sequence_abs);
buffer << ">";
} else {

View File

@ -50,7 +50,7 @@ std::optional<std::string> GetRefKey(const AnfNodePtr &node) {
if (abs_ref == nullptr) {
return std::nullopt;
}
auto ref_key = abs_ref->ref_key_value();
auto ref_key = abs_ref->ref_key_value()->cast<RefKeyPtr>();
if (ref_key == nullptr) {
return std::nullopt;
}

View File

@ -36,14 +36,6 @@ AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceB
return nullptr;
}
bool IsSideEffectOp(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return false;
}
auto effect_info = GetPrimEffectInfo(GetCNodePrimitive(node));
return effect_info.memory || effect_info.io;
}
AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const OptimizerPtr &optimizer) {
AnfNodePtr expanded_node = nullptr;
if (IsValueNode<FuncGraph>(vnode)) {

View File

@ -578,8 +578,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(value);
auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
auto ref_key = std::make_shared<RefKey>(param_node->name());
auto abs_ref_key = ref_key->ToAbstract();
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_value, ref_key);
context->ParallelParameterContextRestoreShape(func_graph, param_node, abs_ref);
args_spec.push_back(abs_ref);
context->ParallelParameterContextCkptShape(func_graph, param_node, abs_ref);

View File

@ -509,17 +509,10 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &) {
AbstractBasePtrList args_spec_list;
auto is_py_eval = (identifier_ == "PythonPrimEvaluator");
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr {
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
auto abstract = conf->ObtainEvalResult()->abstract();
MS_EXCEPTION_IF_NULL(abstract);
// Broaden the ref_key, while infer python prim for cache
if (is_py_eval && abstract->isa<AbstractRef>()) {
auto abs_ref = abstract->cast<AbstractRefPtr>();
abstract = std::make_shared<AbstractRef>(abs_ref->ref_key()->Broaden(), abs_ref);
}
return abstract;
});
return EvalPrim(engine, args_spec_list);

View File

@ -413,7 +413,7 @@ class OrderEnforcer {
if (abs_ref == nullptr) {
return "";
}
auto ref_key = abs_ref->ref_key_value();
auto ref_key = abs_ref->ref_key_value()->cast<RefKeyPtr>();
if (ref_key == nullptr) {
return "";
}
@ -433,7 +433,8 @@ class OrderEnforcer {
std::vector<CNodePtr> GetSpecialLoads(const std::map<std::string, std::vector<CNodePtr>> &loads_map1,
const std::map<std::string, std::vector<CNodePtr>> &loads_map2,
const std::map<std::string, std::vector<CNodePtr>> &loads_map3) {
const std::map<std::string, std::vector<CNodePtr>> &loads_map3,
const std::set<CNodePtr> &call_lodes) {
std::vector<CNodePtr> need_insert_loads;
for (auto &refkey_load : loads_map1) {
auto &loads = refkey_load.second;
@ -456,6 +457,12 @@ class OrderEnforcer {
(void)need_insert_loads.emplace_back(loads[0]);
}
}
// Add call node will output is a AbstractRef and ref_key is kAnyValue.
for (const auto &call_lode : call_lodes) {
if (std::find(need_insert_loads.begin(), need_insert_loads.end(), call_lode) == need_insert_loads.end()) {
need_insert_loads.push_back(call_lode);
}
}
return need_insert_loads;
}
@ -476,11 +483,15 @@ class OrderEnforcer {
std::map<std::string, std::vector<CNodePtr>> refkey_loads;
std::map<std::string, std::vector<CNodePtr>> refkey_loads_in_call_or_partial;
std::map<std::string, std::vector<CNodePtr>> refkey_loads_input_is_call_or_partial;
std::set<CNodePtr> ref_call_lodes;
for (auto &node : check_nodes) {
// Record load refkey
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
auto load = node->cast<CNodePtr>();
auto input = load->input(1);
if (CheckLoadInput(input)) {
(void)ref_call_lodes.insert(load);
}
auto refkey = GetRefKey(input);
if (refkey == "") {
MS_LOG(INFO) << "Load without ref key:" << load->DebugString();
@ -517,7 +528,8 @@ class OrderEnforcer {
}
}
}
return GetSpecialLoads(refkey_loads, refkey_loads_in_call_or_partial, refkey_loads_input_is_call_or_partial);
return GetSpecialLoads(refkey_loads, refkey_loads_in_call_or_partial, refkey_loads_input_is_call_or_partial,
ref_call_lodes);
}
void InsertTensorMoveForLoad(const CNodePtr &node) {

View File

@ -596,7 +596,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_conv
auto dic = py::dict();
if (abs_base->isa<AbstractTensor>()) {
ConvertAbstractTensorToPython(abs_base, only_convert_value, &dic);
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>()) {
ShapeVector shape;
dic[ATTR_SHAPE] = shape;
dic[ATTR_DTYPE] = abs_base->BuildType();
@ -1475,16 +1475,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
return nullptr;
}
auto key_abs = ref_abs->ref_key();
if (key_abs == nullptr) {
MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr.";
return nullptr;
}
auto key_value = key_abs->BuildValue();
if (key_value == nullptr) {
MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
return nullptr;
}
// Check if the input of RefEmbed is a weight parameter, if not, don't create the
// specific SymbolicKey.
// Notes: when different weight parameter have same type and shape passed as parameter to same funcgraph
@ -1497,7 +1487,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_EXCEPTION_IF_NULL(param);
ifEmbedIsWeight = param->has_default();
}
auto refkey = key_value->cast<RefKeyPtr>();
auto refkey = ref_abs->ref_key_value()->cast<RefKeyPtr>();
if (refkey == nullptr || !ifEmbedIsWeight) {
auto ret = std::make_shared<AbstractScalar>(type);
auto ref_value = ref_abs->ref();

View File

@ -816,7 +816,7 @@ bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
if (abstract->isa<AbstractSequence>()) {
auto elements = abstract->cast<AbstractSequencePtr>()->elements();
if (std::any_of(elements.begin(), elements.end(),
[](const AbstractBasePtr &item) { return item->isa<AbstractFunction>(); })) {
[](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
return true;
}
}

View File

@ -118,12 +118,11 @@ void ValidateAbstract(const AnfNodePtr &node) {
MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString();
return;
}
bool is_legal_abstract = abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() ||
abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
abstract->isa<AbstractCOOTensor>() || abstract->isa<AbstractCSRTensor>() ||
abstract->isa<abstract::AbstractRefKey>() || abstract->isa<AbstractRef>() ||
abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>();
bool is_legal_abstract =
abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() || abstract->isa<AbstractTuple>() ||
abstract->isa<AbstractList>() || abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
abstract->isa<AbstractCOOTensor>() || abstract->isa<AbstractCSRTensor>() || abstract->isa<AbstractRef>() ||
abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>();
if (is_legal_abstract) {
return;
}

View File

@ -603,7 +603,7 @@ bool IrExportBuilder::SetTensorProto(const AbstractBasePtr &abstract, mind_ir::T
MS_LOG(ERROR) << "The abstract " << abstract->ToString() << " should be AbstractRef.";
return false;
}
auto ref_key_value = abs_ref->ref_key_value();
auto ref_key_value = abs_ref->ref_key_value()->cast<RefKeyPtr>();
if (ref_key_value == nullptr) {
MS_LOG(INFO) << "The ref_key_value of abstract ref " << abstract->ToString() << " is nullptr";
return true;

View File

@ -1153,11 +1153,12 @@ std::string AbstractJTagged::ToString() const {
return buffer.str();
}
AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value)
: AbstractTensor(*ref_value), ref_key_(ref_key), ref_key_value_(nullptr) {
AbstractRef::AbstractRef(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value)
: AbstractTensor(*ref_value), ref_key_value_(ref_key_value) {
set_type(std::make_shared<RefType>());
if (ref_key && ref_key->isa<AbstractRefKey>()) {
ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value();
MS_EXCEPTION_IF_NULL(ref_key_value);
if (ref_key_value != kAnyValue && !ref_key_value->isa<RefKey>()) {
MS_LOG(EXCEPTION) << "ref_key_value must be kAnyValue or RefKey, but got:" << ref_key_value->ToString();
}
}
@ -1172,65 +1173,66 @@ bool AbstractRef::operator==(const AbstractRef &other) const {
if (this == &other) {
return true;
}
return IsEqual(ref_key_, other.ref_key_) && AbstractTensor::equal_to(other);
// Check whether the ref_key_value is equal.
if (!IsEqual(ref_key_value_, other.ref_key_value_)) {
return false;
}
// Check whether Tensor value is equal.
return AbstractTensor::equal_to(other);
}
bool AbstractRef::operator==(const AbstractBase &other) const {
if (this == &other) {
return true;
}
if (!other.isa<AbstractRef>()) {
return false;
}
return *this == static_cast<const AbstractRef &>(other);
}
AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
AbstractBasePtr AbstractRef::Join(const std::shared_ptr<AbstractRef> &other) {
if (*this == *other) {
auto ret = shared_from_base<AbstractBase>();
return ret;
return shared_from_base<AbstractRef>();
}
auto value_self = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_self);
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
if (res_value == value_self) {
auto ret = shared_from_base<AbstractBase>();
return ret;
}
auto ret = std::make_shared<AbstractRefKey>();
ret->set_value(res_value);
return ret;
// Firstly, join the ref_key_value.
auto joined_ref_key = ValueJoin(ref_key_value_, other->ref_key_value_);
// Secondly , join the tensor value.
auto joined_tensor = AbstractTensor::Join(other)->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(joined_tensor);
return std::make_shared<AbstractRef>(joined_tensor, joined_ref_key);
}
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) {
auto join_abs = AbstractTensor::Join(other);
MS_EXCEPTION_IF_NULL(join_abs);
return join_abs->cast<AbstractTensorPtr>();
// Abstract ref join abstract ref
if (other->isa<AbstractRef>()) {
return AbstractRef::Join(other->cast<AbstractRefPtr>());
}
MS_EXCEPTION_IF_NULL(ref_key_);
MS_EXCEPTION_IF_NULL(other_ref->ref_key_);
if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) {
return shared_from_base<AbstractBase>();
// Abstract ref join other abstract are same to AbstractTensor::Join.
auto joined_tensor = AbstractTensor::Join(other);
if (!joined_tensor->isa<AbstractTensor>()) {
MS_LOG(EXCEPTION) << "Expect an AbstractTensor, but got:" << joined_tensor->ToString()
<< ", other:" << other->ToString();
}
auto ref_key = ref_key_->Join(other_ref->ref_key_);
auto joined_abs_tensor = other_ref->ref();
MS_EXCEPTION_IF_NULL(joined_abs_tensor);
auto ref = AbstractTensor::Join(joined_abs_tensor);
MS_EXCEPTION_IF_NULL(ref);
auto ref_tensor = ref->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(ref_tensor);
return std::make_shared<AbstractRef>(ref_key, ref_tensor);
return joined_tensor;
}
AbstractBasePtr AbstractRef::Clone() const {
auto abs_tensor = AbstractTensor::Clone()->cast<AbstractTensorPtr>();
return std::make_shared<AbstractRef>(abs_tensor, ref_key_value_);
}
AbstractBasePtr AbstractRef::Broaden() const {
// always broaden for ref
auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>();
// Broaden the tensor value and keep the ref_key_value.
auto ret = std::make_shared<AbstractRef>(abs_tensor, ref_key_value_);
return ret;
}
std::string AbstractRef::ToString() const {
std::ostringstream buffer;
MS_EXCEPTION_IF_NULL(ref_key_);
MS_EXCEPTION_IF_NULL(ref_key_value_);
buffer << type_name() << "("
<< "key: " << ref_key_->ToString() << " ref_value: " << AbstractTensor::ToString();
<< "key: " << ref_key_value_->ToString() << " ref_value: " << AbstractTensor::ToString();
auto value = GetValueTrack();
if (value != nullptr) {
buffer << ", value: " << value->ToString();
@ -1258,41 +1260,6 @@ std::string AbstractNone::ToString() const {
ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
bool AbstractRefKey::operator==(const AbstractRefKey &other) const {
ValuePtr v1 = GetValueTrack();
ValuePtr v2 = other.GetValueTrack();
if (v1 == v2) {
return true;
}
if (v1 == nullptr || v2 == nullptr) {
return false;
}
if (v1->isa<AnyValue>() && v2->isa<AnyValue>()) {
return true;
}
return IsEqual(dyn_cast<RefKey>(v1), dyn_cast<RefKey>(v2));
}
bool AbstractRefKey::operator==(const AbstractBase &other) const {
if (this == &other) {
return true;
}
if (!other.isa<AbstractRefKey>()) {
return false;
}
return *this == static_cast<const AbstractRefKey &>(other);
}
std::string AbstractRefKey::ToString() const {
std::ostringstream buffer;
buffer << type_name();
auto value = GetValueTrack();
if (value != nullptr) {
buffer << "(value: " << value->ToString() << ")";
}
return buffer.str();
}
bool AbstractNull::operator==(const AbstractNull &) const { return true; }
bool AbstractNull::operator==(const AbstractBase &other) const {

View File

@ -1233,55 +1233,6 @@ class MS_CORE_API AbstractEllipsis final : public AbstractBase {
};
using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>;
/// \brief Class AbstractRefKey describes a RefKey node's abstract value.
class MS_CORE_API AbstractRefKey final : public AbstractBase {
public:
/// \brief Constructor of AbstractRefKey.
AbstractRefKey() : AbstractBase(), ref_key_value_(nullptr) { set_type(std::make_shared<RefKeyType>()); }
/// \brief Destructor of AbstractRefKey.
~AbstractRefKey() override = default;
MS_DECLARE_PARENT(AbstractRefKey, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<RefKeyType>(); }
/// \brief Overwrite the operator '==' to compare other v.
///
/// \param[in] other The other instance of AbstractTimeOut.
///
/// \return A boolean, which indicates whether the other abstract is same.
bool operator==(const AbstractRefKey &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override {
auto cloned = std::make_shared<AbstractRefKey>();
cloned->set_value(GetValueTrack());
return cloned;
}
inline void set_value(const ValuePtr &value) {
AbstractBase::set_value(value);
if (value != nullptr) {
ref_key_value_ = value->cast<RefKeyPtr>();
}
}
/// \brief Get the ref key.
///
/// \return The pointer to RefKey.
RefKeyPtr ref_key_value() const { return ref_key_value_; }
AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::string ToString() const override;
private:
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr ref_key_value_{nullptr};
};
using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>;
/// \brief Class AbstractRef describes a RefTensor's abstract value.
class MS_CORE_API AbstractRef final : public AbstractTensor {
public:
@ -1289,7 +1240,7 @@ class MS_CORE_API AbstractRef final : public AbstractTensor {
///
/// \param[in] ref_key The ref key of tensor.
/// \param[in] ref_value The tensor.
AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value);
AbstractRef(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value);
/// \brief Destructor of AbstractEllipsis.
~AbstractRef() override = default;
@ -1306,13 +1257,7 @@ class MS_CORE_API AbstractRef final : public AbstractTensor {
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override {
auto abs_tensor = AbstractTensor::Clone()->cast<AbstractTensorPtr>();
if (abs_tensor == nullptr) {
return nullptr;
}
return std::make_shared<AbstractRef>(ref_key_->Clone(), abs_tensor);
}
AbstractBasePtr Clone() const override;
/// \brief Use parent's AbstractTensor::Clone() to clone an abstract.
///
@ -1326,33 +1271,21 @@ class MS_CORE_API AbstractRef final : public AbstractTensor {
/// \return A pointer to the abstract tensor.
inline AbstractTensorPtr ref() { return shared_from_base<AbstractTensor>(); }
/// \brief Get the ref key.
///
/// \return A pointer to the abstract key.
inline AbstractBasePtr ref_key() const { return ref_key_; }
/// \brief Get the ref key value.
///
/// \return A point to the RefKey.
inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
inline ValuePtr ref_key_value() const { return ref_key_value_; }
AbstractBasePtr Broaden() const override {
// always broaden for ref
auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>();
if (abs_tensor == nullptr) {
return nullptr;
}
return std::make_shared<AbstractRef>(ref_key_->Broaden(), abs_tensor);
}
AbstractBasePtr Broaden() const override;
virtual AbstractBasePtr Join(const std::shared_ptr<AbstractRef> &other);
AbstractBasePtr Join(const AbstractBasePtr &other) override;
AbstractBasePtr PartialBroaden() const override;
private:
AbstractBasePtr ref_key_;
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr ref_key_value_;
// ref_key_value is the reference key of AbstractRef, the value can be a string value or kAnyValue
ValuePtr ref_key_value_;
};
using AbstractRefPtr = std::shared_ptr<AbstractRef>;

View File

@ -32,8 +32,7 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() {
if (is_parameter_) {
auto param_name = param_info_->name();
auto ref_key = std::make_shared<RefKey>(param_name);
auto abs_ref_key = ref_key->ToAbstract();
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_tensor, ref_key);
} else {
abs_tensor->set_value(shared_from_base<MetaTensor>());
}

View File

@ -747,8 +747,7 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
if (is_parameter_) {
auto param_name = param_info_->name();
auto ref_key = std::make_shared<RefKey>(param_name);
auto abs_ref_key = ref_key->ToAbstract();
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_tensor, ref_key);
} else {
abs_tensor->set_value(shared_from_base<Tensor>());
}

View File

@ -32,9 +32,7 @@ abstract::AbstractBasePtr StringImm::ToAbstract() {
}
abstract::AbstractBasePtr RefKey::ToAbstract() {
auto refkey = std::make_shared<abstract::AbstractRefKey>();
refkey->set_value(shared_from_base<Value>());
return refkey;
MS_LOG(EXCEPTION) << "Ref key can't convert to abstract, ref_key:" << ToString();
}
abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstract::AbstractScalar>(); }

View File

@ -414,8 +414,7 @@ abstract::AbstractTensorPtr MSANFModelParser::GetAbsTensorFromTensorProto(const
auto tensor_info = std::make_shared<abstract::AbstractTensor>(TypeIdToType(iter->second), tensor_shape);
if (tensor_proto.has_ref_key()) {
auto ref_key = std::make_shared<RefKey>(tensor_proto.ref_key());
auto abs_ref_key = ref_key->ToAbstract();
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, tensor_info);
auto abs_ref = std::make_shared<abstract::AbstractRef>(tensor_info, ref_key);
return abs_ref;
}
return tensor_info;
@ -550,7 +549,8 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
auto parameter_abs = parameter->abstract();
if (parameter_abs->isa<abstract::AbstractRef>()) {
auto parameter_abs_value = parameter_abs->cast<abstract::AbstractRefPtr>()->ref_key_value();
if (parameter_abs_value->name() == tensor_proto.ref_key()) {
auto ref_key_value = parameter_abs_value->cast<RefKeyPtr>();
if (ref_key_value != nullptr && ref_key_value->name() == tensor_proto.ref_key()) {
node->set_default_param(parameter->cast<ParameterPtr>()->default_param());
break;
}

View File

@ -77,13 +77,6 @@ TEST_F(TestValue, testToAbstract) {
ret = tv->ToAbstract();
ASSERT_TRUE(ret);
ASSERT_EQ(*(ret), *(ta));
ValuePtr rv = std::make_shared<RefKey>("net.weight");
abstract::AbstractRefKeyPtr ra = std::make_shared<abstract::AbstractRefKey>();
ra->set_value(rv);
ret = rv->ToAbstract();
ASSERT_TRUE(ret);
ASSERT_EQ(*(ret), *(ra));
}
TEST_F(TestValue, GetValue) {