forked from mindspore-Ecosystem/mindspore
!21168 Only replace the primal user with tuple_getitem once
Merge pull request !21168 from YuJianfeng/ad
This commit is contained in:
commit
8306abeb67
|
@ -909,9 +909,9 @@ CNodePtr GetPrimalUser(const CNodePtr &j_user, const std::map<FuncGraphPtr, std:
|
||||||
return primal_user;
|
return primal_user;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
|
static std::unordered_map<CNodePtr, std::vector<CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
|
||||||
const FuncGraphPtr &primal_graph) {
|
const FuncGraphPtr &primal_graph) {
|
||||||
std::vector<std::pair<CNodePtr, CNodePtr>> primal_j_pair;
|
std::vector<CNodePtr> j_users;
|
||||||
std::map<FuncGraphPtr, std::vector<CNodePtr>> primal_map;
|
std::map<FuncGraphPtr, std::vector<CNodePtr>> primal_map;
|
||||||
const auto &node_user_map = manager->node_users();
|
const auto &node_user_map = manager->node_users();
|
||||||
// Search primal graph user cnodes.
|
// Search primal graph user cnodes.
|
||||||
|
@ -930,20 +930,22 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap
|
||||||
primal_map[fg] = {cnode};
|
primal_map[fg] = {cnode};
|
||||||
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
|
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
|
||||||
// To find J user.
|
// To find J user.
|
||||||
auto j_user = GetJUser(node_user_map, cnode, index);
|
j_users.emplace_back(GetJUser(node_user_map, cnode, index));
|
||||||
(void)primal_j_pair.emplace_back(std::pair<CNodePtr, CNodePtr>(nullptr, j_user));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &[primal_user, j_user] : primal_j_pair) {
|
std::unordered_map<CNodePtr, std::vector<CNodePtr>> primal_user_to_j_users;
|
||||||
|
for (const auto &j_user : j_users) {
|
||||||
|
MS_EXCEPTION_IF_NULL(j_user);
|
||||||
auto primal = GetPrimalUser(j_user, primal_map);
|
auto primal = GetPrimalUser(j_user, primal_map);
|
||||||
if (primal != nullptr) {
|
if (primal == nullptr) {
|
||||||
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
|
continue;
|
||||||
<< " and J user is: " << j_user->DebugString();
|
|
||||||
primal_user = primal;
|
|
||||||
}
|
}
|
||||||
|
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
|
||||||
|
<< " and J user is: " << j_user->DebugString();
|
||||||
|
primal_user_to_j_users[primal].emplace_back(j_user);
|
||||||
}
|
}
|
||||||
return primal_j_pair;
|
return primal_user_to_j_users;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) {
|
static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) {
|
||||||
|
@ -1007,26 +1009,32 @@ void DFunctor::EliminatePrimalGraph() {
|
||||||
// Find primal user and paired J user cnodes.
|
// Find primal user and paired J user cnodes.
|
||||||
auto manager = primal_graph_->manager();
|
auto manager = primal_graph_->manager();
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
auto prim_j_pair = FindPrimalJPair(manager, primal_graph_);
|
auto primal_user_to_j_users = FindPrimalJPair(manager, primal_graph_);
|
||||||
for (auto &[primal_user, j_user] : prim_j_pair) {
|
for (const auto &iter : primal_user_to_j_users) {
|
||||||
if (primal_user == nullptr || j_user == nullptr) {
|
auto primal_user = iter.first;
|
||||||
// Skip if one of them not found.
|
auto &j_users = iter.second;
|
||||||
return;
|
MS_EXCEPTION_IF_NULL(primal_user);
|
||||||
|
if (j_users.size() == 1) {
|
||||||
|
// If both inputs are same except monads, we copy primal monad args to k graph
|
||||||
|
// so that they can be combined in CSE (common subexpression elimination) pass.
|
||||||
|
// Only do this when the size of j_users is 1 in order to keep the execution order.
|
||||||
|
const bool has_monad = CopyMonadArguments(primal_user, j_users[0]);
|
||||||
|
// Remove the UpdateState nodes after primal_user if need.
|
||||||
|
if (has_monad) {
|
||||||
|
RemovePrimalUpdateStates(manager, primal_user);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << "There are multiple j users with the same primal user " << primal_user->DebugString();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace primal graph with k graph.
|
// Replace primal graph with k graph.
|
||||||
auto k_vnode = NewValueNode(k_graph_);
|
auto k_vnode = NewValueNode(k_graph_);
|
||||||
primal_user->set_input(0, k_vnode);
|
primal_user->set_input(0, k_vnode);
|
||||||
primal_user->set_abstract(j_user->abstract());
|
if (j_users.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The J nodes for primal graph " << primal_graph_->ToString()
|
||||||
// If both inputs are same except monads, we copy primal monad args to k graph
|
<< " should be used by at least one other node.";
|
||||||
// so that they can be combined in CSE (common subexpression elimination) pass.
|
|
||||||
const bool has_monad = CopyMonadArguments(primal_user, j_user);
|
|
||||||
// Remove the UpdateState nodes after primal_user if need.
|
|
||||||
if (has_monad) {
|
|
||||||
RemovePrimalUpdateStates(manager, primal_user);
|
|
||||||
}
|
}
|
||||||
|
primal_user->set_abstract(j_users[0]->abstract());
|
||||||
// Insert tuple_getitem after primal user cnode.
|
// Insert tuple_getitem after primal user cnode.
|
||||||
auto construct_wrapper = primal_user->func_graph();
|
auto construct_wrapper = primal_user->func_graph();
|
||||||
auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem);
|
auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem);
|
||||||
|
|
|
@ -252,3 +252,112 @@ def test_limit_lift_fv_scope():
|
||||||
grad_net = GradNet(net)
|
grad_net = GradNet(net)
|
||||||
grad_net.add_flags_recursive(defer_inline=True)
|
grad_net.add_flags_recursive(defer_inline=True)
|
||||||
grad_net(x, y)
|
grad_net(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_same_primal_used_by_multi_j():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
class GradNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad = ops.GradOperation()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.net(x)
|
||||||
|
gout = self.grad(self.net)(x)
|
||||||
|
gout1 = self.grad(self.net)(x)
|
||||||
|
return out, gout, gout1
|
||||||
|
|
||||||
|
x = Tensor(np.array([1.0], dtype=np.float32))
|
||||||
|
net = Net()
|
||||||
|
grad = GradNet(net)
|
||||||
|
grad(x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_same_primal_used_by_multi_j_with_monad1():
|
||||||
|
class AdamNet(nn.Cell):
|
||||||
|
def __init__(self, var, m, v):
|
||||||
|
super(AdamNet, self).__init__()
|
||||||
|
self.apply_adam = P.Adam()
|
||||||
|
self.var = Parameter(var, name="var")
|
||||||
|
self.m = Parameter(m, name="m")
|
||||||
|
self.v = Parameter(v, name="v")
|
||||||
|
|
||||||
|
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||||
|
self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||||
|
return self.var
|
||||||
|
|
||||||
|
class AdamGradNet(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(AdamGradNet, self).__init__()
|
||||||
|
self.grad_fn = ops.GradOperation(sens_param=True)
|
||||||
|
self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))]
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||||
|
out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||||
|
gout1 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0])
|
||||||
|
gout2 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1])
|
||||||
|
return out, gout1, gout2
|
||||||
|
|
||||||
|
var = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||||
|
m = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||||
|
v = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||||
|
beta1_power = Tensor(np.array([0.9], dtype=np.float32))
|
||||||
|
beta2_power = Tensor(np.array([0.999], dtype=np.float32))
|
||||||
|
lr = Tensor(np.array([0.001], dtype=np.float32))
|
||||||
|
beta1 = Tensor(np.array([0.9], dtype=np.float32))
|
||||||
|
beta2 = Tensor(np.array([0.999], dtype=np.float32))
|
||||||
|
epsilon = Tensor(np.array([1e-8], dtype=np.float32))
|
||||||
|
grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||||
|
net = AdamNet(var, m, v)
|
||||||
|
grad_net = AdamGradNet(net)
|
||||||
|
grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||||
|
|
||||||
|
|
||||||
|
def test_same_primal_used_by_multi_j_with_monad2():
|
||||||
|
class AdamNet(nn.Cell):
|
||||||
|
def __init__(self, var, m, v):
|
||||||
|
super(AdamNet, self).__init__()
|
||||||
|
self.apply_adam = P.Adam()
|
||||||
|
self.var = Parameter(var, name="var")
|
||||||
|
self.m = Parameter(m, name="m")
|
||||||
|
self.v = Parameter(v, name="v")
|
||||||
|
|
||||||
|
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||||
|
self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||||
|
return self.var
|
||||||
|
|
||||||
|
class AdamGradNet(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(AdamGradNet, self).__init__()
|
||||||
|
self.grad = ops.GradOperation(sens_param=True)
|
||||||
|
self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))]
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||||
|
out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||||
|
grad_fn = self.grad(self.network)
|
||||||
|
gout1 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0])
|
||||||
|
gout2 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1])
|
||||||
|
return out, gout1, gout2
|
||||||
|
|
||||||
|
var = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||||
|
m = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||||
|
v = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||||
|
beta1_power = Tensor(np.array([0.9], dtype=np.float32))
|
||||||
|
beta2_power = Tensor(np.array([0.999], dtype=np.float32))
|
||||||
|
lr = Tensor(np.array([0.001], dtype=np.float32))
|
||||||
|
beta1 = Tensor(np.array([0.9], dtype=np.float32))
|
||||||
|
beta2 = Tensor(np.array([0.999], dtype=np.float32))
|
||||||
|
epsilon = Tensor(np.array([1e-8], dtype=np.float32))
|
||||||
|
grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||||
|
net = AdamNet(var, m, v)
|
||||||
|
grad_net = AdamGradNet(net)
|
||||||
|
grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||||
|
|
Loading…
Reference in New Issue