mirror of https://github.com/vllm-project/vllm
[Core][Distributed] use existing torch.cuda.device (#4318)
[Core][Distributed] use existing torch.cuda.device context manager (#4318)
This commit is contained in:
parent
468d761b32
commit
3cd9b5bb2d
|
@ -250,15 +250,13 @@ class NCCLCommunicator:
|
|||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
# nccl communicator and stream will use this device
|
||||
current_device = torch.cuda.current_device()
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
# `torch.cuda.device` is a context manager that changes the
|
||||
# current cuda device to the specified one
|
||||
with torch.cuda.device(device):
|
||||
NCCL_CHECK(
|
||||
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
||||
self.unique_id, self.rank))
|
||||
self.stream = torch.cuda.Stream()
|
||||
finally:
|
||||
torch.cuda.set_device(current_device)
|
||||
|
||||
def all_reduce(self,
|
||||
tensor: torch.Tensor,
|
||||
|
|
Loading…
Reference in New Issue