Fix executive order problem: the user of MakeTuple(Load, ...) do not

attach UpdateState
This commit is contained in:
Margaret_wangrui 2021-07-26 16:23:56 +08:00
parent 022c1c4583
commit 16466f3453
2 changed files with 140 additions and 6 deletions

View File

@ -37,7 +37,13 @@ class OrderEnforcer {
void Run() {
auto nodes = MakeTopoSortMap();
for (auto &node : nodes) {
HandleNode(node);
if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
HandleUpdateState(node);
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
// op(MakTuple(Load, ...)) sometimes do not attach update_state,
// So need special treatment in order to ensure the exec_order of MakeTuple users.
HandleMakeTupleUsers(node);
}
}
}
@ -50,11 +56,7 @@ class OrderEnforcer {
return nodes;
}
void HandleNode(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
// Skip nodes other than UpdateState.
return;
}
void HandleUpdateState(const AnfNodePtr &node) {
auto update_state = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(update_state);
const size_t update_state_inputs_size = 3;
@ -74,6 +76,98 @@ class OrderEnforcer {
}
}
bool CheckMakeTupleHaveLoad(const CNodePtr &cnode) {
auto inputs = cnode->inputs();
for (size_t index = 1; index < inputs.size(); index++) {
auto input = cnode->input(index);
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
return true;
}
}
return false;
}
std::vector<AnfNodePtr> FindUpdateStateUsers(const CNodePtr &cnode) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(cnode);
if (iter == node_users.end()) {
return {};
}
std::vector<AnfNodePtr> update_states;
auto &users = iter->second;
for (auto &user : users) {
auto &user_node = user.first;
if (IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) {
update_states.emplace_back(user_node);
} else if (IsPrimitiveCNode(user_node, prim::kPrimMakeTuple)) {
auto make_tuple_users = FindUpdateStateUsers(user_node->cast<CNodePtr>());
for (auto make_tuple_user : make_tuple_users) {
if (IsPrimitiveCNode(make_tuple_user, prim::kPrimUpdateState)) {
update_states.emplace_back(make_tuple_user);
}
}
}
}
return update_states;
}
AnfNodePtr FindLastUpdateState(const CNodePtr &cnode) {
auto inputs = cnode->inputs();
std::vector<AnfNodePtr> all_update_states;
for (size_t index = 1; index < inputs.size(); index++) {
auto input = cnode->input(index);
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
std::vector<AnfNodePtr> update_states = FindUpdateStateUsers(input->cast<CNodePtr>());
std::copy(update_states.begin(), update_states.end(), std::back_inserter(all_update_states));
}
}
AnfNodePtr last_update_state = nullptr;
if (all_update_states.empty()) {
return last_update_state;
}
if (all_update_states.size() == 1) {
return all_update_states[0];
}
for (size_t i = 0; i < all_update_states.size() - 1; i++) {
auto cur_update_state = all_update_states[i];
auto next_update_state = all_update_states[i + 1];
if (topo_sort_map_[cur_update_state] <= topo_sort_map_[next_update_state]) {
last_update_state = next_update_state;
}
}
return last_update_state;
}
// Convert:
// load1 = Load(para1, u1)
// load2 = Load(para2, u2)
// maketuple1 = MakeTuple(inputs, load1, load2)
// addn = AddN(maketupe1) or other-op
// maketuple2 = MakeTuple(load1, load2)
// u3 = UpdateState(u', maketuple2)
// assign = Assign(para2, inputs, u3)
// To:
// load1 = Load(para1, u1)
// load2 = Load(para2, u2)
// maketuple1 = MakeTuple(inputs, load1, load2)
// addn = AddN(maketupe1) or other-op
// maketuple2 = MakeTuple(load1, load2)
// u3 = UpdateState(u', maketuple2, addn) # need put addn or other-op into u3 inputs
// assign = Assign(para2, inputs, u3)
void HandleMakeTupleUsers(const AnfNodePtr &node) {
auto maketuple = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maketuple);
if (CheckMakeTupleHaveLoad(maketuple)) {
auto update_state = FindLastUpdateState(maketuple);
if (update_state != nullptr) {
std::unordered_set<AnfNodePtr> maketuple_users = GetSpecialOperatorRealUsers(maketuple);
auto update_state_cnode = update_state->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(update_state_cnode);
AddInputEdges(update_state_cnode, maketuple_users);
}
}
}
bool IsRef(const AnfNodePtr &node) {
auto &abs = node->abstract();
return abs != nullptr && abs->isa<abstract::AbstractRef>();

View File

@ -81,3 +81,43 @@ def test_auto_monad_addn_adam():
allclose_nparray(new_var_pyn.asnumpy(), new_var.asnumpy(), 0.001, 0.001)
allclose_nparray(new_m_pyn.asnumpy(), new_m.asnumpy(), 0.001, 0.001)
allclose_nparray(new_v_pyn.asnumpy(), new_v.asnumpy(), 0.001, 0.001)
class AutoMonadTwoAssignTwoAddnDependencyNet(Cell):
def __init__(self):
super().__init__()
self.parameter1 = ms.Parameter(Tensor([1.0], ms.float32), name="parameter1")
self.parameter2 = ms.Parameter(Tensor([3.0], ms.float32), name="parameter2")
self.assign = P.Assign()
self.addN = P.AddN()
def construct(self, inputs):
self.assign(self.parameter1, inputs)
out = self.addN((inputs, self.parameter1, self.parameter2))
self.assign(self.parameter2, inputs)
out = self.addN((out, self.parameter1, self.parameter2))
return out
class AutoMonadTwoAssignTwoAddnDependencyBenchmarkNet(Cell):
def __init__(self):
super().__init__()
self.parameter2 = ms.Parameter(Tensor([3.0], ms.float32), name="parameter2")
self.addN = P.AddN()
def construct(self, inputs):
out = self.addN((inputs, inputs, self.parameter2))
out = self.addN((out, inputs, inputs))
return out
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_auto_monad_read_dependency_two_assign_two_addn():
net = AutoMonadTwoAssignTwoAddnDependencyNet()
benchmarknet = AutoMonadTwoAssignTwoAddnDependencyBenchmarkNet()
out1 = net(Tensor([9.0], ms.float32))
out2 = benchmarknet(Tensor([9.0], ms.float32))
allclose_nparray(out1.asnumpy(), out2.asnumpy(), 0.001, 0.001)