forked from mindspore-Ecosystem/mindspore
rm abstract ref key
This commit is contained in:
parent
8eed9eca2d
commit
7ce82746ab
|
@ -62,12 +62,16 @@ void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
ValuePtr tensor_value = nullptr;
|
||||
RefKeyPtr ref_key = nullptr;
|
||||
abstract::AbstractSequencePtr sequence_abs = nullptr;
|
||||
auto abstract = node->abstract();
|
||||
if (abstract != nullptr) {
|
||||
if (abstract->isa<abstract::AbstractTensor>()) {
|
||||
tensor_value = abstract->BuildValue();
|
||||
}
|
||||
if (abstract->isa<abstract::AbstractRef>()) {
|
||||
ref_key = abstract->cast<abstract::AbstractRefPtr>()->ref_key_value()->cast<RefKeyPtr>();
|
||||
}
|
||||
sequence_abs = dyn_cast<abstract::AbstractSequence>(abstract);
|
||||
}
|
||||
|
||||
|
@ -78,6 +82,9 @@ void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) {
|
|||
if (tensor_value != nullptr && tensor_value != kAnyValue) {
|
||||
buffer << ", value=...";
|
||||
}
|
||||
if (ref_key != nullptr) {
|
||||
buffer << ", ref_key=:" << ref_key->name();
|
||||
}
|
||||
PrintTupleNodeUsedFlags(buffer, sequence_abs);
|
||||
buffer << ">";
|
||||
} else if (type != nullptr) {
|
||||
|
@ -85,6 +92,9 @@ void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) {
|
|||
if (tensor_value != nullptr && tensor_value != kAnyValue) {
|
||||
buffer << ", value=...";
|
||||
}
|
||||
if (ref_key != nullptr) {
|
||||
buffer << ", ref_key=:" << ref_key->name();
|
||||
}
|
||||
PrintTupleNodeUsedFlags(buffer, sequence_abs);
|
||||
buffer << ">";
|
||||
} else {
|
||||
|
|
|
@ -50,7 +50,7 @@ std::optional<std::string> GetRefKey(const AnfNodePtr &node) {
|
|||
if (abs_ref == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto ref_key = abs_ref->ref_key_value();
|
||||
auto ref_key = abs_ref->ref_key_value()->cast<RefKeyPtr>();
|
||||
if (ref_key == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
|
|
@ -36,14 +36,6 @@ AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceB
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
bool IsSideEffectOp(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto effect_info = GetPrimEffectInfo(GetCNodePrimitive(node));
|
||||
return effect_info.memory || effect_info.io;
|
||||
}
|
||||
|
||||
AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const OptimizerPtr &optimizer) {
|
||||
AnfNodePtr expanded_node = nullptr;
|
||||
if (IsValueNode<FuncGraph>(vnode)) {
|
||||
|
|
|
@ -578,8 +578,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
|
||||
auto ref_key = std::make_shared<RefKey>(param_node->name());
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_value, ref_key);
|
||||
context->ParallelParameterContextRestoreShape(func_graph, param_node, abs_ref);
|
||||
args_spec.push_back(abs_ref);
|
||||
context->ParallelParameterContextCkptShape(func_graph, param_node, abs_ref);
|
||||
|
|
|
@ -509,17 +509,10 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
|
|||
EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
auto is_py_eval = (identifier_ == "PythonPrimEvaluator");
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
auto abstract = conf->ObtainEvalResult()->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
// Broaden the ref_key, while infer python prim for cache
|
||||
if (is_py_eval && abstract->isa<AbstractRef>()) {
|
||||
auto abs_ref = abstract->cast<AbstractRefPtr>();
|
||||
abstract = std::make_shared<AbstractRef>(abs_ref->ref_key()->Broaden(), abs_ref);
|
||||
}
|
||||
return abstract;
|
||||
});
|
||||
return EvalPrim(engine, args_spec_list);
|
||||
|
|
|
@ -413,7 +413,7 @@ class OrderEnforcer {
|
|||
if (abs_ref == nullptr) {
|
||||
return "";
|
||||
}
|
||||
auto ref_key = abs_ref->ref_key_value();
|
||||
auto ref_key = abs_ref->ref_key_value()->cast<RefKeyPtr>();
|
||||
if (ref_key == nullptr) {
|
||||
return "";
|
||||
}
|
||||
|
@ -433,7 +433,8 @@ class OrderEnforcer {
|
|||
|
||||
std::vector<CNodePtr> GetSpecialLoads(const std::map<std::string, std::vector<CNodePtr>> &loads_map1,
|
||||
const std::map<std::string, std::vector<CNodePtr>> &loads_map2,
|
||||
const std::map<std::string, std::vector<CNodePtr>> &loads_map3) {
|
||||
const std::map<std::string, std::vector<CNodePtr>> &loads_map3,
|
||||
const std::set<CNodePtr> &call_lodes) {
|
||||
std::vector<CNodePtr> need_insert_loads;
|
||||
for (auto &refkey_load : loads_map1) {
|
||||
auto &loads = refkey_load.second;
|
||||
|
@ -456,6 +457,12 @@ class OrderEnforcer {
|
|||
(void)need_insert_loads.emplace_back(loads[0]);
|
||||
}
|
||||
}
|
||||
// Add call node will output is a AbstractRef and ref_key is kAnyValue.
|
||||
for (const auto &call_lode : call_lodes) {
|
||||
if (std::find(need_insert_loads.begin(), need_insert_loads.end(), call_lode) == need_insert_loads.end()) {
|
||||
need_insert_loads.push_back(call_lode);
|
||||
}
|
||||
}
|
||||
return need_insert_loads;
|
||||
}
|
||||
|
||||
|
@ -476,11 +483,15 @@ class OrderEnforcer {
|
|||
std::map<std::string, std::vector<CNodePtr>> refkey_loads;
|
||||
std::map<std::string, std::vector<CNodePtr>> refkey_loads_in_call_or_partial;
|
||||
std::map<std::string, std::vector<CNodePtr>> refkey_loads_input_is_call_or_partial;
|
||||
std::set<CNodePtr> ref_call_lodes;
|
||||
for (auto &node : check_nodes) {
|
||||
// Record load refkey
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
auto load = node->cast<CNodePtr>();
|
||||
auto input = load->input(1);
|
||||
if (CheckLoadInput(input)) {
|
||||
(void)ref_call_lodes.insert(load);
|
||||
}
|
||||
auto refkey = GetRefKey(input);
|
||||
if (refkey == "") {
|
||||
MS_LOG(INFO) << "Load without ref key:" << load->DebugString();
|
||||
|
@ -517,7 +528,8 @@ class OrderEnforcer {
|
|||
}
|
||||
}
|
||||
}
|
||||
return GetSpecialLoads(refkey_loads, refkey_loads_in_call_or_partial, refkey_loads_input_is_call_or_partial);
|
||||
return GetSpecialLoads(refkey_loads, refkey_loads_in_call_or_partial, refkey_loads_input_is_call_or_partial,
|
||||
ref_call_lodes);
|
||||
}
|
||||
|
||||
void InsertTensorMoveForLoad(const CNodePtr &node) {
|
||||
|
|
|
@ -596,7 +596,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_conv
|
|||
auto dic = py::dict();
|
||||
if (abs_base->isa<AbstractTensor>()) {
|
||||
ConvertAbstractTensorToPython(abs_base, only_convert_value, &dic);
|
||||
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
|
||||
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>()) {
|
||||
ShapeVector shape;
|
||||
dic[ATTR_SHAPE] = shape;
|
||||
dic[ATTR_DTYPE] = abs_base->BuildType();
|
||||
|
@ -1475,16 +1475,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
auto key_abs = ref_abs->ref_key();
|
||||
if (key_abs == nullptr) {
|
||||
MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto key_value = key_abs->BuildValue();
|
||||
if (key_value == nullptr) {
|
||||
MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
// Check if the input of RefEmbed is a weight parameter, if not, don't create the
|
||||
// specific SymbolicKey.
|
||||
// Notes: when different weight parameter have same type and shape passed as parameter to same funcgraph
|
||||
|
@ -1497,7 +1487,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
MS_EXCEPTION_IF_NULL(param);
|
||||
ifEmbedIsWeight = param->has_default();
|
||||
}
|
||||
auto refkey = key_value->cast<RefKeyPtr>();
|
||||
auto refkey = ref_abs->ref_key_value()->cast<RefKeyPtr>();
|
||||
if (refkey == nullptr || !ifEmbedIsWeight) {
|
||||
auto ret = std::make_shared<AbstractScalar>(type);
|
||||
auto ref_value = ref_abs->ref();
|
||||
|
|
|
@ -816,7 +816,7 @@ bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
|
|||
if (abstract->isa<AbstractSequence>()) {
|
||||
auto elements = abstract->cast<AbstractSequencePtr>()->elements();
|
||||
if (std::any_of(elements.begin(), elements.end(),
|
||||
[](const AbstractBasePtr &item) { return item->isa<AbstractFunction>(); })) {
|
||||
[](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -118,12 +118,11 @@ void ValidateAbstract(const AnfNodePtr &node) {
|
|||
MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString();
|
||||
return;
|
||||
}
|
||||
bool is_legal_abstract = abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() ||
|
||||
abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
|
||||
abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
|
||||
abstract->isa<AbstractCOOTensor>() || abstract->isa<AbstractCSRTensor>() ||
|
||||
abstract->isa<abstract::AbstractRefKey>() || abstract->isa<AbstractRef>() ||
|
||||
abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>();
|
||||
bool is_legal_abstract =
|
||||
abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() || abstract->isa<AbstractTuple>() ||
|
||||
abstract->isa<AbstractList>() || abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
|
||||
abstract->isa<AbstractCOOTensor>() || abstract->isa<AbstractCSRTensor>() || abstract->isa<AbstractRef>() ||
|
||||
abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>();
|
||||
if (is_legal_abstract) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -603,7 +603,7 @@ bool IrExportBuilder::SetTensorProto(const AbstractBasePtr &abstract, mind_ir::T
|
|||
MS_LOG(ERROR) << "The abstract " << abstract->ToString() << " should be AbstractRef.";
|
||||
return false;
|
||||
}
|
||||
auto ref_key_value = abs_ref->ref_key_value();
|
||||
auto ref_key_value = abs_ref->ref_key_value()->cast<RefKeyPtr>();
|
||||
if (ref_key_value == nullptr) {
|
||||
MS_LOG(INFO) << "The ref_key_value of abstract ref " << abstract->ToString() << " is nullptr";
|
||||
return true;
|
||||
|
|
|
@ -1153,11 +1153,12 @@ std::string AbstractJTagged::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value)
|
||||
: AbstractTensor(*ref_value), ref_key_(ref_key), ref_key_value_(nullptr) {
|
||||
AbstractRef::AbstractRef(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value)
|
||||
: AbstractTensor(*ref_value), ref_key_value_(ref_key_value) {
|
||||
set_type(std::make_shared<RefType>());
|
||||
if (ref_key && ref_key->isa<AbstractRefKey>()) {
|
||||
ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value();
|
||||
MS_EXCEPTION_IF_NULL(ref_key_value);
|
||||
if (ref_key_value != kAnyValue && !ref_key_value->isa<RefKey>()) {
|
||||
MS_LOG(EXCEPTION) << "ref_key_value must be kAnyValue or RefKey, but got:" << ref_key_value->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1172,65 +1173,66 @@ bool AbstractRef::operator==(const AbstractRef &other) const {
|
|||
if (this == &other) {
|
||||
return true;
|
||||
}
|
||||
return IsEqual(ref_key_, other.ref_key_) && AbstractTensor::equal_to(other);
|
||||
// Check whether the ref_key_value is equal.
|
||||
if (!IsEqual(ref_key_value_, other.ref_key_value_)) {
|
||||
return false;
|
||||
}
|
||||
// Check whether Tensor value is equal.
|
||||
return AbstractTensor::equal_to(other);
|
||||
}
|
||||
|
||||
bool AbstractRef::operator==(const AbstractBase &other) const {
|
||||
if (this == &other) {
|
||||
return true;
|
||||
}
|
||||
if (!other.isa<AbstractRef>()) {
|
||||
return false;
|
||||
}
|
||||
return *this == static_cast<const AbstractRef &>(other);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
AbstractBasePtr AbstractRef::Join(const std::shared_ptr<AbstractRef> &other) {
|
||||
if (*this == *other) {
|
||||
auto ret = shared_from_base<AbstractBase>();
|
||||
return ret;
|
||||
return shared_from_base<AbstractRef>();
|
||||
}
|
||||
auto value_self = GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_self);
|
||||
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
|
||||
if (res_value == value_self) {
|
||||
auto ret = shared_from_base<AbstractBase>();
|
||||
return ret;
|
||||
}
|
||||
auto ret = std::make_shared<AbstractRefKey>();
|
||||
ret->set_value(res_value);
|
||||
return ret;
|
||||
// Firstly, join the ref_key_value.
|
||||
auto joined_ref_key = ValueJoin(ref_key_value_, other->ref_key_value_);
|
||||
// Secondly , join the tensor value.
|
||||
auto joined_tensor = AbstractTensor::Join(other)->cast<AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(joined_tensor);
|
||||
return std::make_shared<AbstractRef>(joined_tensor, joined_ref_key);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
auto other_ref = other->cast<AbstractRefPtr>();
|
||||
if (other_ref == nullptr) {
|
||||
auto join_abs = AbstractTensor::Join(other);
|
||||
MS_EXCEPTION_IF_NULL(join_abs);
|
||||
return join_abs->cast<AbstractTensorPtr>();
|
||||
// Abstract ref join abstract ref
|
||||
if (other->isa<AbstractRef>()) {
|
||||
return AbstractRef::Join(other->cast<AbstractRefPtr>());
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ref_key_);
|
||||
MS_EXCEPTION_IF_NULL(other_ref->ref_key_);
|
||||
if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
// Abstract ref join other abstract are same to AbstractTensor::Join.
|
||||
auto joined_tensor = AbstractTensor::Join(other);
|
||||
if (!joined_tensor->isa<AbstractTensor>()) {
|
||||
MS_LOG(EXCEPTION) << "Expect an AbstractTensor, but got:" << joined_tensor->ToString()
|
||||
<< ", other:" << other->ToString();
|
||||
}
|
||||
auto ref_key = ref_key_->Join(other_ref->ref_key_);
|
||||
auto joined_abs_tensor = other_ref->ref();
|
||||
MS_EXCEPTION_IF_NULL(joined_abs_tensor);
|
||||
auto ref = AbstractTensor::Join(joined_abs_tensor);
|
||||
MS_EXCEPTION_IF_NULL(ref);
|
||||
auto ref_tensor = ref->cast<AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(ref_tensor);
|
||||
return std::make_shared<AbstractRef>(ref_key, ref_tensor);
|
||||
return joined_tensor;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractRef::Clone() const {
|
||||
auto abs_tensor = AbstractTensor::Clone()->cast<AbstractTensorPtr>();
|
||||
return std::make_shared<AbstractRef>(abs_tensor, ref_key_value_);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractRef::Broaden() const {
|
||||
// always broaden for ref
|
||||
auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>();
|
||||
// Broaden the tensor value and keep the ref_key_value.
|
||||
auto ret = std::make_shared<AbstractRef>(abs_tensor, ref_key_value_);
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string AbstractRef::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
MS_EXCEPTION_IF_NULL(ref_key_);
|
||||
MS_EXCEPTION_IF_NULL(ref_key_value_);
|
||||
buffer << type_name() << "("
|
||||
<< "key: " << ref_key_->ToString() << " ref_value: " << AbstractTensor::ToString();
|
||||
<< "key: " << ref_key_value_->ToString() << " ref_value: " << AbstractTensor::ToString();
|
||||
auto value = GetValueTrack();
|
||||
if (value != nullptr) {
|
||||
buffer << ", value: " << value->ToString();
|
||||
|
@ -1258,41 +1260,6 @@ std::string AbstractNone::ToString() const {
|
|||
|
||||
ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
|
||||
|
||||
bool AbstractRefKey::operator==(const AbstractRefKey &other) const {
|
||||
ValuePtr v1 = GetValueTrack();
|
||||
ValuePtr v2 = other.GetValueTrack();
|
||||
if (v1 == v2) {
|
||||
return true;
|
||||
}
|
||||
if (v1 == nullptr || v2 == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (v1->isa<AnyValue>() && v2->isa<AnyValue>()) {
|
||||
return true;
|
||||
}
|
||||
return IsEqual(dyn_cast<RefKey>(v1), dyn_cast<RefKey>(v2));
|
||||
}
|
||||
|
||||
bool AbstractRefKey::operator==(const AbstractBase &other) const {
|
||||
if (this == &other) {
|
||||
return true;
|
||||
}
|
||||
if (!other.isa<AbstractRefKey>()) {
|
||||
return false;
|
||||
}
|
||||
return *this == static_cast<const AbstractRefKey &>(other);
|
||||
}
|
||||
|
||||
std::string AbstractRefKey::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
buffer << type_name();
|
||||
auto value = GetValueTrack();
|
||||
if (value != nullptr) {
|
||||
buffer << "(value: " << value->ToString() << ")";
|
||||
}
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
bool AbstractNull::operator==(const AbstractNull &) const { return true; }
|
||||
|
||||
bool AbstractNull::operator==(const AbstractBase &other) const {
|
||||
|
|
|
@ -1233,55 +1233,6 @@ class MS_CORE_API AbstractEllipsis final : public AbstractBase {
|
|||
};
|
||||
using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>;
|
||||
|
||||
/// \brief Class AbstractRefKey describes a RefKey node's abstract value.
|
||||
class MS_CORE_API AbstractRefKey final : public AbstractBase {
|
||||
public:
|
||||
/// \brief Constructor of AbstractRefKey.
|
||||
AbstractRefKey() : AbstractBase(), ref_key_value_(nullptr) { set_type(std::make_shared<RefKeyType>()); }
|
||||
|
||||
/// \brief Destructor of AbstractRefKey.
|
||||
~AbstractRefKey() override = default;
|
||||
MS_DECLARE_PARENT(AbstractRefKey, AbstractBase)
|
||||
|
||||
TypePtr BuildType() const override { return std::make_shared<RefKeyType>(); }
|
||||
|
||||
/// \brief Overwrite the operator '==' to compare other v.
|
||||
///
|
||||
/// \param[in] other The other instance of AbstractTimeOut.
|
||||
///
|
||||
/// \return A boolean, which indicates whether the other abstract is same.
|
||||
bool operator==(const AbstractRefKey &other) const;
|
||||
|
||||
bool operator==(const AbstractBase &other) const override;
|
||||
|
||||
AbstractBasePtr Clone() const override {
|
||||
auto cloned = std::make_shared<AbstractRefKey>();
|
||||
cloned->set_value(GetValueTrack());
|
||||
return cloned;
|
||||
}
|
||||
|
||||
inline void set_value(const ValuePtr &value) {
|
||||
AbstractBase::set_value(value);
|
||||
if (value != nullptr) {
|
||||
ref_key_value_ = value->cast<RefKeyPtr>();
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Get the ref key.
|
||||
///
|
||||
/// \return The pointer to RefKey.
|
||||
RefKeyPtr ref_key_value() const { return ref_key_value_; }
|
||||
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other) override;
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
private:
|
||||
// cache for ref_key after build value, when value is null, return nullptr.
|
||||
RefKeyPtr ref_key_value_{nullptr};
|
||||
};
|
||||
using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>;
|
||||
|
||||
/// \brief Class AbstractRef describes a RefTensor's abstract value.
|
||||
class MS_CORE_API AbstractRef final : public AbstractTensor {
|
||||
public:
|
||||
|
@ -1289,7 +1240,7 @@ class MS_CORE_API AbstractRef final : public AbstractTensor {
|
|||
///
|
||||
/// \param[in] ref_key The ref key of tensor.
|
||||
/// \param[in] ref_value The tensor.
|
||||
AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value);
|
||||
AbstractRef(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value);
|
||||
|
||||
/// \brief Destructor of AbstractEllipsis.
|
||||
~AbstractRef() override = default;
|
||||
|
@ -1306,13 +1257,7 @@ class MS_CORE_API AbstractRef final : public AbstractTensor {
|
|||
|
||||
bool operator==(const AbstractBase &other) const override;
|
||||
|
||||
AbstractBasePtr Clone() const override {
|
||||
auto abs_tensor = AbstractTensor::Clone()->cast<AbstractTensorPtr>();
|
||||
if (abs_tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<AbstractRef>(ref_key_->Clone(), abs_tensor);
|
||||
}
|
||||
AbstractBasePtr Clone() const override;
|
||||
|
||||
/// \brief Use parent's AbstractTensor::Clone() to clone an abstract.
|
||||
///
|
||||
|
@ -1326,33 +1271,21 @@ class MS_CORE_API AbstractRef final : public AbstractTensor {
|
|||
/// \return A pointer to the abstract tensor.
|
||||
inline AbstractTensorPtr ref() { return shared_from_base<AbstractTensor>(); }
|
||||
|
||||
/// \brief Get the ref key.
|
||||
///
|
||||
/// \return A pointer to the abstract key.
|
||||
inline AbstractBasePtr ref_key() const { return ref_key_; }
|
||||
|
||||
/// \brief Get the ref key value.
|
||||
///
|
||||
/// \return A point to the RefKey.
|
||||
inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
|
||||
inline ValuePtr ref_key_value() const { return ref_key_value_; }
|
||||
|
||||
AbstractBasePtr Broaden() const override {
|
||||
// always broaden for ref
|
||||
auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>();
|
||||
if (abs_tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<AbstractRef>(ref_key_->Broaden(), abs_tensor);
|
||||
}
|
||||
AbstractBasePtr Broaden() const override;
|
||||
|
||||
virtual AbstractBasePtr Join(const std::shared_ptr<AbstractRef> &other);
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other) override;
|
||||
|
||||
AbstractBasePtr PartialBroaden() const override;
|
||||
|
||||
private:
|
||||
AbstractBasePtr ref_key_;
|
||||
// cache for ref_key after build value, when value is null, return nullptr.
|
||||
RefKeyPtr ref_key_value_;
|
||||
// ref_key_value is the reference key of AbstractRef, the value can be a string value or kAnyValue
|
||||
ValuePtr ref_key_value_;
|
||||
};
|
||||
using AbstractRefPtr = std::shared_ptr<AbstractRef>;
|
||||
|
||||
|
|
|
@ -32,8 +32,7 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() {
|
|||
if (is_parameter_) {
|
||||
auto param_name = param_info_->name();
|
||||
auto ref_key = std::make_shared<RefKey>(param_name);
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
|
||||
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_tensor, ref_key);
|
||||
} else {
|
||||
abs_tensor->set_value(shared_from_base<MetaTensor>());
|
||||
}
|
||||
|
|
|
@ -747,8 +747,7 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
|
|||
if (is_parameter_) {
|
||||
auto param_name = param_info_->name();
|
||||
auto ref_key = std::make_shared<RefKey>(param_name);
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
|
||||
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_tensor, ref_key);
|
||||
} else {
|
||||
abs_tensor->set_value(shared_from_base<Tensor>());
|
||||
}
|
||||
|
|
|
@ -32,9 +32,7 @@ abstract::AbstractBasePtr StringImm::ToAbstract() {
|
|||
}
|
||||
|
||||
abstract::AbstractBasePtr RefKey::ToAbstract() {
|
||||
auto refkey = std::make_shared<abstract::AbstractRefKey>();
|
||||
refkey->set_value(shared_from_base<Value>());
|
||||
return refkey;
|
||||
MS_LOG(EXCEPTION) << "Ref key can't convert to abstract, ref_key:" << ToString();
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstract::AbstractScalar>(); }
|
||||
|
|
|
@ -414,8 +414,7 @@ abstract::AbstractTensorPtr MSANFModelParser::GetAbsTensorFromTensorProto(const
|
|||
auto tensor_info = std::make_shared<abstract::AbstractTensor>(TypeIdToType(iter->second), tensor_shape);
|
||||
if (tensor_proto.has_ref_key()) {
|
||||
auto ref_key = std::make_shared<RefKey>(tensor_proto.ref_key());
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, tensor_info);
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRef>(tensor_info, ref_key);
|
||||
return abs_ref;
|
||||
}
|
||||
return tensor_info;
|
||||
|
@ -550,7 +549,8 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
|
|||
auto parameter_abs = parameter->abstract();
|
||||
if (parameter_abs->isa<abstract::AbstractRef>()) {
|
||||
auto parameter_abs_value = parameter_abs->cast<abstract::AbstractRefPtr>()->ref_key_value();
|
||||
if (parameter_abs_value->name() == tensor_proto.ref_key()) {
|
||||
auto ref_key_value = parameter_abs_value->cast<RefKeyPtr>();
|
||||
if (ref_key_value != nullptr && ref_key_value->name() == tensor_proto.ref_key()) {
|
||||
node->set_default_param(parameter->cast<ParameterPtr>()->default_param());
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -77,13 +77,6 @@ TEST_F(TestValue, testToAbstract) {
|
|||
ret = tv->ToAbstract();
|
||||
ASSERT_TRUE(ret);
|
||||
ASSERT_EQ(*(ret), *(ta));
|
||||
|
||||
ValuePtr rv = std::make_shared<RefKey>("net.weight");
|
||||
abstract::AbstractRefKeyPtr ra = std::make_shared<abstract::AbstractRefKey>();
|
||||
ra->set_value(rv);
|
||||
ret = rv->ToAbstract();
|
||||
ASSERT_TRUE(ret);
|
||||
ASSERT_EQ(*(ret), *(ra));
|
||||
}
|
||||
|
||||
TEST_F(TestValue, GetValue) {
|
||||
|
|
Loading…
Reference in New Issue