mirror of https://github.com/tracel-ai/burn.git
Autotune: fix inputs (#926)
This commit is contained in:
parent
6548f1a730
commit
a0297530ea
|
@ -1,10 +1,11 @@
|
|||
use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet};
|
||||
use burn_tensor::Element;
|
||||
use burn_tensor::{Element, ElementConversion};
|
||||
|
||||
use crate::{
|
||||
compute::WgpuAutotuneKey,
|
||||
element::WgpuElement,
|
||||
kernel::matmul::{tune::utils::autotune_tensors, utils::init_matmul_output},
|
||||
kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform},
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
|
@ -37,32 +38,42 @@ impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet<WgpuAutotune
|
|||
}
|
||||
|
||||
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>> {
|
||||
let lhs = autotune_tensors(&self.lhs);
|
||||
let rhs = autotune_tensors(&self.rhs);
|
||||
let out = autotune_tensors(&self.out);
|
||||
let random_bounds: (E, E) = ((-10.0).elem::<E>(), (10.0).elem::<E>());
|
||||
let lhs = random_like_uniform(&self.lhs, random_bounds.0, random_bounds.1);
|
||||
let rhs = random_like_uniform(&self.rhs, random_bounds.0, random_bounds.1);
|
||||
|
||||
let out = empty_device(
|
||||
self.out.client.clone(),
|
||||
self.out.device.clone(),
|
||||
self.out.shape.clone(),
|
||||
);
|
||||
|
||||
vec![
|
||||
Box::new(MemoryCoalescingMatmulDefault::<E, 3>::new(
|
||||
Box::new(MemoryCoalescingMatmulDefault::<E, D>::new(
|
||||
lhs.clone(),
|
||||
rhs.clone(),
|
||||
out.clone(),
|
||||
)),
|
||||
Box::new(MemoryCoalescingMatmulW16x16::<E, 3>::new(
|
||||
Box::new(MemoryCoalescingMatmulW16x16::<E, D>::new(
|
||||
lhs.clone(),
|
||||
rhs.clone(),
|
||||
out.clone(),
|
||||
)),
|
||||
Box::new(Vec4TilingMatmulDefault::<E, 3>::new(
|
||||
Box::new(Vec4TilingMatmulDefault::<E, D>::new(
|
||||
lhs.clone(),
|
||||
rhs.clone(),
|
||||
out.clone(),
|
||||
)),
|
||||
Box::new(Vec4TilingMatmulUnpaddedDefault::<E, 3>::new(
|
||||
Box::new(Vec4TilingMatmulUnpaddedDefault::<E, D>::new(
|
||||
lhs.clone(),
|
||||
rhs.clone(),
|
||||
out.clone(),
|
||||
)),
|
||||
Box::new(Vec4LhsOnlyTilingMatmulDefault::<E, D>::new(
|
||||
lhs.clone(),
|
||||
rhs.clone(),
|
||||
out.clone(),
|
||||
)),
|
||||
Box::new(Vec4LhsOnlyTilingMatmulDefault::<E, 3>::new(lhs, rhs, out)),
|
||||
]
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
mod base;
|
||||
mod key;
|
||||
mod utils;
|
||||
|
||||
pub use base::*;
|
||||
pub use key::*;
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
use burn_tensor::Element;
|
||||
|
||||
use crate::{element::WgpuElement, ops::numeric::ones_device, tensor::WgpuTensor};
|
||||
|
||||
pub(crate) fn autotune_tensors<E: WgpuElement + Element, const D: usize>(
|
||||
tensor: &WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, 3> {
|
||||
let n_batches = 2;
|
||||
ones_device(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
[
|
||||
n_batches,
|
||||
tensor.shape.dims[D - 2],
|
||||
tensor.shape.dims[D - 1],
|
||||
]
|
||||
.into(),
|
||||
)
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
use burn_tensor::Shape;
|
||||
|
||||
use crate::{
|
||||
compute::{compute_client, StaticKernel},
|
||||
compute::{compute_client, StaticKernel, WgpuComputeClient},
|
||||
element::WgpuElement,
|
||||
kernel::{
|
||||
prng::base::{make_args_buffer, make_info_buffer},
|
||||
|
@ -31,10 +31,36 @@ pub fn random_uniform<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
|||
device: &WgpuDevice,
|
||||
low: E,
|
||||
high: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let client = compute_client::<G>(device);
|
||||
uniform_kernel(client, device, &shape, low, high)
|
||||
}
|
||||
|
||||
/// Pseudo-random generator for uniform distribution, based on
|
||||
/// another tensor's client, device and shape
|
||||
pub fn random_like_uniform<E: WgpuElement, const D: usize>(
|
||||
tensor: &WgpuTensor<E, D>,
|
||||
low: E,
|
||||
high: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
uniform_kernel(
|
||||
tensor.client.clone(),
|
||||
&tensor.device,
|
||||
&tensor.shape,
|
||||
low,
|
||||
high,
|
||||
)
|
||||
}
|
||||
|
||||
fn uniform_kernel<E: WgpuElement, const D: usize>(
|
||||
client: WgpuComputeClient,
|
||||
device: &WgpuDevice,
|
||||
shape: &Shape<D>,
|
||||
low: E,
|
||||
high: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const N_VALUES_PER_THREAD: usize = 128;
|
||||
|
||||
let client = compute_client::<G>(device);
|
||||
let output = empty_device(client.clone(), device.clone(), shape.clone());
|
||||
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
|
||||
let args_handle = make_args_buffer(client.clone(), &[low, high]);
|
||||
|
|
Loading…
Reference in New Issue