mirror of https://github.com/tracel-ai/burn.git
Fix/devices api (#990)
This commit is contained in:
parent
3d6c738776
commit
630044e96b
|
@ -84,9 +84,14 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
|||
/// Type to save and load the module.
|
||||
type Record: Record;
|
||||
|
||||
/// Collects devices in the given vector and returns it with the devices found in the module
|
||||
/// structure without duplicates.
|
||||
fn devices(&self, devices: Devices<B>) -> Devices<B>;
|
||||
/// Return all the devices found in the underneath module tree added to the given vector
|
||||
/// without duplicates.
|
||||
fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
|
||||
|
||||
/// Return all the devices found in the underneath module tree without duplicates.
|
||||
fn devices(&self) -> Devices<B> {
|
||||
self.collect_devices(Devices::<B>::new())
|
||||
}
|
||||
|
||||
/// Fork the module and all of its sub-modules to the given device.
|
||||
///
|
||||
|
|
|
@ -75,7 +75,7 @@ macro_rules! constant {
|
|||
self
|
||||
}
|
||||
|
||||
fn devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
|
||||
fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
|
||||
devices
|
||||
}
|
||||
};
|
||||
|
@ -147,7 +147,7 @@ impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
|
|||
self.to_device(device)
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Devices<B>) -> Devices<B> {
|
||||
fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {
|
||||
let device = self.device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
|
@ -195,7 +195,7 @@ impl<B: Backend> Module<B> for PhantomData<B> {
|
|||
self
|
||||
}
|
||||
|
||||
fn devices(&self, devices: Devices<B>) -> Devices<B> {
|
||||
fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,9 +37,9 @@ where
|
|||
self.map(|module| module.fork(device))
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
if let Some(module) = self.as_ref() {
|
||||
devices = module.devices(devices);
|
||||
devices = module.collect_devices(devices);
|
||||
}
|
||||
|
||||
devices
|
||||
|
@ -105,9 +105,9 @@ where
|
|||
self.into_iter().map(|module| module.fork(device)).collect()
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
for module in self.iter() {
|
||||
devices = module.devices(devices);
|
||||
devices = module.collect_devices(devices);
|
||||
}
|
||||
|
||||
devices
|
||||
|
@ -134,9 +134,9 @@ where
|
|||
{
|
||||
type Record = [T::Record; N];
|
||||
|
||||
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
for module in self.iter() {
|
||||
devices = module.devices(devices);
|
||||
devices = module.collect_devices(devices);
|
||||
}
|
||||
|
||||
devices
|
||||
|
|
|
@ -95,7 +95,10 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
|
|||
self.to_device(device) // Same thing here since no grad.
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
|
||||
fn collect_devices(
|
||||
&self,
|
||||
mut devices: Vec<<B as Backend>::Device>,
|
||||
) -> Vec<<B as Backend>::Device> {
|
||||
let device = self.value.read().unwrap().device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
|
|
|
@ -75,7 +75,10 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
|
|||
})
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
|
||||
fn collect_devices(
|
||||
&self,
|
||||
mut devices: Vec<<B as Backend>::Device>,
|
||||
) -> Vec<<B as Backend>::Device> {
|
||||
let device = self.device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
|
@ -122,7 +125,10 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
|
|||
self.to_device(device) // Don't support autodiff.
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
|
||||
fn collect_devices(
|
||||
&self,
|
||||
mut devices: Vec<<B as Backend>::Device>,
|
||||
) -> Vec<<B as Backend>::Device> {
|
||||
let device = self.device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
|
@ -169,7 +175,10 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
|
|||
self.to_device(device) // Don't support autodiff.
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
|
||||
fn collect_devices(
|
||||
&self,
|
||||
mut devices: Vec<<B as Backend>::Device>,
|
||||
) -> Vec<<B as Backend>::Device> {
|
||||
let device = self.device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
|
|
|
@ -29,7 +29,7 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
|||
let num_params_fn = generator.gen_num_params();
|
||||
let visit = generator.gen_visit();
|
||||
let map_mut = generator.gen_map();
|
||||
let devices = generator.gen_devices();
|
||||
let collect_devices = generator.gen_collect_devices();
|
||||
let to_device = generator.gen_to_device();
|
||||
let fork = generator.gen_fork();
|
||||
let valid_fn = generator.gen_valid();
|
||||
|
@ -54,7 +54,7 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
|||
#visit
|
||||
#map_mut
|
||||
|
||||
#devices
|
||||
#collect_devices
|
||||
#to_device
|
||||
#fork
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use proc_macro2::TokenStream;
|
|||
pub(crate) trait ModuleCodegen {
|
||||
fn gen_num_params(&self) -> TokenStream;
|
||||
fn gen_visit(&self) -> TokenStream;
|
||||
fn gen_devices(&self) -> TokenStream;
|
||||
fn gen_collect_devices(&self) -> TokenStream;
|
||||
fn gen_to_device(&self) -> TokenStream;
|
||||
fn gen_fork(&self) -> TokenStream;
|
||||
fn gen_map(&self) -> TokenStream;
|
||||
|
|
|
@ -39,15 +39,15 @@ impl ModuleCodegen for StructModuleCodegen {
|
|||
}
|
||||
}
|
||||
|
||||
fn gen_devices(&self) -> TokenStream {
|
||||
fn gen_collect_devices(&self) -> TokenStream {
|
||||
let body = self.gen_fields_fn(|name| {
|
||||
quote! {
|
||||
let devices = burn::module::Module::<B>::devices(&self.#name, devices);
|
||||
let devices = burn::module::Module::<B>::collect_devices(&self.#name, devices);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
|
||||
fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
|
||||
#body
|
||||
|
||||
devices
|
||||
|
|
|
@ -88,7 +88,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
pub fn forward(&self, item: TextClassificationTrainingBatch<B>) -> ClassificationOutput<B> {
|
||||
// Get batch and sequence length, and the device
|
||||
let [batch_size, seq_length] = item.tokens.dims();
|
||||
let device = &self.embedding_token.devices(Vec::new())[0];
|
||||
let device = &self.embedding_token.devices()[0];
|
||||
|
||||
// Move tensors to the correct device
|
||||
let tokens = item.tokens.to_device(device);
|
||||
|
@ -128,7 +128,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
pub fn infer(&self, item: TextClassificationInferenceBatch<B>) -> Tensor<B, 2> {
|
||||
// Get batch and sequence length, and the device
|
||||
let [batch_size, seq_length] = item.tokens.dims();
|
||||
let device = &self.embedding_token.devices(Vec::new())[0];
|
||||
let device = &self.embedding_token.devices()[0];
|
||||
|
||||
// Move tensors to the correct device
|
||||
let tokens = item.tokens.to_device(device);
|
||||
|
|
|
@ -58,7 +58,7 @@ impl<B: Backend> TextGenerationModel<B> {
|
|||
item: TrainingTextGenerationBatch<B>,
|
||||
) -> ClassificationOutput<B> {
|
||||
let [batch_size, seq_length] = item.tokens_inputs.dims();
|
||||
let device = &self.devices(Vec::new())[0];
|
||||
let device = &self.devices()[0];
|
||||
|
||||
let inputs = item.tokens_inputs.to_device(device);
|
||||
let targets = item.targets.to_device(device);
|
||||
|
|
Loading…
Reference in New Issue