From cf4121126a31679693399e9828d359b539e72258 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Sat, 31 Jul 2021 14:14:31 +0800 Subject: [PATCH] Only replace the primal user with tuple_getitem once --- .../ccsrc/frontend/optimizer/ad/dfunctor.cc | 58 ++++++---- tests/ut/python/optimizer/test_auto_grad.py | 109 ++++++++++++++++++ 2 files changed, 142 insertions(+), 25 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 86d0bf78cc0..5f35bc96558 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -909,9 +909,9 @@ CNodePtr GetPrimalUser(const CNodePtr &j_user, const std::map> FindPrimalJPair(const FuncGraphManagerPtr &manager, - const FuncGraphPtr &primal_graph) { - std::vector> primal_j_pair; +static std::unordered_map> FindPrimalJPair(const FuncGraphManagerPtr &manager, + const FuncGraphPtr &primal_graph) { + std::vector j_users; std::map> primal_map; const auto &node_user_map = manager->node_users(); // Search primal graph user cnodes. @@ -930,20 +930,22 @@ static std::vector> FindPrimalJPair(const FuncGrap primal_map[fg] = {cnode}; } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { // To find J user. - auto j_user = GetJUser(node_user_map, cnode, index); - (void)primal_j_pair.emplace_back(std::pair(nullptr, j_user)); + j_users.emplace_back(GetJUser(node_user_map, cnode, index)); } } - for (auto &[primal_user, j_user] : primal_j_pair) { + std::unordered_map> primal_user_to_j_users; + for (const auto &j_user : j_users) { + MS_EXCEPTION_IF_NULL(j_user); auto primal = GetPrimalUser(j_user, primal_map); - if (primal != nullptr) { - MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString() - << " and J user is: " << j_user->DebugString(); - primal_user = primal; + if (primal == nullptr) { + continue; } + MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString() + << " and J user is: " << j_user->DebugString(); + primal_user_to_j_users[primal].emplace_back(j_user); } - return primal_j_pair; + return primal_user_to_j_users; } static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) { @@ -1007,26 +1009,32 @@ void DFunctor::EliminatePrimalGraph() { // Find primal user and paired J user cnodes. auto manager = primal_graph_->manager(); MS_EXCEPTION_IF_NULL(manager); - auto prim_j_pair = FindPrimalJPair(manager, primal_graph_); - for (auto &[primal_user, j_user] : prim_j_pair) { - if (primal_user == nullptr || j_user == nullptr) { - // Skip if one of them not found. - return; + auto primal_user_to_j_users = FindPrimalJPair(manager, primal_graph_); + for (const auto &iter : primal_user_to_j_users) { + auto primal_user = iter.first; + auto &j_users = iter.second; + MS_EXCEPTION_IF_NULL(primal_user); + if (j_users.size() == 1) { + // If both inputs are same except monads, we copy primal monad args to k graph + // so that they can be combined in CSE (common subexpression elimination) pass. + // Only do this when the size of j_users is 1 in order to keep the execution order. + const bool has_monad = CopyMonadArguments(primal_user, j_users[0]); + // Remove the UpdateState nodes after primal_user if need. + if (has_monad) { + RemovePrimalUpdateStates(manager, primal_user); + } + } else { + MS_LOG(INFO) << "There are multiple j users with the same primal user " << primal_user->DebugString(); } // Replace primal graph with k graph. auto k_vnode = NewValueNode(k_graph_); primal_user->set_input(0, k_vnode); - primal_user->set_abstract(j_user->abstract()); - - // If both inputs are same except monads, we copy primal monad args to k graph - // so that they can be combined in CSE (common subexpression elimination) pass. - const bool has_monad = CopyMonadArguments(primal_user, j_user); - // Remove the UpdateState nodes after primal_user if need. - if (has_monad) { - RemovePrimalUpdateStates(manager, primal_user); + if (j_users.empty()) { + MS_LOG(EXCEPTION) << "The J nodes for primal graph " << primal_graph_->ToString() + << " should be used by at least one other node."; } - + primal_user->set_abstract(j_users[0]->abstract()); // Insert tuple_getitem after primal user cnode. auto construct_wrapper = primal_user->func_graph(); auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem); diff --git a/tests/ut/python/optimizer/test_auto_grad.py b/tests/ut/python/optimizer/test_auto_grad.py index 3314472176a..ca5e7a85f00 100644 --- a/tests/ut/python/optimizer/test_auto_grad.py +++ b/tests/ut/python/optimizer/test_auto_grad.py @@ -252,3 +252,112 @@ def test_limit_lift_fv_scope(): grad_net = GradNet(net) grad_net.add_flags_recursive(defer_inline=True) grad_net(x, y) + + +def test_same_primal_used_by_multi_j(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + return x + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.grad = ops.GradOperation() + + def construct(self, x): + out = self.net(x) + gout = self.grad(self.net)(x) + gout1 = self.grad(self.net)(x) + return out, gout, gout1 + + x = Tensor(np.array([1.0], dtype=np.float32)) + net = Net() + grad = GradNet(net) + grad(x) + + +def test_same_primal_used_by_multi_j_with_monad1(): + class AdamNet(nn.Cell): + def __init__(self, var, m, v): + super(AdamNet, self).__init__() + self.apply_adam = P.Adam() + self.var = Parameter(var, name="var") + self.m = Parameter(m, name="m") + self.v = Parameter(v, name="v") + + def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): + self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + return self.var + + class AdamGradNet(nn.Cell): + def __init__(self, network): + super(AdamGradNet, self).__init__() + self.grad_fn = ops.GradOperation(sens_param=True) + self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))] + self.network = network + + def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): + out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + gout1 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0]) + gout2 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1]) + return out, gout1, gout2 + + var = Tensor(np.ones([3, 3, 3]).astype(np.float32)) + m = Tensor(np.ones([3, 3, 3]).astype(np.float32)) + v = Tensor(np.ones([3, 3, 3]).astype(np.float32)) + beta1_power = Tensor(np.array([0.9], dtype=np.float32)) + beta2_power = Tensor(np.array([0.999], dtype=np.float32)) + lr = Tensor(np.array([0.001], dtype=np.float32)) + beta1 = Tensor(np.array([0.9], dtype=np.float32)) + beta2 = Tensor(np.array([0.999], dtype=np.float32)) + epsilon = Tensor(np.array([1e-8], dtype=np.float32)) + grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) + net = AdamNet(var, m, v) + grad_net = AdamGradNet(net) + grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + + +def test_same_primal_used_by_multi_j_with_monad2(): + class AdamNet(nn.Cell): + def __init__(self, var, m, v): + super(AdamNet, self).__init__() + self.apply_adam = P.Adam() + self.var = Parameter(var, name="var") + self.m = Parameter(m, name="m") + self.v = Parameter(v, name="v") + + def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): + self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + return self.var + + class AdamGradNet(nn.Cell): + def __init__(self, network): + super(AdamGradNet, self).__init__() + self.grad = ops.GradOperation(sens_param=True) + self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))] + self.network = network + + def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): + out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + grad_fn = self.grad(self.network) + gout1 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0]) + gout2 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1]) + return out, gout1, gout2 + + var = Tensor(np.ones([3, 3, 3]).astype(np.float32)) + m = Tensor(np.ones([3, 3, 3]).astype(np.float32)) + v = Tensor(np.ones([3, 3, 3]).astype(np.float32)) + beta1_power = Tensor(np.array([0.9], dtype=np.float32)) + beta2_power = Tensor(np.array([0.999], dtype=np.float32)) + lr = Tensor(np.array([0.001], dtype=np.float32)) + beta1 = Tensor(np.array([0.9], dtype=np.float32)) + beta2 = Tensor(np.array([0.999], dtype=np.float32)) + epsilon = Tensor(np.array([1e-8], dtype=np.float32)) + grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) + net = AdamNet(var, m, v) + grad_net = AdamGradNet(net) + grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)