Return a new abstract without tracking_id for fg ValueNode in CSE.

This commit is contained in:
Zhang Qinghua 2020-10-22 19:23:06 +08:00
parent 3763e201b5
commit 077bde0767
5 changed files with 13 additions and 10 deletions

View File

@ -32,19 +32,23 @@ using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractFunctionPtr;
BasePtr AbsOf(const AnfNodePtr &node) {
BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) {
MS_EXCEPTION_IF_NULL(node);
auto node_abs = node->abstract();
// in testcase: TestOptOpt.CSE, node->abstract() is null;
// In testcase: TestOptOpt.CSE, node->abstract() is null.
if (node_abs == nullptr) {
return kAnyValue;
}
// Ignore the tracking_id and prim pointer hash;
if (node_abs->isa<abstract::PrimitiveAbstractClosure>()) {
// Ignore the tracking_id and prim pointer hash.
auto prim_abs = node_abs->cast<abstract::PrimitiveAbstractClosurePtr>();
return prim_abs->prim();
} else if (ignore_fg_abs_tracking_id && node_abs->isa<abstract::FuncGraphAbstractClosure>()) {
// Ignore the tracking_id.
auto new_fg_abs = node_abs->cast<abstract::AbstractFunctionPtr>()->Copy();
new_fg_abs->set_tracking_id(nullptr);
return new_fg_abs;
}
return node_abs;
}
@ -68,7 +72,7 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
ValueNodePtr value_node = node->cast<ValueNodePtr>();
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
h = hash_combine(value->hash(), (AbsOf(value_node)->hash()));
h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash()));
} else if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
@ -134,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool chec
if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
auto main_value = GetValueNode(main);
auto node_value = GetValueNode(node);
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
return (AbsOf(main, true) == AbsOf(node, true)) && (*main_value == *node_value);
} else if (main->isa<CNode>() && node->isa<CNode>()) {
auto c_main = main->cast<CNodePtr>();
auto c_node = node->cast<CNodePtr>();

View File

@ -46,7 +46,7 @@ class CSE {
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) const;
};
BasePtr AbsOf(const AnfNodePtr &node);
BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id = false);
} // namespace opt
} // namespace mindspore

View File

@ -44,8 +44,6 @@ class CSEPass : public CSE {
private:
bool report_changes_;
};
BasePtr AbsOf(const AnfNodePtr &node);
} // namespace opt
} // namespace mindspore

View File

@ -467,7 +467,6 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
} else {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFunction";
}
return nullptr;
}
EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {

View File

@ -113,6 +113,8 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); }
AbstractFunctionPtr Copy() const override {
return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_, tracking_id());
}