!20505 limit the scope of lift free variable

Merge pull request !20505 from xychow/limit-lift-scope
This commit is contained in:
i-robot 2021-07-20 06:10:17 +00:00 committed by Gitee
commit ed4c9682b5
3 changed files with 49 additions and 9 deletions

View File

@ -444,9 +444,8 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph
}
}
void Cloner::Lift() {
void Cloner::Lift(const std::vector<FuncGraphPtr> &sorted) {
// lift inner graph first
auto sorted = BroadFirstSearchGraphUsed(*(manager_->roots().begin()));
for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) {
auto func_graph = *r_iter;
auto iter = repl_func_graph_params_.find(func_graph);
@ -459,14 +458,14 @@ void Cloner::Lift() {
}
}
void Cloner::LiftParameters() {
void Cloner::LiftParameters(const FuncGraphPtr &lift_top_func_graph) {
MS_EXCEPTION_IF_NULL(manager_);
transaction_ = manager_->Transact();
const FuncGraphSet &func_graphs = manager_->func_graphs();
const auto &func_graphs = BroadFirstSearchGraphUsed(lift_top_func_graph);
for (auto &func_graph : func_graphs) {
GenParameters(func_graph);
}
Lift();
Lift(func_graphs);
for (auto &func_graph : func_graphs) {
SetEdges(func_graph);
}
@ -542,7 +541,7 @@ void Cloner::Run() {
// Lifting Clone
CloneInfo item = todo_.back();
manager_ = Manage(item.origin);
LiftParameters();
LiftParameters(item.origin);
}
}

View File

@ -99,8 +99,8 @@ class Cloner {
void SetEdges(const FuncGraphPtr &func_graph);
void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
const AnfNodePtrList &params);
void Lift();
void LiftParameters();
void Lift(const std::vector<FuncGraphPtr> &sorted);
void LiftParameters(const FuncGraphPtr &lift_top_func_graph);
bool clone_all_valuenodes_;
bool clone_all_child_graphs_;

View File

@ -180,7 +180,7 @@ def test_ad_fv_cnode_order():
# True and False branch of switch have different number of parameters.
def test_if_branch_with_different_params():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -211,3 +211,44 @@ def test_if_branch_with_different_params():
net = Net()
grad_net = GradNet(net)
grad_net(idx, end, x)
# Only lift fv in scope of lift_top_func_graph other than all func_graphs inside manager.
# Otherwise, "Illegal AnfNode for evaluating" may be reported
# because weight1 in Net may use old_parameter other than replicated one.
def test_limit_lift_fv_scope():
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.weight1 = Parameter(Tensor(np.array([1.0], dtype=np.float32)), name="weight1")
def construct(self, x, y):
def inner_add(a, b):
return a + b
out = inner_add(x, y) + self.weight1
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, x, y):
def inner_grad_add(a, b):
return a + b
d_weight = grad_by_list(self.net, self.weights)(x, y)[0]
d_out = inner_grad_add(d_weight, y)
return d_out
x = Tensor(np.array([2.0], dtype=np.float32))
y = Tensor(np.array([2.0], dtype=np.float32))
net = Net()
net.add_flags_recursive(defer_inline=True)
grad_net = GradNet(net)
grad_net.add_flags_recursive(defer_inline=True)
grad_net(x, y)