Only replace the primal user with tuple_getitem once

This commit is contained in:
yujianfeng 2021-07-31 14:14:31 +08:00
parent 41b89ad60c
commit cf4121126a
2 changed files with 142 additions and 25 deletions

View File

@ -909,9 +909,9 @@ CNodePtr GetPrimalUser(const CNodePtr &j_user, const std::map<FuncGraphPtr, std:
return primal_user;
}
static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
const FuncGraphPtr &primal_graph) {
std::vector<std::pair<CNodePtr, CNodePtr>> primal_j_pair;
static std::unordered_map<CNodePtr, std::vector<CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
const FuncGraphPtr &primal_graph) {
std::vector<CNodePtr> j_users;
std::map<FuncGraphPtr, std::vector<CNodePtr>> primal_map;
const auto &node_user_map = manager->node_users();
// Search primal graph user cnodes.
@ -930,20 +930,22 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap
primal_map[fg] = {cnode};
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
// To find J user.
auto j_user = GetJUser(node_user_map, cnode, index);
(void)primal_j_pair.emplace_back(std::pair<CNodePtr, CNodePtr>(nullptr, j_user));
j_users.emplace_back(GetJUser(node_user_map, cnode, index));
}
}
for (auto &[primal_user, j_user] : primal_j_pair) {
std::unordered_map<CNodePtr, std::vector<CNodePtr>> primal_user_to_j_users;
for (const auto &j_user : j_users) {
MS_EXCEPTION_IF_NULL(j_user);
auto primal = GetPrimalUser(j_user, primal_map);
if (primal != nullptr) {
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
<< " and J user is: " << j_user->DebugString();
primal_user = primal;
if (primal == nullptr) {
continue;
}
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
<< " and J user is: " << j_user->DebugString();
primal_user_to_j_users[primal].emplace_back(j_user);
}
return primal_j_pair;
return primal_user_to_j_users;
}
static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) {
@ -1007,26 +1009,32 @@ void DFunctor::EliminatePrimalGraph() {
// Find primal user and paired J user cnodes.
auto manager = primal_graph_->manager();
MS_EXCEPTION_IF_NULL(manager);
auto prim_j_pair = FindPrimalJPair(manager, primal_graph_);
for (auto &[primal_user, j_user] : prim_j_pair) {
if (primal_user == nullptr || j_user == nullptr) {
// Skip if one of them not found.
return;
auto primal_user_to_j_users = FindPrimalJPair(manager, primal_graph_);
for (const auto &iter : primal_user_to_j_users) {
auto primal_user = iter.first;
auto &j_users = iter.second;
MS_EXCEPTION_IF_NULL(primal_user);
if (j_users.size() == 1) {
// If both inputs are same except monads, we copy primal monad args to k graph
// so that they can be combined in CSE (common subexpression elimination) pass.
// Only do this when the size of j_users is 1 in order to keep the execution order.
const bool has_monad = CopyMonadArguments(primal_user, j_users[0]);
// Remove the UpdateState nodes after primal_user if need.
if (has_monad) {
RemovePrimalUpdateStates(manager, primal_user);
}
} else {
MS_LOG(INFO) << "There are multiple j users with the same primal user " << primal_user->DebugString();
}
// Replace primal graph with k graph.
auto k_vnode = NewValueNode(k_graph_);
primal_user->set_input(0, k_vnode);
primal_user->set_abstract(j_user->abstract());
// If both inputs are same except monads, we copy primal monad args to k graph
// so that they can be combined in CSE (common subexpression elimination) pass.
const bool has_monad = CopyMonadArguments(primal_user, j_user);
// Remove the UpdateState nodes after primal_user if need.
if (has_monad) {
RemovePrimalUpdateStates(manager, primal_user);
if (j_users.empty()) {
MS_LOG(EXCEPTION) << "The J nodes for primal graph " << primal_graph_->ToString()
<< " should be used by at least one other node.";
}
primal_user->set_abstract(j_users[0]->abstract());
// Insert tuple_getitem after primal user cnode.
auto construct_wrapper = primal_user->func_graph();
auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem);

View File

@ -252,3 +252,112 @@ def test_limit_lift_fv_scope():
grad_net = GradNet(net)
grad_net.add_flags_recursive(defer_inline=True)
grad_net(x, y)
def test_same_primal_used_by_multi_j():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
return x
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad = ops.GradOperation()
def construct(self, x):
out = self.net(x)
gout = self.grad(self.net)(x)
gout1 = self.grad(self.net)(x)
return out, gout, gout1
x = Tensor(np.array([1.0], dtype=np.float32))
net = Net()
grad = GradNet(net)
grad(x)
def test_same_primal_used_by_multi_j_with_monad1():
class AdamNet(nn.Cell):
def __init__(self, var, m, v):
super(AdamNet, self).__init__()
self.apply_adam = P.Adam()
self.var = Parameter(var, name="var")
self.m = Parameter(m, name="m")
self.v = Parameter(v, name="v")
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
return self.var
class AdamGradNet(nn.Cell):
def __init__(self, network):
super(AdamGradNet, self).__init__()
self.grad_fn = ops.GradOperation(sens_param=True)
self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))]
self.network = network
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
gout1 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0])
gout2 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1])
return out, gout1, gout2
var = Tensor(np.ones([3, 3, 3]).astype(np.float32))
m = Tensor(np.ones([3, 3, 3]).astype(np.float32))
v = Tensor(np.ones([3, 3, 3]).astype(np.float32))
beta1_power = Tensor(np.array([0.9], dtype=np.float32))
beta2_power = Tensor(np.array([0.999], dtype=np.float32))
lr = Tensor(np.array([0.001], dtype=np.float32))
beta1 = Tensor(np.array([0.9], dtype=np.float32))
beta2 = Tensor(np.array([0.999], dtype=np.float32))
epsilon = Tensor(np.array([1e-8], dtype=np.float32))
grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
net = AdamNet(var, m, v)
grad_net = AdamGradNet(net)
grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
def test_same_primal_used_by_multi_j_with_monad2():
class AdamNet(nn.Cell):
def __init__(self, var, m, v):
super(AdamNet, self).__init__()
self.apply_adam = P.Adam()
self.var = Parameter(var, name="var")
self.m = Parameter(m, name="m")
self.v = Parameter(v, name="v")
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
return self.var
class AdamGradNet(nn.Cell):
def __init__(self, network):
super(AdamGradNet, self).__init__()
self.grad = ops.GradOperation(sens_param=True)
self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))]
self.network = network
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
grad_fn = self.grad(self.network)
gout1 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0])
gout2 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1])
return out, gout1, gout2
var = Tensor(np.ones([3, 3, 3]).astype(np.float32))
m = Tensor(np.ones([3, 3, 3]).astype(np.float32))
v = Tensor(np.ones([3, 3, 3]).astype(np.float32))
beta1_power = Tensor(np.array([0.9], dtype=np.float32))
beta2_power = Tensor(np.array([0.999], dtype=np.float32))
lr = Tensor(np.array([0.001], dtype=np.float32))
beta1 = Tensor(np.array([0.9], dtype=np.float32))
beta2 = Tensor(np.array([0.999], dtype=np.float32))
epsilon = Tensor(np.array([1e-8], dtype=np.float32))
grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
net = AdamNet(var, m, v)
grad_net = AdamGradNet(net)
grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)