diff --git a/mindspore/ccsrc/frontend/optimizer/cse.cc b/mindspore/ccsrc/frontend/optimizer/cse.cc index 4d968d6d74..c80b54097d 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.cc +++ b/mindspore/ccsrc/frontend/optimizer/cse.cc @@ -36,6 +36,11 @@ BasePtr AbsOf(const AnfNodePtr &node) { if (node_abs == nullptr) { return kAnyValue; } + // Ignore the tracking_id and prim pointer hash; + if (node_abs->isa()) { + auto prim_abs = node_abs->cast(); + return prim_abs->prim(); + } return node_abs; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index da71f3996c..338743b1da 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -470,7 +470,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString(); } MS_EXCEPTION_IF_NULL(func); - if (func->tracking_id() == nullptr) { + if (func->tracking_id() == nullptr || func->isa() || + func->isa()) { EvaluatorPtr evaluator = _GetEvaluatorFor(func); return evaluator; } @@ -639,12 +640,12 @@ EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { } abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph, - const abstract::AnalysisContextPtr &context) { + const abstract::AnalysisContextPtr &context, const AnfNodePtr &anf_node) { AnalysisContextPtr temp_context = context; if (temp_context == nullptr) { temp_context = abstract::AnalysisContext::DummyContext(); } - return std::make_shared(func_graph, temp_context); + return std::make_shared(func_graph, temp_context, anf_node); } abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) { @@ -652,7 +653,8 @@ abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_ if (anf_node == nullptr) { meta_func_graph_fn = std::make_shared(meta_func_graph); } else { - meta_func_graph_fn = std::make_shared(meta_func_graph, anf_node->scope()); + meta_func_graph_fn = + std::make_shared(meta_func_graph, anf_node, anf_node->scope()); } return meta_func_graph_fn; } @@ -663,14 +665,14 @@ abstract::AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, con } AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) { - if (value->isa()) { - auto func_graph = value->cast(); - return MakeAbstractClosure(func_graph, context); - } AnfNodePtr anf_node = nullptr; if (conf != nullptr) { anf_node = conf->node(); } + if (value->isa()) { + auto func_graph = value->cast(); + return MakeAbstractClosure(func_graph, context, anf_node); + } if (value->isa()) { auto meta_func_graph = value->cast(); return MakeAbstractClosure(meta_func_graph, anf_node); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index 59b4a33b51..57e78dcec8 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -232,7 +232,7 @@ class AnalysisEngine : public std::enable_shared_from_this { const PrimEvaluatorMap &prim_constructors_; FuncGraphManagerPtr func_graph_manager_; - std::unordered_map constructors_; + std::unordered_map constructors_; AnfNodeConfigMap anfnode_config_map_; // Use a list to trace multiple evaluators. std::list> eval_trace_; diff --git a/mindspore/core/abstract/abstract_function.cc b/mindspore/core/abstract/abstract_function.cc index 402b9327c5..2d46862af1 100644 --- a/mindspore/core/abstract/abstract_function.cc +++ b/mindspore/core/abstract/abstract_function.cc @@ -143,14 +143,23 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { return false; } -std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); } +std::size_t PrimitiveAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), prim_->hash()); + // Keep in sync with operator==() which compares the prim_ pointer; + hash_value = hash_combine(hash_value, std::hash{}(prim_.get())); + if (tracking_id() != nullptr) { + hash_value = hash_combine(hash_value, tracking_id()->hash()); + } + return hash_value; +} bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { if (!other.isa()) { return false; } auto other_fg = static_cast(&other); - if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) { + if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ && + tracking_id() == other_fg->tracking_id()) { return true; } return false; @@ -159,9 +168,11 @@ bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { std::size_t FuncGraphAbstractClosure::hash() const { auto hash_value = hash_combine(tid(), func_graph_->hash()); hash_value = hash_combine(hash_value, context_->hash()); + if (tracking_id() != nullptr) { + hash_value = hash_combine(hash_value, tracking_id()->hash()); + } return hash_value; } - std::string FuncGraphAbstractClosure::ToString() const { std::stringstream ss; ss << "FuncGraphAbstractClosure: " @@ -174,7 +185,7 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con return false; } auto other_meta_fg = static_cast(&other); - if (meta_func_graph_ == other_meta_fg->meta_func_graph_) { + if (meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id()) { return true; } return false; @@ -182,6 +193,9 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con std::size_t MetaFuncGraphAbstractClosure::hash() const { auto hash_value = hash_combine(tid(), meta_func_graph_->hash()); + if (tracking_id() != nullptr) { + hash_value = hash_combine(hash_value, tracking_id()->hash()); + } return hash_value; } diff --git a/mindspore/core/abstract/abstract_function.h b/mindspore/core/abstract/abstract_function.h index 5e33384218..79a1d6c1d7 100644 --- a/mindspore/core/abstract/abstract_function.h +++ b/mindspore/core/abstract/abstract_function.h @@ -92,13 +92,15 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom { // one reference cycle example is Graph::set_output() input0 local variable. AnfNodeWeakPtr tracking_id_; }; +using PrimitiveAbstractClosurePtr = std::shared_ptr; class FuncGraphAbstractClosure : public AbstractFuncAtom { public: // Represents a Graph in a certain Context. // context: The context, or Context.empty() - FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) - : func_graph_(func_graph), context_(context) { + FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, + const AnfNodePtr &tracking_id = nullptr) + : func_graph_(func_graph), context_(context), tracking_id_(AnfNodeWeakPtr(tracking_id)) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(context); } @@ -109,8 +111,10 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { AnalysisContextPtr context() const override { return context_; } + AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } + AbstractFunctionPtr Copy() const override { - return std::make_shared(func_graph_, context_); + return std::make_shared(func_graph_, context_, tracking_id()); } bool operator==(const AbstractFunction &other) const override; @@ -121,13 +125,22 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { private: FuncGraphPtr func_graph_; AnalysisContextPtr context_; + // To discriminate different usage of same graph by using this tracking_id, + // so different tracking_id will produce different FuncGraphAbstractClosure, + // different FuncGraphEvaluator. + // Espcecially usefull for recursive func graph call, so it will not mess up + // the graph_context_ in FuncGraphEvaluator. + // Notes: Be careful to use nullptr for this variable. + // store it as weak_ptr to break reference cycle. + AnfNodeWeakPtr tracking_id_; }; using FuncGraphAbstractClosurePtr = std::shared_ptr; class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { public: - explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope = kDefaultScope) - : meta_func_graph_(meta_func_graph), scope_(scope) {} + explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, + const AnfNodePtr &tracking_id = nullptr, const ScopePtr &scope = kDefaultScope) + : meta_func_graph_(meta_func_graph), tracking_id_(AnfNodeWeakPtr(tracking_id)), scope_(scope) {} ~MetaFuncGraphAbstractClosure() override = default; MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom) @@ -137,7 +150,11 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { ScopePtr GetScope() { return scope_; } - AbstractFunctionPtr Copy() const override { return std::make_shared(meta_func_graph_); } + AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } + + AbstractFunctionPtr Copy() const override { + return std::make_shared(meta_func_graph_, tracking_id()); + } bool operator==(const AbstractFunction &other) const override; std::size_t hash() const override; @@ -145,6 +162,9 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { private: MetaFuncGraphPtr meta_func_graph_; + // refer the comment in FuncGraphAbstractClosure; + // store it as weak_ptr to break reference cycle. + AnfNodeWeakPtr tracking_id_; ScopePtr scope_; }; using MetaFuncGraphAbstractClosurePtr = std::shared_ptr; diff --git a/tests/ut/python/pipeline/infer/test_net_infer.py b/tests/ut/python/pipeline/infer/test_net_infer.py index 61f19e7d6b..51bfcf87cd 100644 --- a/tests/ut/python/pipeline/infer/test_net_infer.py +++ b/tests/ut/python/pipeline/infer/test_net_infer.py @@ -67,3 +67,62 @@ def test_assign_in_while(): z = Tensor(np.random.randn(*input_shape).astype(np.float32)) net = Net(input_shape) net(x, y, z) + + +def test_dup_context(): + ''' different func_with_fv in net1 and net2 should produce 2 different FuncGraphAbstractClosure and + Evaluator. + ''' + context.set_context(mode=context.GRAPH_MODE) + + class Net(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + def identity(f): + return f + + def func_with_fv(): + return x + + def net1(): + local_func = identity(func_with_fv) + out = local_func() + 20.0 + return out + + def net2(): + local_func = identity(func_with_fv) + out = local_func() + 15.0 + return out + + return net1() + net2() + + Net()(5.0) + + +def test_maybe_poly_func(): + ''' different func_with_fv in net1 and net2 may produce poly node. ''' + context.set_context(mode=context.GRAPH_MODE) + + class Net(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x, y, z): + def identity(f, inp): + return f(inp) + + def func_with_fv(yy): + return (x, yy) + + def make_call(): + out1 = identity(func_with_fv, y) + out2 = identity(func_with_fv, z) + return (out1, out2) + + return make_call() + + y_input = Tensor(np.array([1, 2]).astype(np.int32)) + z_input = Tensor(np.array([[2, 2], [3, 3]]).astype(np.int32)) + Net()(1, y_input, z_input)