forked from mindspore-Ecosystem/mindspore
Only replace the primal user with tuple_getitem once
This commit is contained in:
parent
41b89ad60c
commit
cf4121126a
|
@ -909,9 +909,9 @@ CNodePtr GetPrimalUser(const CNodePtr &j_user, const std::map<FuncGraphPtr, std:
|
|||
return primal_user;
|
||||
}
|
||||
|
||||
static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
|
||||
static std::unordered_map<CNodePtr, std::vector<CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
|
||||
const FuncGraphPtr &primal_graph) {
|
||||
std::vector<std::pair<CNodePtr, CNodePtr>> primal_j_pair;
|
||||
std::vector<CNodePtr> j_users;
|
||||
std::map<FuncGraphPtr, std::vector<CNodePtr>> primal_map;
|
||||
const auto &node_user_map = manager->node_users();
|
||||
// Search primal graph user cnodes.
|
||||
|
@ -930,20 +930,22 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap
|
|||
primal_map[fg] = {cnode};
|
||||
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
|
||||
// To find J user.
|
||||
auto j_user = GetJUser(node_user_map, cnode, index);
|
||||
(void)primal_j_pair.emplace_back(std::pair<CNodePtr, CNodePtr>(nullptr, j_user));
|
||||
j_users.emplace_back(GetJUser(node_user_map, cnode, index));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &[primal_user, j_user] : primal_j_pair) {
|
||||
std::unordered_map<CNodePtr, std::vector<CNodePtr>> primal_user_to_j_users;
|
||||
for (const auto &j_user : j_users) {
|
||||
MS_EXCEPTION_IF_NULL(j_user);
|
||||
auto primal = GetPrimalUser(j_user, primal_map);
|
||||
if (primal != nullptr) {
|
||||
if (primal == nullptr) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
|
||||
<< " and J user is: " << j_user->DebugString();
|
||||
primal_user = primal;
|
||||
primal_user_to_j_users[primal].emplace_back(j_user);
|
||||
}
|
||||
}
|
||||
return primal_j_pair;
|
||||
return primal_user_to_j_users;
|
||||
}
|
||||
|
||||
static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) {
|
||||
|
@ -1007,26 +1009,32 @@ void DFunctor::EliminatePrimalGraph() {
|
|||
// Find primal user and paired J user cnodes.
|
||||
auto manager = primal_graph_->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto prim_j_pair = FindPrimalJPair(manager, primal_graph_);
|
||||
for (auto &[primal_user, j_user] : prim_j_pair) {
|
||||
if (primal_user == nullptr || j_user == nullptr) {
|
||||
// Skip if one of them not found.
|
||||
return;
|
||||
auto primal_user_to_j_users = FindPrimalJPair(manager, primal_graph_);
|
||||
for (const auto &iter : primal_user_to_j_users) {
|
||||
auto primal_user = iter.first;
|
||||
auto &j_users = iter.second;
|
||||
MS_EXCEPTION_IF_NULL(primal_user);
|
||||
if (j_users.size() == 1) {
|
||||
// If both inputs are same except monads, we copy primal monad args to k graph
|
||||
// so that they can be combined in CSE (common subexpression elimination) pass.
|
||||
// Only do this when the size of j_users is 1 in order to keep the execution order.
|
||||
const bool has_monad = CopyMonadArguments(primal_user, j_users[0]);
|
||||
// Remove the UpdateState nodes after primal_user if need.
|
||||
if (has_monad) {
|
||||
RemovePrimalUpdateStates(manager, primal_user);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(INFO) << "There are multiple j users with the same primal user " << primal_user->DebugString();
|
||||
}
|
||||
|
||||
// Replace primal graph with k graph.
|
||||
auto k_vnode = NewValueNode(k_graph_);
|
||||
primal_user->set_input(0, k_vnode);
|
||||
primal_user->set_abstract(j_user->abstract());
|
||||
|
||||
// If both inputs are same except monads, we copy primal monad args to k graph
|
||||
// so that they can be combined in CSE (common subexpression elimination) pass.
|
||||
const bool has_monad = CopyMonadArguments(primal_user, j_user);
|
||||
// Remove the UpdateState nodes after primal_user if need.
|
||||
if (has_monad) {
|
||||
RemovePrimalUpdateStates(manager, primal_user);
|
||||
if (j_users.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The J nodes for primal graph " << primal_graph_->ToString()
|
||||
<< " should be used by at least one other node.";
|
||||
}
|
||||
|
||||
primal_user->set_abstract(j_users[0]->abstract());
|
||||
// Insert tuple_getitem after primal user cnode.
|
||||
auto construct_wrapper = primal_user->func_graph();
|
||||
auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem);
|
||||
|
|
|
@ -252,3 +252,112 @@ def test_limit_lift_fv_scope():
|
|||
grad_net = GradNet(net)
|
||||
grad_net.add_flags_recursive(defer_inline=True)
|
||||
grad_net(x, y)
|
||||
|
||||
|
||||
def test_same_primal_used_by_multi_j():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad = ops.GradOperation()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.net(x)
|
||||
gout = self.grad(self.net)(x)
|
||||
gout1 = self.grad(self.net)(x)
|
||||
return out, gout, gout1
|
||||
|
||||
x = Tensor(np.array([1.0], dtype=np.float32))
|
||||
net = Net()
|
||||
grad = GradNet(net)
|
||||
grad(x)
|
||||
|
||||
|
||||
def test_same_primal_used_by_multi_j_with_monad1():
|
||||
class AdamNet(nn.Cell):
|
||||
def __init__(self, var, m, v):
|
||||
super(AdamNet, self).__init__()
|
||||
self.apply_adam = P.Adam()
|
||||
self.var = Parameter(var, name="var")
|
||||
self.m = Parameter(m, name="m")
|
||||
self.v = Parameter(v, name="v")
|
||||
|
||||
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||
self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
return self.var
|
||||
|
||||
class AdamGradNet(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(AdamGradNet, self).__init__()
|
||||
self.grad_fn = ops.GradOperation(sens_param=True)
|
||||
self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))]
|
||||
self.network = network
|
||||
|
||||
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||
out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
gout1 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0])
|
||||
gout2 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1])
|
||||
return out, gout1, gout2
|
||||
|
||||
var = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||
m = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||
v = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||
beta1_power = Tensor(np.array([0.9], dtype=np.float32))
|
||||
beta2_power = Tensor(np.array([0.999], dtype=np.float32))
|
||||
lr = Tensor(np.array([0.001], dtype=np.float32))
|
||||
beta1 = Tensor(np.array([0.9], dtype=np.float32))
|
||||
beta2 = Tensor(np.array([0.999], dtype=np.float32))
|
||||
epsilon = Tensor(np.array([1e-8], dtype=np.float32))
|
||||
grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||
net = AdamNet(var, m, v)
|
||||
grad_net = AdamGradNet(net)
|
||||
grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
|
||||
|
||||
def test_same_primal_used_by_multi_j_with_monad2():
|
||||
class AdamNet(nn.Cell):
|
||||
def __init__(self, var, m, v):
|
||||
super(AdamNet, self).__init__()
|
||||
self.apply_adam = P.Adam()
|
||||
self.var = Parameter(var, name="var")
|
||||
self.m = Parameter(m, name="m")
|
||||
self.v = Parameter(v, name="v")
|
||||
|
||||
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||
self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
return self.var
|
||||
|
||||
class AdamGradNet(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(AdamGradNet, self).__init__()
|
||||
self.grad = ops.GradOperation(sens_param=True)
|
||||
self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))]
|
||||
self.network = network
|
||||
|
||||
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||
out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
grad_fn = self.grad(self.network)
|
||||
gout1 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0])
|
||||
gout2 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1])
|
||||
return out, gout1, gout2
|
||||
|
||||
var = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||
m = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||
v = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||
beta1_power = Tensor(np.array([0.9], dtype=np.float32))
|
||||
beta2_power = Tensor(np.array([0.999], dtype=np.float32))
|
||||
lr = Tensor(np.array([0.001], dtype=np.float32))
|
||||
beta1 = Tensor(np.array([0.9], dtype=np.float32))
|
||||
beta2 = Tensor(np.array([0.999], dtype=np.float32))
|
||||
epsilon = Tensor(np.array([1e-8], dtype=np.float32))
|
||||
grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||
net = AdamNet(var, m, v)
|
||||
grad_net = AdamGradNet(net)
|
||||
grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
|
|
Loading…
Reference in New Issue