mirror of https://github.com/tracel-ai/burn.git
wip
This commit is contained in:
parent
d18d1b0bb9
commit
89eb248b88
|
@ -13,6 +13,8 @@ pub struct Mutex<T> {
|
|||
inner: MutexImported<T>,
|
||||
}
|
||||
|
||||
unsafe impl<T: Sync> Sync for Mutex<T> {}
|
||||
|
||||
impl<T> Mutex<T> {
|
||||
/// Creates a new mutex in an unlocked state ready for use.
|
||||
#[inline(always)]
|
||||
|
|
|
@ -15,16 +15,23 @@ default = ["async"]
|
|||
async = []
|
||||
# Still experimental
|
||||
autotune = []
|
||||
std = [
|
||||
"rand/std",
|
||||
"burn-tensor/std",
|
||||
"burn-common/std",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
burn-common = { path = "../burn-common", version = "0.9.0" }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.9.0" }
|
||||
burn-common = { path = "../burn-common", version = "0.9.0", default-features = false }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.9.0", default-features = false }
|
||||
bytemuck = { workspace = true }
|
||||
|
||||
derive-new = { workspace = true }
|
||||
log = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
dashmap = {workspace = true}
|
||||
|
||||
# WGPU stuff
|
||||
futures-intrusive = { workspace = true }
|
||||
|
@ -34,6 +41,7 @@ wgpu = { workspace = true }
|
|||
# Template
|
||||
serde = { workspace = true }
|
||||
text_placeholder = { version = "0.5.0", features = ["struct_context"] }
|
||||
lazy_static = { version = "1.4.0", default-features = false, features = ["spin_no_std"] }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "0.9.0", default-features = false, features = [
|
||||
|
|
|
@ -2,9 +2,9 @@ use super::client::ContextClient;
|
|||
use crate::{
|
||||
context::server::ContextServer,
|
||||
kernel::{DynamicKernel, StaticKernel},
|
||||
tune::Tuner,
|
||||
GraphicsApi, WgpuDevice,
|
||||
};
|
||||
|
||||
use burn_common::id::IdGenerator;
|
||||
use spin::Mutex;
|
||||
use std::{any::TypeId, borrow::Cow, collections::HashMap, sync::Arc};
|
||||
|
@ -32,11 +32,15 @@ pub struct Context {
|
|||
device_wgpu: Arc<wgpu::Device>,
|
||||
cache: Mutex<HashMap<TemplateKey, Arc<ComputePipeline>>>,
|
||||
client: ContextClientImpl,
|
||||
pub(crate) tuner: Tuner,
|
||||
#[cfg(feature = "autotune")]
|
||||
pub(crate) tuner: tune::Tuner,
|
||||
pub(crate) device: WgpuDevice,
|
||||
pub(crate) info: wgpu::AdapterInfo,
|
||||
}
|
||||
|
||||
unsafe impl Send for Context {}
|
||||
unsafe impl Sync for Context {}
|
||||
|
||||
#[derive(Debug, Hash, PartialOrd, PartialEq, Eq)]
|
||||
enum TemplateKey {
|
||||
Static(TypeId),
|
||||
|
@ -71,7 +75,8 @@ impl Context {
|
|||
device,
|
||||
client,
|
||||
cache: Mutex::new(HashMap::new()),
|
||||
tuner: Tuner::new(),
|
||||
#[cfg(feature = "autotune")]
|
||||
tuner: crate::Tuner::new(),
|
||||
info,
|
||||
}
|
||||
}
|
||||
|
@ -231,7 +236,7 @@ impl PartialEq for Context {
|
|||
async fn select_device<G: GraphicsApi>(
|
||||
device: &WgpuDevice,
|
||||
) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) {
|
||||
let adapter = select_adapter::<G>(device);
|
||||
let adapter = select_adapter::<G>(device).await;
|
||||
let limits = adapter.limits();
|
||||
|
||||
let (device, queue) = adapter
|
||||
|
@ -256,12 +261,35 @@ async fn select_device<G: GraphicsApi>(
|
|||
(device, queue, adapter.get_info())
|
||||
}
|
||||
|
||||
fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> wgpu::Adapter {
|
||||
async fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> wgpu::Adapter {
|
||||
let instance = wgpu::Instance::default();
|
||||
|
||||
let mut adapters_other = Vec::new();
|
||||
let mut adapters = Vec::new();
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
let power_preference = match device {
|
||||
WgpuDevice::DiscreteGpu(_) => wgpu::PowerPreference::HighPerformance,
|
||||
WgpuDevice::IntegratedGpu(_) => wgpu::PowerPreference::LowPower,
|
||||
WgpuDevice::VirtualGpu(_) => wgpu::PowerPreference::None,
|
||||
WgpuDevice::Cpu => wgpu::PowerPreference::None,
|
||||
WgpuDevice::BestAvailable => wgpu::PowerPreference::HighPerformance,
|
||||
};
|
||||
|
||||
let adapter = instance
|
||||
.request_adapter(&wgpu::RequestAdapterOptionsBase {
|
||||
power_preference,
|
||||
..Default::default()
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Some(adapter) = adapter {
|
||||
adapters.push(adapter);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
instance
|
||||
.enumerate_adapters(G::backend().into())
|
||||
.for_each(|adapter| {
|
||||
|
|
|
@ -3,9 +3,13 @@ pub(crate) mod utils;
|
|||
mod mem_coalescing;
|
||||
mod naive;
|
||||
mod tiling2d;
|
||||
mod tune;
|
||||
|
||||
pub use mem_coalescing::*;
|
||||
pub use naive::*;
|
||||
pub use tiling2d::*;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
mod tune;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub use tune::*;
|
||||
|
|
|
@ -16,6 +16,8 @@ pub(crate) mod context;
|
|||
pub(crate) mod element;
|
||||
pub(crate) mod pool;
|
||||
pub(crate) mod tensor;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub(crate) mod tune;
|
||||
|
||||
mod device;
|
||||
|
|
|
@ -14,6 +14,9 @@ pub struct WgpuTensor<E: WgpuElement, const D: usize> {
|
|||
elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
unsafe impl<E: WgpuElement, const D: usize> Send for WgpuTensor<E, D> {}
|
||||
unsafe impl<E: WgpuElement, const D: usize> Sync for WgpuTensor<E, D> {}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WgpuTensorDyn<E: WgpuElement> {
|
||||
pub(crate) context: Arc<Context>,
|
||||
|
|
|
@ -190,6 +190,7 @@ fn no_std_checks() {
|
|||
|
||||
// Run checks for the following crates
|
||||
build_and_test_no_std("burn");
|
||||
build_and_test_no_std("burn-wgpu");
|
||||
build_and_test_no_std("burn-core");
|
||||
build_and_test_no_std("burn-common");
|
||||
build_and_test_no_std("burn-tensor");
|
||||
|
|
Loading…
Reference in New Issue