forked from mindspore-Ecosystem/mindspore
!12314 [auto-monad] Change backend execution order sorting policy
From: @hwhewei Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
1239a4a848
|
@ -201,21 +201,17 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
|
|||
}
|
||||
|
||||
void KernelGraph::SetExecOrderByDefault() {
|
||||
std::queue<AnfNodePtr> seed_nodes;
|
||||
UpdateNodeEdgeList(&seed_nodes);
|
||||
std::queue<AnfNodePtr> zero_input_nodes;
|
||||
UpdateNodeEdgeList(&zero_input_nodes);
|
||||
execution_order_.clear();
|
||||
std::unordered_set<AnfNodePtr> visited_nodes;
|
||||
std::queue<AnfNodePtr> zero_input_nodes;
|
||||
AnfNodePtr last_communication_node = nullptr;
|
||||
std::queue<AnfNodePtr> communication_descendants;
|
||||
while (!seed_nodes.empty() || last_communication_node != nullptr) {
|
||||
while (!zero_input_nodes.empty() || last_communication_node != nullptr) {
|
||||
// seed nodes first, then visit last all reduce node descendant
|
||||
if (seed_nodes.empty()) {
|
||||
if (last_communication_node != nullptr) {
|
||||
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
|
||||
last_communication_node = nullptr;
|
||||
} else {
|
||||
zero_input_nodes.push(seed_nodes.front());
|
||||
seed_nodes.pop();
|
||||
}
|
||||
// all reduce node descendant first, then common queue
|
||||
while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
|
||||
|
@ -901,14 +897,11 @@ void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
|
|||
seed_nodes->push(node);
|
||||
continue;
|
||||
}
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto &inputs = cnode->inputs();
|
||||
// We push inputs from right to left, so that them can be evaluated from left to right.
|
||||
for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
|
||||
auto &input = *iter;
|
||||
for (auto &input : cnode->inputs()) {
|
||||
PushNoVisitedNode(input, &que, &visited_nodes);
|
||||
AddDependEdge(node, input, 1);
|
||||
}
|
||||
|
|
|
@ -1427,3 +1427,43 @@ def test_if_cast():
|
|||
r1 = net(beta1, beta2)
|
||||
expect = Tensor(np.array([3]).astype(np.float32))
|
||||
np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_multi_add_assign():
|
||||
class Net(Cell):
|
||||
def __init__(self, i1):
|
||||
super(Net, self).__init__()
|
||||
self.add = P.Add()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.assign = P.Assign()
|
||||
self.p = Parameter(i1, name='para')
|
||||
|
||||
def construct(self, a, d, e):
|
||||
res1 = self.add(self.add(self.add(self.p, a), a), a)
|
||||
mul = self.mul(d, e)
|
||||
self.assign(self.p, mul)
|
||||
res2 = self.sub(self.p, e)
|
||||
return res2, res1
|
||||
|
||||
def numpy_out(p, a, d, e):
|
||||
res1 = p + a + a + a
|
||||
res_as = d * e
|
||||
res2 = d * e - e
|
||||
return res2, res1, res_as
|
||||
|
||||
p = (np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32)
|
||||
i0 = (np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32)
|
||||
i1 = (np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32)
|
||||
i2 = (np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32)
|
||||
|
||||
net = Net(Tensor(p))
|
||||
r2, r1 = net(Tensor(i0), Tensor(i1), Tensor(i2))
|
||||
|
||||
outputs = [r2.asnumpy(), r1.asnumpy(), net.p.data.asnumpy()]
|
||||
expects = numpy_out(p, i0, i1, i2)
|
||||
np.testing.assert_array_equal(outputs, expects)
|
||||
|
|
|
@ -229,7 +229,7 @@ def test_bert_performance():
|
|||
|
||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||
loss_value = np.array(callback.loss_list)
|
||||
expect_loss_value = [11.3660, 11.3265, 11.3264]
|
||||
expect_loss_value = [11.3246, 11.2834, 11.2833]
|
||||
print("loss value: {}".format(loss_value))
|
||||
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
|
||||
|
||||
|
|
|
@ -229,8 +229,8 @@ def test_bert_precision(enable_graph_kernel=False):
|
|||
expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565,
|
||||
12.185522, 12.386192]
|
||||
else:
|
||||
expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656,
|
||||
12.407923, 12.631133]
|
||||
expect_loss_value = [12.206587, 11.940709, 11.930911, 11.937369, 11.932178, 12.556069, 12.130172, 12.783402,
|
||||
12.359581, 12.578078]
|
||||
print("loss value: {}".format(loss_value))
|
||||
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
|
||||
|
||||
|
|
Loading…
Reference in New Issue