mirror of https://github.com/vllm-project/vllm
[doc][distributed] add both gloo and nccl tests (#5834)
This commit is contained in:
parent
67882dbb44
commit
c18ebfdd71
|
@ -28,8 +28,8 @@ If it crashes, and the error trace shows somewhere around ``self.graph.replay()`
|
|||
|
||||
Here are some common issues that can cause hangs:
|
||||
|
||||
- **Incorrect network setup**: The vLLM instance cannot get the correct IP address. You can find the log such as ``DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl``. The IP address should be the correct one. If not, override the IP address by setting the environment variable ``export VLLM_HOST_IP=your_ip_address``.
|
||||
- **Incorrect hardware/driver**: GPU communication cannot be established. You can run the following sanity check script to see if the GPU communication is working correctly.
|
||||
- **Incorrect network setup**: The vLLM instance cannot get the correct IP address if you have complicated network config. You can find the log such as ``DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl``. The IP address should be the correct one. If not, override the IP address by setting the environment variable ``export VLLM_HOST_IP=your_ip_address``. You might also need to set ``export NCCL_SOCKET_IFNAME=your_network_interface`` and ``export GLOO_SOCKET_IFNAME=your_network_interface`` to specify the network interface for the IP address.
|
||||
- **Incorrect hardware/driver**: GPU/CPU communication cannot be established. You can run the following sanity check script to see if the GPU/CPU communication is working correctly.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -41,7 +41,14 @@ Here are some common issues that can cause hangs:
|
|||
dist.all_reduce(data, op=dist.ReduceOp.SUM)
|
||||
torch.cuda.synchronize()
|
||||
value = data.mean().item()
|
||||
assert value == dist.get_world_size()
|
||||
world_size = dist.get_world_size()
|
||||
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||
|
||||
gloo_group = dist.new_group(ranks=list(range(world_size)), backend="gloo")
|
||||
cpu_data = torch.FloatTensor([1,] * 128)
|
||||
dist.all_reduce(cpu_data, op=dist.ReduceOp.SUM, group=gloo_group)
|
||||
value = cpu_data.mean().item()
|
||||
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||
|
||||
.. tip::
|
||||
|
||||
|
|
Loading…
Reference in New Issue