From f926650c64f7ce4f1dc75dd3675086272ca13ea6 Mon Sep 17 00:00:00 2001 From: zhousiyi Date: Sat, 25 Jul 2020 01:57:07 +0000 Subject: [PATCH] if AbstractFunction comparison succeed in NewContext, then the evaluator should use the same one, otherwise one of the evaluator will not be evaluated. if funcgraph or metafuncgraph call it recursively, then anf_node should be used as tracking_id to discriminate the first occurcance and the recursive occurance. add anf_node to PrimitiveAbstractClosure hash() to reduce cost of GetEvaluatorFor(). ignore the tracking_id to make cse work. --- mindspore/ccsrc/frontend/optimizer/cse.cc | 5 ++ .../jit/static_analysis/static_analysis.cc | 18 +++--- .../jit/static_analysis/static_analysis.h | 2 +- mindspore/core/abstract/abstract_function.cc | 22 +++++-- mindspore/core/abstract/abstract_function.h | 32 ++++++++-- .../python/pipeline/infer/test_net_infer.py | 59 +++++++++++++++++++ 6 files changed, 119 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/cse.cc b/mindspore/ccsrc/frontend/optimizer/cse.cc index 4d968d6d74..c80b54097d 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.cc +++ b/mindspore/ccsrc/frontend/optimizer/cse.cc @@ -36,6 +36,11 @@ BasePtr AbsOf(const AnfNodePtr &node) { if (node_abs == nullptr) { return kAnyValue; } + // Ignore the tracking_id and prim pointer hash; + if (node_abs->isa()) { + auto prim_abs = node_abs->cast(); + return prim_abs->prim(); + } return node_abs; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index da71f3996c..338743b1da 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -470,7 +470,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString(); } MS_EXCEPTION_IF_NULL(func); - if (func->tracking_id() == nullptr) { + if (func->tracking_id() == nullptr || func->isa() || + func->isa()) { EvaluatorPtr evaluator = _GetEvaluatorFor(func); return evaluator; } @@ -639,12 +640,12 @@ EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { } abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph, - const abstract::AnalysisContextPtr &context) { + const abstract::AnalysisContextPtr &context, const AnfNodePtr &anf_node) { AnalysisContextPtr temp_context = context; if (temp_context == nullptr) { temp_context = abstract::AnalysisContext::DummyContext(); } - return std::make_shared(func_graph, temp_context); + return std::make_shared(func_graph, temp_context, anf_node); } abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) { @@ -652,7 +653,8 @@ abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_ if (anf_node == nullptr) { meta_func_graph_fn = std::make_shared(meta_func_graph); } else { - meta_func_graph_fn = std::make_shared(meta_func_graph, anf_node->scope()); + meta_func_graph_fn = + std::make_shared(meta_func_graph, anf_node, anf_node->scope()); } return meta_func_graph_fn; } @@ -663,14 +665,14 @@ abstract::AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, con } AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) { - if (value->isa()) { - auto func_graph = value->cast(); - return MakeAbstractClosure(func_graph, context); - } AnfNodePtr anf_node = nullptr; if (conf != nullptr) { anf_node = conf->node(); } + if (value->isa()) { + auto func_graph = value->cast(); + return MakeAbstractClosure(func_graph, context, anf_node); + } if (value->isa()) { auto meta_func_graph = value->cast(); return MakeAbstractClosure(meta_func_graph, anf_node); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index 59b4a33b51..57e78dcec8 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -232,7 +232,7 @@ class AnalysisEngine : public std::enable_shared_from_this { const PrimEvaluatorMap &prim_constructors_; FuncGraphManagerPtr func_graph_manager_; - std::unordered_map constructors_; + std::unordered_map constructors_; AnfNodeConfigMap anfnode_config_map_; // Use a list to trace multiple evaluators. std::list> eval_trace_; diff --git a/mindspore/core/abstract/abstract_function.cc b/mindspore/core/abstract/abstract_function.cc index 402b9327c5..2d46862af1 100644 --- a/mindspore/core/abstract/abstract_function.cc +++ b/mindspore/core/abstract/abstract_function.cc @@ -143,14 +143,23 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { return false; } -std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); } +std::size_t PrimitiveAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), prim_->hash()); + // Keep in sync with operator==() which compares the prim_ pointer; + hash_value = hash_combine(hash_value, std::hash{}(prim_.get())); + if (tracking_id() != nullptr) { + hash_value = hash_combine(hash_value, tracking_id()->hash()); + } + return hash_value; +} bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { if (!other.isa()) { return false; } auto other_fg = static_cast(&other); - if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) { + if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ && + tracking_id() == other_fg->tracking_id()) { return true; } return false; @@ -159,9 +168,11 @@ bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { std::size_t FuncGraphAbstractClosure::hash() const { auto hash_value = hash_combine(tid(), func_graph_->hash()); hash_value = hash_combine(hash_value, context_->hash()); + if (tracking_id() != nullptr) { + hash_value = hash_combine(hash_value, tracking_id()->hash()); + } return hash_value; } - std::string FuncGraphAbstractClosure::ToString() const { std::stringstream ss; ss << "FuncGraphAbstractClosure: " @@ -174,7 +185,7 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con return false; } auto other_meta_fg = static_cast(&other); - if (meta_func_graph_ == other_meta_fg->meta_func_graph_) { + if (meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id()) { return true; } return false; @@ -182,6 +193,9 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con std::size_t MetaFuncGraphAbstractClosure::hash() const { auto hash_value = hash_combine(tid(), meta_func_graph_->hash()); + if (tracking_id() != nullptr) { + hash_value = hash_combine(hash_value, tracking_id()->hash()); + } return hash_value; } diff --git a/mindspore/core/abstract/abstract_function.h b/mindspore/core/abstract/abstract_function.h index 5e33384218..79a1d6c1d7 100644 --- a/mindspore/core/abstract/abstract_function.h +++ b/mindspore/core/abstract/abstract_function.h @@ -92,13 +92,15 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom { // one reference cycle example is Graph::set_output() input0 local variable. AnfNodeWeakPtr tracking_id_; }; +using PrimitiveAbstractClosurePtr = std::shared_ptr; class FuncGraphAbstractClosure : public AbstractFuncAtom { public: // Represents a Graph in a certain Context. // context: The context, or Context.empty() - FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) - : func_graph_(func_graph), context_(context) { + FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, + const AnfNodePtr &tracking_id = nullptr) + : func_graph_(func_graph), context_(context), tracking_id_(AnfNodeWeakPtr(tracking_id)) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(context); } @@ -109,8 +111,10 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { AnalysisContextPtr context() const override { return context_; } + AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } + AbstractFunctionPtr Copy() const override { - return std::make_shared(func_graph_, context_); + return std::make_shared(func_graph_, context_, tracking_id()); } bool operator==(const AbstractFunction &other) const override; @@ -121,13 +125,22 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { private: FuncGraphPtr func_graph_; AnalysisContextPtr context_; + // To discriminate different usage of same graph by using this tracking_id, + // so different tracking_id will produce different FuncGraphAbstractClosure, + // different FuncGraphEvaluator. + // Espcecially usefull for recursive func graph call, so it will not mess up + // the graph_context_ in FuncGraphEvaluator. + // Notes: Be careful to use nullptr for this variable. + // store it as weak_ptr to break reference cycle. + AnfNodeWeakPtr tracking_id_; }; using FuncGraphAbstractClosurePtr = std::shared_ptr; class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { public: - explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope = kDefaultScope) - : meta_func_graph_(meta_func_graph), scope_(scope) {} + explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, + const AnfNodePtr &tracking_id = nullptr, const ScopePtr &scope = kDefaultScope) + : meta_func_graph_(meta_func_graph), tracking_id_(AnfNodeWeakPtr(tracking_id)), scope_(scope) {} ~MetaFuncGraphAbstractClosure() override = default; MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom) @@ -137,7 +150,11 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { ScopePtr GetScope() { return scope_; } - AbstractFunctionPtr Copy() const override { return std::make_shared(meta_func_graph_); } + AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } + + AbstractFunctionPtr Copy() const override { + return std::make_shared(meta_func_graph_, tracking_id()); + } bool operator==(const AbstractFunction &other) const override; std::size_t hash() const override; @@ -145,6 +162,9 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { private: MetaFuncGraphPtr meta_func_graph_; + // refer the comment in FuncGraphAbstractClosure; + // store it as weak_ptr to break reference cycle. + AnfNodeWeakPtr tracking_id_; ScopePtr scope_; }; using MetaFuncGraphAbstractClosurePtr = std::shared_ptr; diff --git a/tests/ut/python/pipeline/infer/test_net_infer.py b/tests/ut/python/pipeline/infer/test_net_infer.py index 61f19e7d6b..51bfcf87cd 100644 --- a/tests/ut/python/pipeline/infer/test_net_infer.py +++ b/tests/ut/python/pipeline/infer/test_net_infer.py @@ -67,3 +67,62 @@ def test_assign_in_while(): z = Tensor(np.random.randn(*input_shape).astype(np.float32)) net = Net(input_shape) net(x, y, z) + + +def test_dup_context(): + ''' different func_with_fv in net1 and net2 should produce 2 different FuncGraphAbstractClosure and + Evaluator. + ''' + context.set_context(mode=context.GRAPH_MODE) + + class Net(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + def identity(f): + return f + + def func_with_fv(): + return x + + def net1(): + local_func = identity(func_with_fv) + out = local_func() + 20.0 + return out + + def net2(): + local_func = identity(func_with_fv) + out = local_func() + 15.0 + return out + + return net1() + net2() + + Net()(5.0) + + +def test_maybe_poly_func(): + ''' different func_with_fv in net1 and net2 may produce poly node. ''' + context.set_context(mode=context.GRAPH_MODE) + + class Net(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x, y, z): + def identity(f, inp): + return f(inp) + + def func_with_fv(yy): + return (x, yy) + + def make_call(): + out1 = identity(func_with_fv, y) + out2 = identity(func_with_fv, z) + return (out1, out2) + + return make_call() + + y_input = Tensor(np.array([1, 2]).astype(np.int32)) + z_input = Tensor(np.array([[2, 2], [3, 3]]).astype(np.int32)) + Net()(1, y_input, z_input)