forked from mindspore-Ecosystem/mindspore
Fix executive order problem: the user of MakeTuple(Load, ...) do not
attach UpdateState
This commit is contained in:
parent
022c1c4583
commit
16466f3453
|
@ -37,7 +37,13 @@ class OrderEnforcer {
|
|||
void Run() {
|
||||
auto nodes = MakeTopoSortMap();
|
||||
for (auto &node : nodes) {
|
||||
HandleNode(node);
|
||||
if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
HandleUpdateState(node);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
// op(MakTuple(Load, ...)) sometimes do not attach update_state,
|
||||
// So need special treatment in order to ensure the exec_order of MakeTuple users.
|
||||
HandleMakeTupleUsers(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,11 +56,7 @@ class OrderEnforcer {
|
|||
return nodes;
|
||||
}
|
||||
|
||||
void HandleNode(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
// Skip nodes other than UpdateState.
|
||||
return;
|
||||
}
|
||||
void HandleUpdateState(const AnfNodePtr &node) {
|
||||
auto update_state = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(update_state);
|
||||
const size_t update_state_inputs_size = 3;
|
||||
|
@ -74,6 +76,98 @@ class OrderEnforcer {
|
|||
}
|
||||
}
|
||||
|
||||
bool CheckMakeTupleHaveLoad(const CNodePtr &cnode) {
|
||||
auto inputs = cnode->inputs();
|
||||
for (size_t index = 1; index < inputs.size(); index++) {
|
||||
auto input = cnode->input(index);
|
||||
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FindUpdateStateUsers(const CNodePtr &cnode) {
|
||||
auto &node_users = manager_->node_users();
|
||||
auto iter = node_users.find(cnode);
|
||||
if (iter == node_users.end()) {
|
||||
return {};
|
||||
}
|
||||
std::vector<AnfNodePtr> update_states;
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
auto &user_node = user.first;
|
||||
if (IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) {
|
||||
update_states.emplace_back(user_node);
|
||||
} else if (IsPrimitiveCNode(user_node, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple_users = FindUpdateStateUsers(user_node->cast<CNodePtr>());
|
||||
for (auto make_tuple_user : make_tuple_users) {
|
||||
if (IsPrimitiveCNode(make_tuple_user, prim::kPrimUpdateState)) {
|
||||
update_states.emplace_back(make_tuple_user);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return update_states;
|
||||
}
|
||||
|
||||
AnfNodePtr FindLastUpdateState(const CNodePtr &cnode) {
|
||||
auto inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> all_update_states;
|
||||
for (size_t index = 1; index < inputs.size(); index++) {
|
||||
auto input = cnode->input(index);
|
||||
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||
std::vector<AnfNodePtr> update_states = FindUpdateStateUsers(input->cast<CNodePtr>());
|
||||
std::copy(update_states.begin(), update_states.end(), std::back_inserter(all_update_states));
|
||||
}
|
||||
}
|
||||
AnfNodePtr last_update_state = nullptr;
|
||||
if (all_update_states.empty()) {
|
||||
return last_update_state;
|
||||
}
|
||||
if (all_update_states.size() == 1) {
|
||||
return all_update_states[0];
|
||||
}
|
||||
for (size_t i = 0; i < all_update_states.size() - 1; i++) {
|
||||
auto cur_update_state = all_update_states[i];
|
||||
auto next_update_state = all_update_states[i + 1];
|
||||
if (topo_sort_map_[cur_update_state] <= topo_sort_map_[next_update_state]) {
|
||||
last_update_state = next_update_state;
|
||||
}
|
||||
}
|
||||
return last_update_state;
|
||||
}
|
||||
|
||||
// Convert:
|
||||
// load1 = Load(para1, u1)
|
||||
// load2 = Load(para2, u2)
|
||||
// maketuple1 = MakeTuple(inputs, load1, load2)
|
||||
// addn = AddN(maketupe1) or other-op
|
||||
// maketuple2 = MakeTuple(load1, load2)
|
||||
// u3 = UpdateState(u', maketuple2)
|
||||
// assign = Assign(para2, inputs, u3)
|
||||
// To:
|
||||
// load1 = Load(para1, u1)
|
||||
// load2 = Load(para2, u2)
|
||||
// maketuple1 = MakeTuple(inputs, load1, load2)
|
||||
// addn = AddN(maketupe1) or other-op
|
||||
// maketuple2 = MakeTuple(load1, load2)
|
||||
// u3 = UpdateState(u', maketuple2, addn) # need put addn or other-op into u3 inputs
|
||||
// assign = Assign(para2, inputs, u3)
|
||||
void HandleMakeTupleUsers(const AnfNodePtr &node) {
|
||||
auto maketuple = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(maketuple);
|
||||
if (CheckMakeTupleHaveLoad(maketuple)) {
|
||||
auto update_state = FindLastUpdateState(maketuple);
|
||||
if (update_state != nullptr) {
|
||||
std::unordered_set<AnfNodePtr> maketuple_users = GetSpecialOperatorRealUsers(maketuple);
|
||||
auto update_state_cnode = update_state->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(update_state_cnode);
|
||||
AddInputEdges(update_state_cnode, maketuple_users);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool IsRef(const AnfNodePtr &node) {
|
||||
auto &abs = node->abstract();
|
||||
return abs != nullptr && abs->isa<abstract::AbstractRef>();
|
||||
|
|
|
@ -81,3 +81,43 @@ def test_auto_monad_addn_adam():
|
|||
allclose_nparray(new_var_pyn.asnumpy(), new_var.asnumpy(), 0.001, 0.001)
|
||||
allclose_nparray(new_m_pyn.asnumpy(), new_m.asnumpy(), 0.001, 0.001)
|
||||
allclose_nparray(new_v_pyn.asnumpy(), new_v.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
class AutoMonadTwoAssignTwoAddnDependencyNet(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.parameter1 = ms.Parameter(Tensor([1.0], ms.float32), name="parameter1")
|
||||
self.parameter2 = ms.Parameter(Tensor([3.0], ms.float32), name="parameter2")
|
||||
self.assign = P.Assign()
|
||||
self.addN = P.AddN()
|
||||
|
||||
def construct(self, inputs):
|
||||
self.assign(self.parameter1, inputs)
|
||||
out = self.addN((inputs, self.parameter1, self.parameter2))
|
||||
self.assign(self.parameter2, inputs)
|
||||
out = self.addN((out, self.parameter1, self.parameter2))
|
||||
return out
|
||||
|
||||
|
||||
class AutoMonadTwoAssignTwoAddnDependencyBenchmarkNet(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.parameter2 = ms.Parameter(Tensor([3.0], ms.float32), name="parameter2")
|
||||
self.addN = P.AddN()
|
||||
|
||||
def construct(self, inputs):
|
||||
out = self.addN((inputs, inputs, self.parameter2))
|
||||
out = self.addN((out, inputs, inputs))
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_auto_monad_read_dependency_two_assign_two_addn():
|
||||
net = AutoMonadTwoAssignTwoAddnDependencyNet()
|
||||
benchmarknet = AutoMonadTwoAssignTwoAddnDependencyBenchmarkNet()
|
||||
out1 = net(Tensor([9.0], ms.float32))
|
||||
out2 = benchmarknet(Tensor([9.0], ms.float32))
|
||||
allclose_nparray(out1.asnumpy(), out2.asnumpy(), 0.001, 0.001)
|
Loading…
Reference in New Issue