From 64ae3d697c959edb138e269525dc03809490c673 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 13 Jun 2024 08:25:24 -0400 Subject: [PATCH] Fix clippy --- crates/burn-autodiff/src/backend.rs | 1 + crates/burn-cuda/src/runtime.rs | 3 ++- crates/burn-ndarray/src/backend.rs | 1 + crates/burn-tch/src/backend.rs | 3 +-- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/crates/burn-autodiff/src/backend.rs b/crates/burn-autodiff/src/backend.rs index e552fb7af..2bd7bafe4 100644 --- a/crates/burn-autodiff/src/backend.rs +++ b/crates/burn-autodiff/src/backend.rs @@ -5,6 +5,7 @@ use crate::{ tensor::AutodiffTensor, AutodiffBridge, }; +use alloc::vec::Vec; use burn_tensor::backend::{AutodiffBackend, Backend}; use core::marker::PhantomData; diff --git a/crates/burn-cuda/src/runtime.rs b/crates/burn-cuda/src/runtime.rs index ba24cab8b..13a1bc027 100644 --- a/crates/burn-cuda/src/runtime.rs +++ b/crates/burn-cuda/src/runtime.rs @@ -1,3 +1,4 @@ +use alloc::vec::Vec; use burn_common::stub::RwLock; use burn_compute::{ channel::MutexComputeChannel, @@ -24,7 +25,7 @@ impl burn_jit::JitRuntime for CudaRuntime { fn list_available_devices() -> Vec { (0..cudarc::driver::result::device::get_count().unwrap() as usize) - .map(|id| CudaDevice::new(id)) + .map(CudaDevice::new) .collect() } } diff --git a/crates/burn-ndarray/src/backend.rs b/crates/burn-ndarray/src/backend.rs index 61dfd027e..62813de61 100644 --- a/crates/burn-ndarray/src/backend.rs +++ b/crates/burn-ndarray/src/backend.rs @@ -1,6 +1,7 @@ use crate::NdArrayTensor; use crate::{element::FloatNdArrayElement, PrecisionBridge}; use alloc::string::String; +use alloc::{vec, vec::Vec}; use burn_common::stub::Mutex; use burn_tensor::backend::{Backend, DeviceId, DeviceOps}; use core::marker::PhantomData; diff --git a/crates/burn-tch/src/backend.rs b/crates/burn-tch/src/backend.rs index d936fc22d..4851d13e8 100644 --- a/crates/burn-tch/src/backend.rs +++ b/crates/burn-tch/src/backend.rs @@ -135,8 +135,7 @@ impl Backend for LibTorch { let mut devices = vec![LibTorchDevice::Cpu]; if tch::utils::has_cuda() { - devices - .extend((0..tch::Cuda::device_count() as usize).map(|id| LibTorchDevice::Cuda(id))); + devices.extend((0..tch::Cuda::device_count() as usize).map(LibTorchDevice::Cuda)); } if tch::utils::has_mps() {