forked from mindspore-Ecosystem/mindspore
Fix result error when calling AllReduce serially.
This commit is contained in:
parent
c8f69f5db2
commit
d9bcdac3dc
|
@ -40,6 +40,9 @@ void AssignGpuStream(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (allreduce_kernels.size() > 1) {
|
if (allreduce_kernels.size() > 1) {
|
||||||
|
// Assign multiple streams only when there's Recv node for AllReduce.
|
||||||
|
std::vector<SendRecvPair> send_recv_pairs;
|
||||||
|
if (FindAllReduceStreamSwitchPos(kernel_graph, &send_recv_pairs)) {
|
||||||
DeviceStream comm_stream = nullptr;
|
DeviceStream comm_stream = nullptr;
|
||||||
GPUDeviceManager::GetInstance().CreateStream(&comm_stream);
|
GPUDeviceManager::GetInstance().CreateStream(&comm_stream);
|
||||||
std::transform(
|
std::transform(
|
||||||
|
@ -47,14 +50,14 @@ void AssignGpuStream(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
||||||
AnfAlgo::SetNodeAttr("stream_id", MakeValue(reinterpret_cast<uintptr_t>(comm_stream)), allreduce_kernel);
|
AnfAlgo::SetNodeAttr("stream_id", MakeValue(reinterpret_cast<uintptr_t>(comm_stream)), allreduce_kernel);
|
||||||
return allreduce_kernel;
|
return allreduce_kernel;
|
||||||
});
|
});
|
||||||
|
|
||||||
std::vector<SendRecvPair> send_recv_pairs;
|
|
||||||
FindAllReduceStreamSwitchPos(kernel_graph, &send_recv_pairs);
|
|
||||||
InsertStreamSwitchNode(kernel_graph, send_recv_pairs);
|
InsertStreamSwitchNode(kernel_graph, send_recv_pairs);
|
||||||
|
} else {
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||||
std::vector<SendRecvPair> *send_recv_pairs) {
|
std::vector<SendRecvPair> *send_recv_pairs) {
|
||||||
auto execution_kernels = kernel_graph->execution_order();
|
auto execution_kernels = kernel_graph->execution_order();
|
||||||
std::vector<CNodePtr>::iterator iter, iter_begin;
|
std::vector<CNodePtr>::iterator iter, iter_begin;
|
||||||
|
@ -77,14 +80,15 @@ void FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &k
|
||||||
std::vector<CNodePtr>::iterator mock_recv_node_iter =
|
std::vector<CNodePtr>::iterator mock_recv_node_iter =
|
||||||
FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch);
|
FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch);
|
||||||
if (mock_recv_node_iter == iter_end) {
|
if (mock_recv_node_iter == iter_end) {
|
||||||
MS_LOG(WARNING) << "Can't find send node place before AllReduce node.";
|
MS_LOG(WARNING) << "Can't find recv node place after AllReduce node.";
|
||||||
continue;
|
return false;
|
||||||
}
|
}
|
||||||
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
|
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
|
||||||
IntToSize(mock_recv_node_iter - iter_begin)};
|
IntToSize(mock_recv_node_iter - iter_begin)};
|
||||||
send_recv_pairs->push_back(pair2);
|
send_recv_pairs->push_back(pair2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<CNodePtr>::iterator FindSendNodePos(std::vector<CNodePtr>::iterator begin,
|
std::vector<CNodePtr>::iterator FindSendNodePos(std::vector<CNodePtr>::iterator begin,
|
||||||
|
|
|
@ -48,7 +48,7 @@ struct StreamSwitchNode {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
void AssignGpuStream(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
void AssignGpuStream(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
void FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||||
std::vector<SendRecvPair> *send_recv_pairs);
|
std::vector<SendRecvPair> *send_recv_pairs);
|
||||||
// Find Send node position according to "mock" recv node.
|
// Find Send node position according to "mock" recv node.
|
||||||
// "mock" recv node is a gpu kernel node after a real Recv node, e.g. AllReduce node.
|
// "mock" recv node is a gpu kernel node after a real Recv node, e.g. AllReduce node.
|
||||||
|
|
|
@ -75,3 +75,49 @@ def test_AllReduce():
|
||||||
error2 = np.ones(shape=expect2.shape) * 1.0e-5
|
error2 = np.ones(shape=expect2.shape) * 1.0e-5
|
||||||
assert np.all(diff2 < error2)
|
assert np.all(diff2 < error2)
|
||||||
assert output[2].shape() == expect2.shape
|
assert output[2].shape() == expect2.shape
|
||||||
|
|
||||||
|
|
||||||
|
class Net2(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net2, self).__init__()
|
||||||
|
self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1')
|
||||||
|
|
||||||
|
self.op0 = "sum"
|
||||||
|
self.op1 = "sum"
|
||||||
|
self.op2 = "sum"
|
||||||
|
|
||||||
|
self.all_reduce1 = P.AllReduce(self.op0, group=NCCL_WORLD_COMM_GROUP)
|
||||||
|
self.all_reduce2 = P.AllReduce(self.op1, group=NCCL_WORLD_COMM_GROUP)
|
||||||
|
self.all_reduce3 = P.AllReduce(self.op2, group=NCCL_WORLD_COMM_GROUP)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
x = self.all_reduce1(self.x1)
|
||||||
|
y = self.all_reduce2(x)
|
||||||
|
z = self.all_reduce3(y)
|
||||||
|
return (x, y, z)
|
||||||
|
|
||||||
|
|
||||||
|
def test_AllReduce2():
|
||||||
|
all_reduce = Net2()
|
||||||
|
output = all_reduce()
|
||||||
|
|
||||||
|
expect0 = np.ones([3, 1, 3, 3]).astype(np.float32) * 0
|
||||||
|
for i in range(size):
|
||||||
|
part = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (i + 1)
|
||||||
|
expect0 += part
|
||||||
|
diff0 = abs(output[0].asnumpy() - expect0)
|
||||||
|
error0 = np.ones(shape=expect0.shape) * 1.0e-5
|
||||||
|
assert np.all(diff0 < error0)
|
||||||
|
assert output[0].shape() == expect0.shape
|
||||||
|
|
||||||
|
expect1 = expect0 * size
|
||||||
|
diff1 = abs(output[1].asnumpy() - expect1)
|
||||||
|
error1 = np.ones(shape=expect1.shape) * 1.0e-5
|
||||||
|
assert np.all(diff1 < error1)
|
||||||
|
assert output[1].shape() == expect1.shape
|
||||||
|
|
||||||
|
expect2 = expect1 * size
|
||||||
|
diff2 = abs(output[2].asnumpy() - expect2)
|
||||||
|
error2 = np.ones(shape=expect2.shape) * 1.0e-5
|
||||||
|
assert np.all(diff2 < error2)
|
||||||
|
assert output[2].shape() == expect2.shape
|
||||||
|
|
Loading…
Reference in New Issue