!3370 same graph used in different context should be treat differently.
Merge pull request !3370 from xychow/fix-context-dup
This commit is contained in:
commit
2da29bce66
|
@ -36,6 +36,11 @@ BasePtr AbsOf(const AnfNodePtr &node) {
|
||||||
if (node_abs == nullptr) {
|
if (node_abs == nullptr) {
|
||||||
return kAnyValue;
|
return kAnyValue;
|
||||||
}
|
}
|
||||||
|
// Ignore the tracking_id and prim pointer hash;
|
||||||
|
if (node_abs->isa<abstract::PrimitiveAbstractClosure>()) {
|
||||||
|
auto prim_abs = node_abs->cast<abstract::PrimitiveAbstractClosurePtr>();
|
||||||
|
return prim_abs->prim();
|
||||||
|
}
|
||||||
|
|
||||||
return node_abs;
|
return node_abs;
|
||||||
}
|
}
|
||||||
|
|
|
@ -470,7 +470,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
||||||
MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
|
MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(func);
|
MS_EXCEPTION_IF_NULL(func);
|
||||||
if (func->tracking_id() == nullptr) {
|
if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() ||
|
||||||
|
func->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||||
EvaluatorPtr evaluator = _GetEvaluatorFor(func);
|
EvaluatorPtr evaluator = _GetEvaluatorFor(func);
|
||||||
return evaluator;
|
return evaluator;
|
||||||
}
|
}
|
||||||
|
@ -639,12 +640,12 @@ EvalResultPtr AnfNodeConfig::GetEvaluatedValue() {
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph,
|
abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph,
|
||||||
const abstract::AnalysisContextPtr &context) {
|
const abstract::AnalysisContextPtr &context, const AnfNodePtr &anf_node) {
|
||||||
AnalysisContextPtr temp_context = context;
|
AnalysisContextPtr temp_context = context;
|
||||||
if (temp_context == nullptr) {
|
if (temp_context == nullptr) {
|
||||||
temp_context = abstract::AnalysisContext::DummyContext();
|
temp_context = abstract::AnalysisContext::DummyContext();
|
||||||
}
|
}
|
||||||
return std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context);
|
return std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context, anf_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) {
|
abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) {
|
||||||
|
@ -652,7 +653,8 @@ abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_
|
||||||
if (anf_node == nullptr) {
|
if (anf_node == nullptr) {
|
||||||
meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph);
|
meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph);
|
||||||
} else {
|
} else {
|
||||||
meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph, anf_node->scope());
|
meta_func_graph_fn =
|
||||||
|
std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph, anf_node, anf_node->scope());
|
||||||
}
|
}
|
||||||
return meta_func_graph_fn;
|
return meta_func_graph_fn;
|
||||||
}
|
}
|
||||||
|
@ -663,14 +665,14 @@ abstract::AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, con
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) {
|
AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) {
|
||||||
if (value->isa<FuncGraph>()) {
|
|
||||||
auto func_graph = value->cast<FuncGraphPtr>();
|
|
||||||
return MakeAbstractClosure(func_graph, context);
|
|
||||||
}
|
|
||||||
AnfNodePtr anf_node = nullptr;
|
AnfNodePtr anf_node = nullptr;
|
||||||
if (conf != nullptr) {
|
if (conf != nullptr) {
|
||||||
anf_node = conf->node();
|
anf_node = conf->node();
|
||||||
}
|
}
|
||||||
|
if (value->isa<FuncGraph>()) {
|
||||||
|
auto func_graph = value->cast<FuncGraphPtr>();
|
||||||
|
return MakeAbstractClosure(func_graph, context, anf_node);
|
||||||
|
}
|
||||||
if (value->isa<MetaFuncGraph>()) {
|
if (value->isa<MetaFuncGraph>()) {
|
||||||
auto meta_func_graph = value->cast<MetaFuncGraphPtr>();
|
auto meta_func_graph = value->cast<MetaFuncGraphPtr>();
|
||||||
return MakeAbstractClosure(meta_func_graph, anf_node);
|
return MakeAbstractClosure(meta_func_graph, anf_node);
|
||||||
|
|
|
@ -232,7 +232,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
||||||
|
|
||||||
const PrimEvaluatorMap &prim_constructors_;
|
const PrimEvaluatorMap &prim_constructors_;
|
||||||
FuncGraphManagerPtr func_graph_manager_;
|
FuncGraphManagerPtr func_graph_manager_;
|
||||||
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_;
|
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_;
|
||||||
AnfNodeConfigMap anfnode_config_map_;
|
AnfNodeConfigMap anfnode_config_map_;
|
||||||
// Use a list to trace multiple evaluators.
|
// Use a list to trace multiple evaluators.
|
||||||
std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_;
|
std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_;
|
||||||
|
|
|
@ -143,14 +143,23 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); }
|
std::size_t PrimitiveAbstractClosure::hash() const {
|
||||||
|
auto hash_value = hash_combine(tid(), prim_->hash());
|
||||||
|
// Keep in sync with operator==() which compares the prim_ pointer;
|
||||||
|
hash_value = hash_combine(hash_value, std::hash<Primitive *>{}(prim_.get()));
|
||||||
|
if (tracking_id() != nullptr) {
|
||||||
|
hash_value = hash_combine(hash_value, tracking_id()->hash());
|
||||||
|
}
|
||||||
|
return hash_value;
|
||||||
|
}
|
||||||
|
|
||||||
bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
|
bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||||
if (!other.isa<FuncGraphAbstractClosure>()) {
|
if (!other.isa<FuncGraphAbstractClosure>()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto other_fg = static_cast<const FuncGraphAbstractClosure *>(&other);
|
auto other_fg = static_cast<const FuncGraphAbstractClosure *>(&other);
|
||||||
if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) {
|
if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ &&
|
||||||
|
tracking_id() == other_fg->tracking_id()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -159,9 +168,11 @@ bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||||
std::size_t FuncGraphAbstractClosure::hash() const {
|
std::size_t FuncGraphAbstractClosure::hash() const {
|
||||||
auto hash_value = hash_combine(tid(), func_graph_->hash());
|
auto hash_value = hash_combine(tid(), func_graph_->hash());
|
||||||
hash_value = hash_combine(hash_value, context_->hash());
|
hash_value = hash_combine(hash_value, context_->hash());
|
||||||
|
if (tracking_id() != nullptr) {
|
||||||
|
hash_value = hash_combine(hash_value, tracking_id()->hash());
|
||||||
|
}
|
||||||
return hash_value;
|
return hash_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string FuncGraphAbstractClosure::ToString() const {
|
std::string FuncGraphAbstractClosure::ToString() const {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "FuncGraphAbstractClosure: "
|
ss << "FuncGraphAbstractClosure: "
|
||||||
|
@ -174,7 +185,7 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure *>(&other);
|
auto other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure *>(&other);
|
||||||
if (meta_func_graph_ == other_meta_fg->meta_func_graph_) {
|
if (meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -182,6 +193,9 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con
|
||||||
|
|
||||||
std::size_t MetaFuncGraphAbstractClosure::hash() const {
|
std::size_t MetaFuncGraphAbstractClosure::hash() const {
|
||||||
auto hash_value = hash_combine(tid(), meta_func_graph_->hash());
|
auto hash_value = hash_combine(tid(), meta_func_graph_->hash());
|
||||||
|
if (tracking_id() != nullptr) {
|
||||||
|
hash_value = hash_combine(hash_value, tracking_id()->hash());
|
||||||
|
}
|
||||||
return hash_value;
|
return hash_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -92,13 +92,15 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom {
|
||||||
// one reference cycle example is Graph::set_output() input0 local variable.
|
// one reference cycle example is Graph::set_output() input0 local variable.
|
||||||
AnfNodeWeakPtr tracking_id_;
|
AnfNodeWeakPtr tracking_id_;
|
||||||
};
|
};
|
||||||
|
using PrimitiveAbstractClosurePtr = std::shared_ptr<PrimitiveAbstractClosure>;
|
||||||
|
|
||||||
class FuncGraphAbstractClosure : public AbstractFuncAtom {
|
class FuncGraphAbstractClosure : public AbstractFuncAtom {
|
||||||
public:
|
public:
|
||||||
// Represents a Graph in a certain Context.
|
// Represents a Graph in a certain Context.
|
||||||
// context: The context, or Context.empty()
|
// context: The context, or Context.empty()
|
||||||
FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
|
FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
|
||||||
: func_graph_(func_graph), context_(context) {
|
const AnfNodePtr &tracking_id = nullptr)
|
||||||
|
: func_graph_(func_graph), context_(context), tracking_id_(AnfNodeWeakPtr(tracking_id)) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
}
|
}
|
||||||
|
@ -109,8 +111,10 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
|
||||||
|
|
||||||
AnalysisContextPtr context() const override { return context_; }
|
AnalysisContextPtr context() const override { return context_; }
|
||||||
|
|
||||||
|
AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
|
||||||
|
|
||||||
AbstractFunctionPtr Copy() const override {
|
AbstractFunctionPtr Copy() const override {
|
||||||
return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_);
|
return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_, tracking_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(const AbstractFunction &other) const override;
|
bool operator==(const AbstractFunction &other) const override;
|
||||||
|
@ -121,13 +125,22 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
|
||||||
private:
|
private:
|
||||||
FuncGraphPtr func_graph_;
|
FuncGraphPtr func_graph_;
|
||||||
AnalysisContextPtr context_;
|
AnalysisContextPtr context_;
|
||||||
|
// To discriminate different usage of same graph by using this tracking_id,
|
||||||
|
// so different tracking_id will produce different FuncGraphAbstractClosure,
|
||||||
|
// different FuncGraphEvaluator.
|
||||||
|
// Espcecially usefull for recursive func graph call, so it will not mess up
|
||||||
|
// the graph_context_ in FuncGraphEvaluator.
|
||||||
|
// Notes: Be careful to use nullptr for this variable.
|
||||||
|
// store it as weak_ptr to break reference cycle.
|
||||||
|
AnfNodeWeakPtr tracking_id_;
|
||||||
};
|
};
|
||||||
using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>;
|
using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>;
|
||||||
|
|
||||||
class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
|
class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
|
||||||
public:
|
public:
|
||||||
explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope = kDefaultScope)
|
explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph,
|
||||||
: meta_func_graph_(meta_func_graph), scope_(scope) {}
|
const AnfNodePtr &tracking_id = nullptr, const ScopePtr &scope = kDefaultScope)
|
||||||
|
: meta_func_graph_(meta_func_graph), tracking_id_(AnfNodeWeakPtr(tracking_id)), scope_(scope) {}
|
||||||
~MetaFuncGraphAbstractClosure() override = default;
|
~MetaFuncGraphAbstractClosure() override = default;
|
||||||
MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom)
|
MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom)
|
||||||
|
|
||||||
|
@ -137,7 +150,11 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
|
||||||
|
|
||||||
ScopePtr GetScope() { return scope_; }
|
ScopePtr GetScope() { return scope_; }
|
||||||
|
|
||||||
AbstractFunctionPtr Copy() const override { return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_); }
|
AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
|
||||||
|
|
||||||
|
AbstractFunctionPtr Copy() const override {
|
||||||
|
return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_, tracking_id());
|
||||||
|
}
|
||||||
bool operator==(const AbstractFunction &other) const override;
|
bool operator==(const AbstractFunction &other) const override;
|
||||||
std::size_t hash() const override;
|
std::size_t hash() const override;
|
||||||
|
|
||||||
|
@ -145,6 +162,9 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MetaFuncGraphPtr meta_func_graph_;
|
MetaFuncGraphPtr meta_func_graph_;
|
||||||
|
// refer the comment in FuncGraphAbstractClosure;
|
||||||
|
// store it as weak_ptr to break reference cycle.
|
||||||
|
AnfNodeWeakPtr tracking_id_;
|
||||||
ScopePtr scope_;
|
ScopePtr scope_;
|
||||||
};
|
};
|
||||||
using MetaFuncGraphAbstractClosurePtr = std::shared_ptr<MetaFuncGraphAbstractClosure>;
|
using MetaFuncGraphAbstractClosurePtr = std::shared_ptr<MetaFuncGraphAbstractClosure>;
|
||||||
|
|
|
@ -67,3 +67,62 @@ def test_assign_in_while():
|
||||||
z = Tensor(np.random.randn(*input_shape).astype(np.float32))
|
z = Tensor(np.random.randn(*input_shape).astype(np.float32))
|
||||||
net = Net(input_shape)
|
net = Net(input_shape)
|
||||||
net(x, y, z)
|
net(x, y, z)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dup_context():
|
||||||
|
''' different func_with_fv in net1 and net2 should produce 2 different FuncGraphAbstractClosure and
|
||||||
|
Evaluator.
|
||||||
|
'''
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
def identity(f):
|
||||||
|
return f
|
||||||
|
|
||||||
|
def func_with_fv():
|
||||||
|
return x
|
||||||
|
|
||||||
|
def net1():
|
||||||
|
local_func = identity(func_with_fv)
|
||||||
|
out = local_func() + 20.0
|
||||||
|
return out
|
||||||
|
|
||||||
|
def net2():
|
||||||
|
local_func = identity(func_with_fv)
|
||||||
|
out = local_func() + 15.0
|
||||||
|
return out
|
||||||
|
|
||||||
|
return net1() + net2()
|
||||||
|
|
||||||
|
Net()(5.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_maybe_poly_func():
|
||||||
|
''' different func_with_fv in net1 and net2 may produce poly node. '''
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
def identity(f, inp):
|
||||||
|
return f(inp)
|
||||||
|
|
||||||
|
def func_with_fv(yy):
|
||||||
|
return (x, yy)
|
||||||
|
|
||||||
|
def make_call():
|
||||||
|
out1 = identity(func_with_fv, y)
|
||||||
|
out2 = identity(func_with_fv, z)
|
||||||
|
return (out1, out2)
|
||||||
|
|
||||||
|
return make_call()
|
||||||
|
|
||||||
|
y_input = Tensor(np.array([1, 2]).astype(np.int32))
|
||||||
|
z_input = Tensor(np.array([[2, 2], [3, 3]]).astype(np.int32))
|
||||||
|
Net()(1, y_input, z_input)
|
||||||
|
|
Loading…
Reference in New Issue