This commit is contained in:
nathaniel 2023-08-23 14:05:36 -04:00
parent d18d1b0bb9
commit 89eb248b88
7 changed files with 56 additions and 8 deletions

View File

@ -13,6 +13,8 @@ pub struct Mutex<T> {
inner: MutexImported<T>,
}
unsafe impl<T: Sync> Sync for Mutex<T> {}
impl<T> Mutex<T> {
/// Creates a new mutex in an unlocked state ready for use.
#[inline(always)]

View File

@ -15,16 +15,23 @@ default = ["async"]
async = []
# Still experimental
autotune = []
std = [
"rand/std",
"burn-tensor/std",
"burn-common/std",
]
[dependencies]
burn-common = { path = "../burn-common", version = "0.9.0" }
burn-tensor = { path = "../burn-tensor", version = "0.9.0" }
burn-common = { path = "../burn-common", version = "0.9.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.9.0", default-features = false }
bytemuck = { workspace = true }
derive-new = { workspace = true }
log = { workspace = true }
num-traits = { workspace = true }
rand = { workspace = true }
spin = { workspace = true }
dashmap = {workspace = true}
# WGPU stuff
futures-intrusive = { workspace = true }
@ -34,6 +41,7 @@ wgpu = { workspace = true }
# Template
serde = { workspace = true }
text_placeholder = { version = "0.5.0", features = ["struct_context"] }
lazy_static = { version = "1.4.0", default-features = false, features = ["spin_no_std"] }
[dev-dependencies]
burn-autodiff = { path = "../burn-autodiff", version = "0.9.0", default-features = false, features = [

View File

@ -2,9 +2,9 @@ use super::client::ContextClient;
use crate::{
context::server::ContextServer,
kernel::{DynamicKernel, StaticKernel},
tune::Tuner,
GraphicsApi, WgpuDevice,
};
use burn_common::id::IdGenerator;
use spin::Mutex;
use std::{any::TypeId, borrow::Cow, collections::HashMap, sync::Arc};
@ -32,11 +32,15 @@ pub struct Context {
device_wgpu: Arc<wgpu::Device>,
cache: Mutex<HashMap<TemplateKey, Arc<ComputePipeline>>>,
client: ContextClientImpl,
pub(crate) tuner: Tuner,
#[cfg(feature = "autotune")]
pub(crate) tuner: tune::Tuner,
pub(crate) device: WgpuDevice,
pub(crate) info: wgpu::AdapterInfo,
}
unsafe impl Send for Context {}
unsafe impl Sync for Context {}
#[derive(Debug, Hash, PartialOrd, PartialEq, Eq)]
enum TemplateKey {
Static(TypeId),
@ -71,7 +75,8 @@ impl Context {
device,
client,
cache: Mutex::new(HashMap::new()),
tuner: Tuner::new(),
#[cfg(feature = "autotune")]
tuner: crate::Tuner::new(),
info,
}
}
@ -231,7 +236,7 @@ impl PartialEq for Context {
async fn select_device<G: GraphicsApi>(
device: &WgpuDevice,
) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) {
let adapter = select_adapter::<G>(device);
let adapter = select_adapter::<G>(device).await;
let limits = adapter.limits();
let (device, queue) = adapter
@ -256,12 +261,35 @@ async fn select_device<G: GraphicsApi>(
(device, queue, adapter.get_info())
}
fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> wgpu::Adapter {
async fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> wgpu::Adapter {
let instance = wgpu::Instance::default();
let mut adapters_other = Vec::new();
let mut adapters = Vec::new();
#[cfg(target_arch = "wasm32")]
{
let power_preference = match device {
WgpuDevice::DiscreteGpu(_) => wgpu::PowerPreference::HighPerformance,
WgpuDevice::IntegratedGpu(_) => wgpu::PowerPreference::LowPower,
WgpuDevice::VirtualGpu(_) => wgpu::PowerPreference::None,
WgpuDevice::Cpu => wgpu::PowerPreference::None,
WgpuDevice::BestAvailable => wgpu::PowerPreference::HighPerformance,
};
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptionsBase {
power_preference,
..Default::default()
})
.await;
if let Some(adapter) = adapter {
adapters.push(adapter);
}
}
#[cfg(not(target_arch = "wasm32"))]
instance
.enumerate_adapters(G::backend().into())
.for_each(|adapter| {

View File

@ -3,9 +3,13 @@ pub(crate) mod utils;
mod mem_coalescing;
mod naive;
mod tiling2d;
mod tune;
pub use mem_coalescing::*;
pub use naive::*;
pub use tiling2d::*;
#[cfg(feature = "autotune")]
mod tune;
#[cfg(feature = "autotune")]
pub use tune::*;

View File

@ -16,6 +16,8 @@ pub(crate) mod context;
pub(crate) mod element;
pub(crate) mod pool;
pub(crate) mod tensor;
#[cfg(feature = "autotune")]
pub(crate) mod tune;
mod device;

View File

@ -14,6 +14,9 @@ pub struct WgpuTensor<E: WgpuElement, const D: usize> {
elem: PhantomData<E>,
}
unsafe impl<E: WgpuElement, const D: usize> Send for WgpuTensor<E, D> {}
unsafe impl<E: WgpuElement, const D: usize> Sync for WgpuTensor<E, D> {}
#[derive(Debug, Clone)]
pub struct WgpuTensorDyn<E: WgpuElement> {
pub(crate) context: Arc<Context>,

View File

@ -190,6 +190,7 @@ fn no_std_checks() {
// Run checks for the following crates
build_and_test_no_std("burn");
build_and_test_no_std("burn-wgpu");
build_and_test_no_std("burn-core");
build_and_test_no_std("burn-common");
build_and_test_no_std("burn-tensor");