[Core][Distributed] use existing torch.cuda.device (#4318)

[Core][Distributed] use existing torch.cuda.device context manager (#4318)
This commit is contained in:
youkaichao 2024-04-24 09:00:20 -07:00 committed by GitHub
parent 468d761b32
commit 3cd9b5bb2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 5 deletions

View File

@ -250,15 +250,13 @@ class NCCLCommunicator:
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
current_device = torch.cuda.current_device()
try:
torch.cuda.set_device(device)
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with torch.cuda.device(device):
NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.cuda.Stream()
finally:
torch.cuda.set_device(current_device)
def all_reduce(self,
tensor: torch.Tensor,