forked from mindspore-Ecosystem/mindspore
Return a new abstract without tracking_id for fg ValueNode in CSE.
This commit is contained in:
parent
3763e201b5
commit
077bde0767
|
@ -32,19 +32,23 @@ using mindspore::abstract::AbstractBase;
|
|||
using mindspore::abstract::AbstractFunction;
|
||||
using mindspore::abstract::AbstractFunctionPtr;
|
||||
|
||||
BasePtr AbsOf(const AnfNodePtr &node) {
|
||||
BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_abs = node->abstract();
|
||||
// in testcase: TestOptOpt.CSE, node->abstract() is null;
|
||||
// In testcase: TestOptOpt.CSE, node->abstract() is null.
|
||||
if (node_abs == nullptr) {
|
||||
return kAnyValue;
|
||||
}
|
||||
// Ignore the tracking_id and prim pointer hash;
|
||||
if (node_abs->isa<abstract::PrimitiveAbstractClosure>()) {
|
||||
// Ignore the tracking_id and prim pointer hash.
|
||||
auto prim_abs = node_abs->cast<abstract::PrimitiveAbstractClosurePtr>();
|
||||
return prim_abs->prim();
|
||||
} else if (ignore_fg_abs_tracking_id && node_abs->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
// Ignore the tracking_id.
|
||||
auto new_fg_abs = node_abs->cast<abstract::AbstractFunctionPtr>()->Copy();
|
||||
new_fg_abs->set_tracking_id(nullptr);
|
||||
return new_fg_abs;
|
||||
}
|
||||
|
||||
return node_abs;
|
||||
}
|
||||
|
||||
|
@ -68,7 +72,7 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
|
|||
ValueNodePtr value_node = node->cast<ValueNodePtr>();
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
h = hash_combine(value->hash(), (AbsOf(value_node)->hash()));
|
||||
h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash()));
|
||||
} else if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto &inputs = cnode->inputs();
|
||||
|
@ -134,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool chec
|
|||
if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
|
||||
auto main_value = GetValueNode(main);
|
||||
auto node_value = GetValueNode(node);
|
||||
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
|
||||
return (AbsOf(main, true) == AbsOf(node, true)) && (*main_value == *node_value);
|
||||
} else if (main->isa<CNode>() && node->isa<CNode>()) {
|
||||
auto c_main = main->cast<CNodePtr>();
|
||||
auto c_node = node->cast<CNodePtr>();
|
||||
|
|
|
@ -46,7 +46,7 @@ class CSE {
|
|||
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) const;
|
||||
};
|
||||
|
||||
BasePtr AbsOf(const AnfNodePtr &node);
|
||||
BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id = false);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -44,8 +44,6 @@ class CSEPass : public CSE {
|
|||
private:
|
||||
bool report_changes_;
|
||||
};
|
||||
|
||||
BasePtr AbsOf(const AnfNodePtr &node);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -467,7 +467,6 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
|||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFunction";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
||||
|
|
|
@ -113,6 +113,8 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
|
|||
|
||||
AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
|
||||
|
||||
void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); }
|
||||
|
||||
AbstractFunctionPtr Copy() const override {
|
||||
return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_, tracking_id());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue