!20505 limit the scope of lift free variable
Merge pull request !20505 from xychow/limit-lift-scope
This commit is contained in:
commit
ed4c9682b5
|
@ -444,9 +444,8 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph
|
|||
}
|
||||
}
|
||||
|
||||
void Cloner::Lift() {
|
||||
void Cloner::Lift(const std::vector<FuncGraphPtr> &sorted) {
|
||||
// lift inner graph first
|
||||
auto sorted = BroadFirstSearchGraphUsed(*(manager_->roots().begin()));
|
||||
for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) {
|
||||
auto func_graph = *r_iter;
|
||||
auto iter = repl_func_graph_params_.find(func_graph);
|
||||
|
@ -459,14 +458,14 @@ void Cloner::Lift() {
|
|||
}
|
||||
}
|
||||
|
||||
void Cloner::LiftParameters() {
|
||||
void Cloner::LiftParameters(const FuncGraphPtr &lift_top_func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(manager_);
|
||||
transaction_ = manager_->Transact();
|
||||
const FuncGraphSet &func_graphs = manager_->func_graphs();
|
||||
const auto &func_graphs = BroadFirstSearchGraphUsed(lift_top_func_graph);
|
||||
for (auto &func_graph : func_graphs) {
|
||||
GenParameters(func_graph);
|
||||
}
|
||||
Lift();
|
||||
Lift(func_graphs);
|
||||
for (auto &func_graph : func_graphs) {
|
||||
SetEdges(func_graph);
|
||||
}
|
||||
|
@ -542,7 +541,7 @@ void Cloner::Run() {
|
|||
// Lifting Clone
|
||||
CloneInfo item = todo_.back();
|
||||
manager_ = Manage(item.origin);
|
||||
LiftParameters();
|
||||
LiftParameters(item.origin);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -99,8 +99,8 @@ class Cloner {
|
|||
void SetEdges(const FuncGraphPtr &func_graph);
|
||||
void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtrList ¶ms);
|
||||
void Lift();
|
||||
void LiftParameters();
|
||||
void Lift(const std::vector<FuncGraphPtr> &sorted);
|
||||
void LiftParameters(const FuncGraphPtr &lift_top_func_graph);
|
||||
|
||||
bool clone_all_valuenodes_;
|
||||
bool clone_all_child_graphs_;
|
||||
|
|
|
@ -180,7 +180,7 @@ def test_ad_fv_cnode_order():
|
|||
|
||||
# True and False branch of switch have different number of parameters.
|
||||
def test_if_branch_with_different_params():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -211,3 +211,44 @@ def test_if_branch_with_different_params():
|
|||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
grad_net(idx, end, x)
|
||||
|
||||
|
||||
# Only lift fv in scope of lift_top_func_graph other than all func_graphs inside manager.
|
||||
# Otherwise, "Illegal AnfNode for evaluating" may be reported
|
||||
# because weight1 in Net may use old_parameter other than replicated one.
|
||||
def test_limit_lift_fv_scope():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.weight1 = Parameter(Tensor(np.array([1.0], dtype=np.float32)), name="weight1")
|
||||
|
||||
def construct(self, x, y):
|
||||
def inner_add(a, b):
|
||||
return a + b
|
||||
|
||||
out = inner_add(x, y) + self.weight1
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, x, y):
|
||||
def inner_grad_add(a, b):
|
||||
return a + b
|
||||
|
||||
d_weight = grad_by_list(self.net, self.weights)(x, y)[0]
|
||||
d_out = inner_grad_add(d_weight, y)
|
||||
return d_out
|
||||
|
||||
x = Tensor(np.array([2.0], dtype=np.float32))
|
||||
y = Tensor(np.array([2.0], dtype=np.float32))
|
||||
|
||||
net = Net()
|
||||
net.add_flags_recursive(defer_inline=True)
|
||||
grad_net = GradNet(net)
|
||||
grad_net.add_flags_recursive(defer_inline=True)
|
||||
grad_net(x, y)
|
||||
|
|
Loading…
Reference in New Issue