JIT Migration: PRNG (#1433)

* wip bernoulli

* wip

* bernoulli works

* uniform works

* done

* remove old

* refactor prng traits

* forgot to save file

* allow

* clippy

* clippy

* scalar commutativity

* array instead of vec
This commit is contained in:
Louis Fortier-Dubois 2024-03-11 11:40:27 -04:00 committed by GitHub
parent 3f7e6bd5bc
commit 093cbd397d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 938 additions and 404 deletions

View File

@ -29,6 +29,10 @@ macro_rules! gpu {
($scope:expr, $out:ident = $lhs:ident * $rhs:expr) => {
gpu!($scope, $out = mul($lhs, $rhs))
};
// out *= input
($scope:expr, $out:ident *= $input:ident) => {
gpu!($scope, $out = mul($out, $input))
};
// out = mul(lhs, rhs)
($scope:expr, $out:ident = mul($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Mul(
@ -55,10 +59,6 @@ macro_rules! gpu {
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs ^ rhs
($scope:expr, $out:ident = $lhs:ident ^ $rhs:expr) => {
gpu!($scope, $out = powf($lhs, $rhs))
};
// out = powf(lhs, rhs)
($scope:expr, $out:ident = powf($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Powf(
@ -95,6 +95,46 @@ macro_rules! gpu {
gpu!(unary $input, $out)
));
};
// out = lhs & rhs
($scope:expr, $out: ident = $lhs:ident & $rhs:ident) => {
gpu!($scope, $out = bitwise_and($lhs, $rhs))
};
// out = bitwise_and(lhs, rhs)
($scope:expr, $out:ident = bitwise_and($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::BitwiseAnd(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs ^ rhs
($scope:expr, $out: ident = $lhs:ident ^ $rhs:ident) => {
gpu!($scope, $out = bitwise_xor($lhs, $rhs))
};
// out = bitwise_xor(lhs, rhs)
($scope:expr, $out:ident = bitwise_xor($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::BitwiseXor(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs << rhs
($scope:expr, $out: ident = $lhs:ident << $rhs:ident) => {
gpu!($scope, $out = shift_left($lhs, $rhs))
};
// out = shift_left(lhs, rhs)
($scope:expr, $out:ident = shift_left($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::ShiftLeft(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs >> rhs
($scope:expr, $out: ident = $lhs:ident >> $rhs:ident) => {
gpu!($scope, $out = shift_right($lhs, $rhs))
};
// out = shift_right(lhs, rhs)
($scope:expr, $out:ident = shift_right($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::ShiftRight(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs == rhs
($scope:expr, $out:ident = $lhs:ident == $rhs:expr) => {
gpu!($scope, $out = equal($lhs, $rhs))

View File

@ -53,6 +53,10 @@ pub enum Operator {
Not(UnaryOperator),
Max(BinaryOperator),
Min(BinaryOperator),
BitwiseAnd(BinaryOperator),
BitwiseXor(BinaryOperator),
ShiftLeft(BinaryOperator),
ShiftRight(BinaryOperator),
}
/// All metadata that can be access in a shader.

View File

@ -1,3 +1,5 @@
use crate::JitElement;
use super::{
gpu, processing::ScopeProcessing, Elem, IndexOffsetGlobalWithLayout, Item, Operation, Operator,
Procedure, ReadGlobal, ReadGlobalWithLayout, UnaryOperator, Variable, Vectorization,
@ -64,6 +66,18 @@ impl Scope {
local
}
/// Create a variable initialized at some value.
pub(crate) fn create_with_value<E: JitElement, I: Into<Item> + Copy>(
&mut self,
value: E,
item: I,
) -> Variable {
let local = self.create_local(item);
let value = Variable::ConstantScalar(value.to_f64().unwrap(), item.into().elem());
gpu!(self, local = value);
local
}
/// Create a local variable of the given [item type](Item).
pub(crate) fn create_local<I: Into<Item>>(&mut self, item: I) -> Variable {
let item = item.into();

View File

@ -74,6 +74,10 @@ impl Operator {
Operator::And(op) => Operator::And(op.vectorize(vectorization)),
Operator::Or(op) => Operator::Or(op.vectorize(vectorization)),
Operator::Not(op) => Operator::Not(op.vectorize(vectorization)),
Operator::BitwiseAnd(op) => Operator::BitwiseAnd(op.vectorize(vectorization)),
Operator::BitwiseXor(op) => Operator::BitwiseXor(op.vectorize(vectorization)),
Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)),
Operator::ShiftRight(op) => Operator::ShiftRight(op.vectorize(vectorization)),
}
}
}

View File

@ -1,5 +1,6 @@
use crate::compute::{DynamicKernel, Kernel, StaticKernel, WorkGroup};
use crate::element::JitElement;
use crate::gpu::Elem;
use crate::kernel::{
elemwise_workgroup, DynamicKernelSource, StaticKernelSource, WORKGROUP_DEFAULT,
};
@ -22,10 +23,6 @@ pub enum WorkgroupLaunch {
}
/// Execute a static kernel.
///
///
/// The limitation from this method is that you can't launch a kernel with multiple types of
/// scalar.
pub fn execute_static<R, K, E>(
inputs: &[EagerHandle<R>],
outputs: &[EagerHandle<R>],
@ -37,12 +34,32 @@ pub fn execute_static<R, K, E>(
R: Runtime,
E: JitElement,
{
let settings = execute_settings(inputs, outputs, scalar_elems, launch, &client);
execute_static_::<R, K, E, E, E>(inputs, outputs, scalar_elems, None, None, launch, client)
}
fn execute_static_<R, K, E1, E2, E3>(
inputs: &[EagerHandle<R>],
outputs: &[EagerHandle<R>],
scalars_1: Option<&[E1]>,
scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>,
launch: WorkgroupLaunch,
client: ComputeClient<R::Server, R::Channel>,
) where
K: StaticKernelSource + 'static,
R: Runtime,
E1: JitElement,
E2: JitElement,
E3: JitElement,
{
let settings = execute_settings(
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
);
let mut handles = settings.handles_tensors;
let workgroup = settings.workgroup;
handles.push(&settings.handle_info);
if let Some(handle) = settings.handle_scalars.as_ref() {
for handle in settings.handles_scalars.iter() {
handles.push(handle);
}
@ -50,11 +67,180 @@ pub fn execute_static<R, K, E>(
client.execute(kernel, &handles);
}
pub struct Execution<'h, K, R: Runtime, Scalars> {
scalars: Scalars,
client: ComputeClient<R::Server, R::Channel>,
kernel: K,
inputs: &'h [EagerHandle<'h, R>],
outputs: &'h [EagerHandle<'h, R>],
}
impl<'h, K, R: Runtime> Execution<'h, K, R, ()> {
pub fn start(
kernel: K,
client: ComputeClient<R::Server, R::Channel>,
) -> Execution<'h, K, R, ()> {
Execution {
scalars: (),
client,
kernel,
inputs: &[],
outputs: &[],
}
}
#[allow(unused)]
pub fn inputs(self, inputs: &'h [EagerHandle<'h, R>]) -> Execution<'h, K, R, ()> {
Execution {
scalars: self.scalars,
client: self.client,
kernel: self.kernel,
inputs,
outputs: self.outputs,
}
}
pub fn outputs(self, outputs: &'h [EagerHandle<'h, R>]) -> Execution<'h, K, R, ()> {
Execution {
scalars: self.scalars,
client: self.client,
kernel: self.kernel,
inputs: self.inputs,
outputs,
}
}
}
impl<'h, K, R> Execution<'h, K, R, ()>
where
K: DynamicKernelSource + 'static,
R: Runtime,
{
pub fn with_scalars<E>(self, scalars: &[E]) -> Execution<'h, K, R, (&[E],)> {
Execution {
scalars: (scalars,),
client: self.client,
kernel: self.kernel,
inputs: self.inputs,
outputs: self.outputs,
}
}
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: WorkgroupLaunch) {
execute_dynamic_::<R, K, f32, f32, f32>(
self.inputs,
self.outputs,
None,
None,
None,
self.kernel,
launch,
self.client,
)
}
}
impl<'h, 'a, K, R, E> Execution<'h, K, R, (&'a [E],)>
where
K: DynamicKernelSource + 'static,
R: Runtime,
E: JitElement,
{
pub fn with_scalars<'b, E2>(
self,
scalars: &'b [E2],
) -> Execution<'h, K, R, (&'a [E], &'b [E2])> {
Execution {
scalars: (self.scalars.0, scalars),
client: self.client,
kernel: self.kernel,
inputs: self.inputs,
outputs: self.outputs,
}
}
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: WorkgroupLaunch) {
execute_dynamic_::<R, K, E, f32, f32>(
self.inputs,
self.outputs,
Some(self.scalars.0),
None,
None,
self.kernel,
launch,
self.client,
)
}
}
impl<'h, 'a, 'b, K, R, E1, E2> Execution<'h, K, R, (&'a [E1], &'b [E2])>
where
K: DynamicKernelSource + 'static,
R: Runtime,
E1: JitElement,
E2: JitElement,
{
#[allow(unused, clippy::type_complexity)]
pub fn with_scalars<'c, E3>(
self,
scalars: &'c [E3],
) -> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])> {
Execution {
scalars: (self.scalars.0, self.scalars.1, scalars),
client: self.client,
kernel: self.kernel,
inputs: self.inputs,
outputs: self.outputs,
}
}
/// Execute a dynamic kernel.
#[allow(clippy::too_many_arguments)]
pub fn execute(self, launch: WorkgroupLaunch)
where
K: DynamicKernelSource + 'static,
R: Runtime,
{
execute_dynamic_::<R, K, E1, E2, f32>(
self.inputs,
self.outputs,
Some(self.scalars.0),
Some(self.scalars.1),
None,
self.kernel,
launch,
self.client,
)
}
}
impl<'h, 'a, 'b, 'c, K, R, E1, E2, E3> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])>
where
K: DynamicKernelSource + 'static,
R: Runtime,
E1: JitElement,
E2: JitElement,
E3: JitElement,
{
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: WorkgroupLaunch) {
execute_dynamic_::<R, K, E1, E2, E3>(
self.inputs,
self.outputs,
Some(self.scalars.0),
Some(self.scalars.1),
Some(self.scalars.2),
self.kernel,
launch,
self.client,
)
}
}
/// Execute a dynamic kernel.
///
///
/// The limitation from this method is that you can't launch a kernel with multiple types of
/// scalar.
pub fn execute_dynamic<R, K, E>(
inputs: &[EagerHandle<R>],
outputs: &[EagerHandle<R>],
@ -67,12 +253,43 @@ pub fn execute_dynamic<R, K, E>(
R: Runtime,
E: JitElement,
{
let settings = execute_settings(inputs, outputs, scalar_elems, launch, &client);
execute_dynamic_::<R, K, E, E, E>(
inputs,
outputs,
scalar_elems,
None,
None,
kernel,
launch,
client,
)
}
#[allow(clippy::too_many_arguments)]
fn execute_dynamic_<R, K, E1, E2, E3>(
inputs: &[EagerHandle<R>],
outputs: &[EagerHandle<R>],
scalars_1: Option<&[E1]>,
scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>,
kernel: K,
launch: WorkgroupLaunch,
client: ComputeClient<R::Server, R::Channel>,
) where
K: DynamicKernelSource + 'static,
R: Runtime,
E1: JitElement,
E2: JitElement,
E3: JitElement,
{
let settings = execute_settings(
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
);
let mut handles = settings.handles_tensors;
let workgroup = settings.workgroup;
handles.push(&settings.handle_info);
if let Some(handle) = settings.handle_scalars.as_ref() {
for handle in settings.handles_scalars.iter() {
handles.push(handle);
}
@ -84,14 +301,16 @@ pub fn execute_dynamic<R, K, E>(
struct ExecuteSettings<'a, R: Runtime> {
handles_tensors: Vec<&'a Handle<R::Server>>,
handle_info: Handle<R::Server>,
handle_scalars: Option<Handle<R::Server>>,
handles_scalars: Vec<Handle<R::Server>>,
workgroup: WorkGroup,
}
fn execute_settings<'a, R: Runtime, E: JitElement>(
fn execute_settings<'a, R: Runtime, E1: JitElement, E2: JitElement, E3: JitElement>(
inputs: &'a [EagerHandle<R>],
outputs: &'a [EagerHandle<R>],
scalar_elems: Option<&[E]>,
scalars_1: Option<&[E1]>,
scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>,
launch: WorkgroupLaunch,
client: &ComputeClient<R::Server, R::Channel>,
) -> ExecuteSettings<'a, R> {
@ -139,10 +358,8 @@ fn execute_settings<'a, R: Runtime, E: JitElement>(
let info = client.create(bytemuck::cast_slice(&info));
// Finally we finish with the named bindings.
let mut scalars = None;
if let Some(values) = &scalar_elems {
scalars = Some(client.create(bytemuck::cast_slice(values)));
}
let handles_scalars =
create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client);
let workgroup = match launch {
WorkgroupLaunch::Custom(workgroup) => workgroup,
@ -152,11 +369,54 @@ fn execute_settings<'a, R: Runtime, E: JitElement>(
ExecuteSettings {
handles_tensors: handles,
handle_info: info,
handle_scalars: scalars,
handles_scalars,
workgroup,
}
}
fn create_scalar_handles<R: Runtime, E1: JitElement, E2: JitElement, E3: JitElement>(
scalars_0: Option<&[E1]>,
scalars_1: Option<&[E2]>,
scalars_2: Option<&[E3]>,
client: &ComputeClient<R::Server, R::Channel>,
) -> Vec<Handle<R::Server>> {
// It is crucial that scalars follow this order: float, int, uint
let element_priority = |elem: Elem| match elem {
Elem::Float => 0,
Elem::Int => 1,
Elem::UInt => 2,
Elem::Bool => panic!("Bool scalars are not supported"),
};
let scalar_priorities: [usize; 3] = [
element_priority(E1::gpu_elem()),
element_priority(E2::gpu_elem()),
element_priority(E3::gpu_elem()),
];
let mut handles_scalars = Vec::new();
for i in 0..3 {
for (j, scalar_priority) in scalar_priorities.iter().enumerate() {
if scalar_priority == &i {
if j == 0 {
if let Some(values) = &scalars_0 {
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
}
} else if j == 1 {
if let Some(values) = &scalars_1 {
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
}
} else if j == 2 {
if let Some(values) = &scalars_2 {
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
}
}
}
}
}
handles_scalars
}
pub(crate) fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
let mut num_elems = 1;
for i in shape.iter() {

View File

@ -326,6 +326,26 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::BitwiseAnd(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::BitwiseXor(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::ShiftLeft(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::ShiftRight(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
},
Operation::Procedure(proc) => {
match proc {

View File

@ -181,20 +181,6 @@ pub(crate) fn elemwise_workgroup(num_elems: usize, workgroup_size: usize) -> Wor
WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1)
}
pub(crate) fn prng_workgroup(
num_elems: usize,
workgroup_size: usize,
n_values_per_thread: usize,
) -> WorkGroup {
let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
let num_elem_per_invocation = workgroup_size * workgroup_size;
let num_invocations = f32::ceil(num_threads / num_elem_per_invocation as f32);
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1)
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,11 +1,123 @@
use crate::{element::JitElement, kernel_wgsl, Runtime, SEED};
use std::marker::PhantomData;
use crate::{
codegen::{
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
compute::WorkGroup,
gpu::{gpu, Elem, Scope, Variable},
kernel::{DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT},
tensor::JitTensor,
Compiler, JitElement, Runtime, SEED,
};
use burn_common::rand::get_seeded_rng;
use burn_compute::{client::ComputeClient, server::Handle};
use burn_tensor::Shape;
use rand::Rng;
kernel_wgsl!(Prng, "../../template/prng/prng.wgsl");
pub(crate) const N_VALUES_PER_THREAD: usize = 128;
pub(crate) fn get_seeds() -> Vec<u32> {
/// Pseudo-random generator
pub(crate) fn random<P: Prng<E>, R: Runtime, E: JitElement, const D: usize>(
shape: Shape<D>,
device: &R::Device,
prng: P,
) -> JitTensor<R, E, D> {
let client = R::client(device);
let kernel: PrngEagerKernel<P, R, E> = PrngEagerKernel::new();
let num_elems = shape.num_elements();
let buffer = client.empty(num_elems * core::mem::size_of::<E>());
let output = JitTensor::new(client.clone(), device.clone(), shape.clone(), buffer);
let seeds = get_seeds();
Execution::start(kernel, client)
.outputs(&[EagerHandle::<R>::new(
&output.handle,
&output.strides,
&output.shape.dims,
)])
.with_scalars(&seeds)
.with_scalars(&prng.args())
.execute(WorkgroupLaunch::Custom(prng_workgroup(
num_elems,
WORKGROUP_DEFAULT,
N_VALUES_PER_THREAD,
)));
output
}
fn prng_workgroup(
num_elems: usize,
workgroup_size: usize,
n_values_per_thread: usize,
) -> WorkGroup {
let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
let num_elem_per_invocation = workgroup_size * workgroup_size;
let num_invocations = f32::ceil(num_threads / num_elem_per_invocation as f32);
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1)
}
impl<P: Prng<E>, R: Runtime, E: JitElement> DynamicKernelSource for PrngEagerKernel<P, R, E> {
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let output = Variable::GlobalOutputArray(0, item);
let seed0 = Variable::GlobalScalar(0, Elem::UInt);
let seed1 = Variable::GlobalScalar(1, Elem::UInt);
let seed2 = Variable::GlobalScalar(2, Elem::UInt);
let seed3 = Variable::GlobalScalar(3, Elem::UInt);
let seeds = [seed0, seed1, seed2, seed3];
let mut args = Vec::<Variable>::new();
for i in 0..P::args_length() {
args.push(Variable::GlobalScalar(i as u16, item.elem()));
}
PrngShader::<P, E>::new(output, N_VALUES_PER_THREAD, seeds, args).expand(&mut scope);
scope.write_global_custom(output);
let args = InputInfo::Scalar {
elem: E::gpu_elem(),
size: P::args_length(),
};
let seeds = InputInfo::Scalar {
elem: Elem::UInt,
size: 4,
};
let out = OutputInfo::Array { item };
let info = CompilationInfo {
inputs: vec![args, seeds],
outputs: vec![out],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<Self>(),)
}
}
#[derive(new)]
pub(crate) struct PrngEagerKernel<P: Prng<E>, R: Runtime, E: JitElement> {
_prng: PhantomData<P>,
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
pub(crate) fn get_seeds() -> [u32; 4] {
let mut seed = SEED.lock().unwrap();
let mut rng = match seed.as_ref() {
Some(rng_seeded) => rng_seeded.clone(),
@ -16,23 +128,168 @@ pub(crate) fn get_seeds() -> Vec<u32> {
seeds.push(rng.gen());
}
*seed = Some(rng);
seeds
seeds.try_into().unwrap()
}
pub(crate) fn make_info_buffer<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
pub(crate) trait Prng<E>: Send + Sync + 'static {
fn args(self) -> Vec<E>;
fn args_length() -> usize;
#[allow(clippy::too_many_arguments)]
fn inner_loop(
scope: &mut Scope,
args: Vec<Variable>,
write_index_base: Variable,
n_invocations: Variable,
n_values_per_thread: usize,
state_0: Variable,
state_1: Variable,
state_2: Variable,
state_3: Variable,
output: Variable,
);
}
#[derive(new)]
pub(crate) struct PrngShader<P: Prng<E>, E: JitElement> {
output: Variable,
n_values_per_thread: usize,
) -> Handle<R::Server> {
let mut info = get_seeds();
info.insert(0, n_values_per_thread as u32);
client.create(bytemuck::cast_slice(&info))
seeds: [Variable; 4],
args: Vec<Variable>,
_prng: PhantomData<P>,
_elem: PhantomData<E>,
}
pub(crate) fn make_args_buffer<R: Runtime, E: JitElement>(
client: ComputeClient<R::Server, R::Channel>,
args: &[E],
) -> Handle<R::Server> {
client.create(E::as_bytes(args))
impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
pub(crate) fn expand(self, scope: &mut Scope) {
let output = self.output;
let [seed_0, seed_1, seed_2, seed_3] = self.seeds;
let n_values_per_thread: Variable = self.n_values_per_thread.into();
let args = self.args;
let workgroup_size_x = Variable::WorkgroupSizeX;
let workgroup_size_y = Variable::WorkgroupSizeY;
let workgroup_id_x = Variable::WorkgroupIdX;
let workgroup_id_y = Variable::WorkgroupIdY;
let num_workgroups_y = Variable::NumWorkgroupsY;
let local_index = Variable::LocalInvocationIndex;
let n_invocations = scope.create_local(Elem::UInt);
gpu!(scope, n_invocations = workgroup_size_x);
gpu!(scope, n_invocations *= workgroup_size_y);
let workgroup_offset = scope.create_local(Elem::UInt);
gpu!(scope, workgroup_offset = workgroup_id_x * num_workgroups_y);
gpu!(scope, workgroup_offset += workgroup_id_y);
gpu!(scope, workgroup_offset *= n_invocations);
let write_index_base = scope.create_local(Elem::UInt);
gpu!(scope, write_index_base = workgroup_offset);
gpu!(scope, write_index_base *= n_values_per_thread);
gpu!(scope, write_index_base += local_index);
// Set state with unique seeds
let thread_seed = scope.create_local(Elem::UInt);
gpu!(scope, thread_seed = cast(1000000007));
let thread_seed_index = scope.create_local(Elem::UInt);
gpu!(scope, thread_seed_index = workgroup_offset + local_index);
gpu!(scope, thread_seed *= thread_seed_index);
let state_0 = scope.create_local(Elem::UInt);
gpu!(scope, state_0 = thread_seed);
gpu!(scope, state_0 += seed_0);
let state_1 = scope.create_local(Elem::UInt);
gpu!(scope, state_1 = thread_seed);
gpu!(scope, state_1 += seed_1);
let state_2 = scope.create_local(Elem::UInt);
gpu!(scope, state_2 = thread_seed);
gpu!(scope, state_2 += seed_2);
let state_3 = scope.create_local(Elem::UInt);
gpu!(scope, state_3 = thread_seed);
gpu!(scope, state_3 += seed_3);
// Creation of n_values_per_thread values, specific to the distribution
P::inner_loop(
scope,
args,
write_index_base,
n_invocations,
self.n_values_per_thread,
state_0,
state_1,
state_2,
state_3,
output,
);
}
}
pub(crate) fn taus_step_0(scope: &mut Scope, z: Variable) {
taus_step(
scope,
z,
13u32.into(),
19u32.into(),
12u32.into(),
4294967294u32.into(),
);
}
pub(crate) fn taus_step_1(scope: &mut Scope, z: Variable) {
taus_step(
scope,
z,
2u32.into(),
25u32.into(),
4u32.into(),
4294967288u32.into(),
);
}
pub(crate) fn taus_step_2(scope: &mut Scope, z: Variable) {
taus_step(
scope,
z,
3u32.into(),
11u32.into(),
17u32.into(),
4294967280u32.into(),
);
}
fn taus_step(
scope: &mut Scope,
z: Variable,
s1: Variable,
s2: Variable,
s3: Variable,
m: Variable,
) {
let b = scope.create_local(Elem::UInt);
gpu!(scope, b = z << s1);
gpu!(scope, b = b ^ z);
gpu!(scope, b = b >> s2);
gpu!(scope, z = z & m);
gpu!(scope, z = z << s3);
gpu!(scope, z = z ^ b);
}
pub(crate) fn lcg_step(scope: &mut Scope, z: Variable) {
let a: Variable = 1664525u32.into();
let b: Variable = 1013904223u32.into();
gpu!(scope, z *= a);
gpu!(scope, z += b);
}
pub(crate) fn cast_uint_to_float(scope: &mut Scope, int_random: Variable, float_random: Variable) {
let tmp: Variable = 2.328_306_4e-10.into();
gpu!(scope, float_random = cast(int_random));
gpu!(scope, float_random *= tmp);
}
#[cfg(feature = "export_tests")]
@ -72,21 +329,4 @@ pub mod tests_utils {
output[current_runs].n_runs += 1;
output
}
#[test]
fn test_count_bins() {
let numbers = vec![0., 1., 1.5, 2., 2.5, 3., 2.5, 1.5, 3.5];
let number_of_bins = 4;
let low = 0.;
let high = 4.;
let stats = calculate_bin_stats(numbers, number_of_bins, low, high);
assert_eq!(stats[0].count, 1);
assert_eq!(stats[0].n_runs, 1);
assert_eq!(stats[1].count, 3);
assert_eq!(stats[1].n_runs, 2);
assert_eq!(stats[2].count, 3);
assert_eq!(stats[2].n_runs, 2);
assert_eq!(stats[3].count, 2);
assert_eq!(stats[3].n_runs, 2);
}
}

View File

@ -1,53 +1,73 @@
use crate::{
compute::StaticKernel,
element::JitElement,
kernel::{
prng::base::{make_args_buffer, make_info_buffer},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
};
use burn_tensor::Shape;
use super::base::Prng;
use crate::{
gpu::{gpu, Elem, Scope, Variable},
kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2},
tensor::JitTensor,
JitElement, Runtime,
};
struct BernoulliPrng;
use super::{random, Prng};
impl StaticKernelSource for BernoulliPrng {
fn source() -> SourceTemplate {
Prng::source()
.register("num_args", "1")
.register(
"prng_loop",
include_str!("../../template/prng/bernoulli_inner_loop.wgsl"),
)
.add_template("fn cast_elem(e: bool) -> {{ elem }} {return {{elem}}(e);}")
pub(crate) struct Bernoulli<E> {
probability: E,
}
impl<E: JitElement> Prng<E> for Bernoulli<E> {
fn args(self) -> Vec<E> {
vec![self.probability]
}
fn inner_loop(
scope: &mut Scope,
args: Vec<Variable>,
write_index_base: Variable,
n_invocations: Variable,
n_values_per_thread: usize,
state_0: Variable,
state_1: Variable,
state_2: Variable,
state_3: Variable,
output: Variable,
) {
let prob = args[0];
gpu!(
scope,
range(0u32, n_values_per_thread).for_each(|i, scope| {
taus_step_0(scope, state_0);
taus_step_1(scope, state_1);
taus_step_2(scope, state_2);
lcg_step(scope, state_3);
let int_random = scope.create_local(Elem::UInt);
gpu!(scope, int_random = state_0 ^ state_1);
gpu!(scope, int_random = int_random ^ state_2);
gpu!(scope, int_random = int_random ^ state_3);
let float_random = scope.create_local(Elem::Float);
cast_uint_to_float(scope, int_random, float_random);
let bernoulli = scope.create_local(Elem::Bool);
gpu!(scope, bernoulli = float_random < prob);
let write_index = scope.create_local(Elem::UInt);
gpu!(scope, write_index = i * n_invocations);
gpu!(scope, write_index += write_index_base);
gpu!(scope, output[write_index] = bernoulli);
})
);
}
fn args_length() -> usize {
1
}
}
/// Pseudo-random generator for bernoulli
/// Pseudo-random generator with bernoulli distribution
pub fn random_bernoulli<R: Runtime, E: JitElement, const D: usize>(
shape: Shape<D>,
device: &R::Device,
prob: E,
probability: E,
) -> JitTensor<R, E, D> {
const N_VALUES_PER_THREAD: usize = 128;
let client = R::client(device);
let output = empty_device(client.clone(), device.clone(), shape.clone());
let info_handle = make_info_buffer::<R>(client.clone(), N_VALUES_PER_THREAD);
let args_handle = make_args_buffer::<R, E>(client.clone(), &[prob]);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
let kernel = StaticKernel::<
KernelSettings<BernoulliPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(workgroup);
client.execute(
Box::new(kernel),
&[&output.handle, &info_handle, &args_handle],
);
output
random(shape, device, Bernoulli { probability })
}

View File

@ -1,57 +1,122 @@
use std::f32::consts::PI;
use burn_tensor::Shape;
use crate::{
compute::StaticKernel,
element::JitElement,
kernel::{
prng::base::{make_args_buffer, make_info_buffer},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
},
ops::numeric::empty_device,
gpu::{gpu, Elem, Scope, Variable},
kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2},
tensor::JitTensor,
Runtime,
JitElement, Runtime,
};
use super::base::Prng;
use super::{random, Prng};
struct NormalPrng;
pub(crate) struct Normal<E> {
mean: E,
std: E,
}
impl StaticKernelSource for NormalPrng {
fn source() -> SourceTemplate {
Prng::source()
.register("num_args", "2")
.register(
"prng_loop",
include_str!("../../template/prng/normal_inner_loop.wgsl"),
)
.add_template(include_str!(
"../../template/prng/box_muller_transform.wgsl"
))
impl<E: JitElement> Prng<E> for Normal<E> {
fn args(self) -> Vec<E> {
vec![self.mean, self.std]
}
fn inner_loop(
scope: &mut Scope,
args: Vec<Variable>,
write_index_base: Variable,
n_invocations: Variable,
n_values_per_thread: usize,
state_0: Variable,
state_1: Variable,
state_2: Variable,
state_3: Variable,
output: Variable,
) {
let item = output.item();
let mean = args[0];
let std = args[1];
let two_pi = scope.create_with_value(2. * PI, Elem::Float);
let t_neg = scope.create_with_value(-2.0, item);
let two: Variable = 2u32.into();
gpu!(
scope,
range(0u32, n_values_per_thread / 2).for_each(|i, scope| {
let int_random = scope.create_local(Elem::UInt);
// First random uniform integer
taus_step_0(scope, state_0);
taus_step_1(scope, state_1);
taus_step_2(scope, state_2);
lcg_step(scope, state_3);
gpu!(scope, int_random = state_0 ^ state_1);
gpu!(scope, int_random = int_random ^ state_2);
gpu!(scope, int_random = int_random ^ state_3);
let unit_0 = scope.create_local(Elem::Float);
cast_uint_to_float(scope, int_random, unit_0);
// Second random uniform integer
taus_step_0(scope, state_0);
taus_step_1(scope, state_1);
taus_step_2(scope, state_2);
lcg_step(scope, state_3);
gpu!(scope, int_random = state_0 ^ state_1);
gpu!(scope, int_random = int_random ^ state_2);
gpu!(scope, int_random = int_random ^ state_3);
let unit_1 = scope.create_local(Elem::Float);
cast_uint_to_float(scope, int_random, unit_1);
// Box-Muller transform
let coeff = scope.create_local(item);
gpu!(scope, coeff = log(unit_0));
gpu!(scope, coeff *= t_neg);
gpu!(scope, coeff = sqrt(coeff));
gpu!(scope, coeff *= std);
let trigo_arg = scope.create_local(item);
gpu!(scope, trigo_arg = two_pi * unit_1);
let normal_0 = scope.create_local(item);
let normal_1 = scope.create_local(item);
gpu!(scope, normal_0 = cos(trigo_arg));
gpu!(scope, normal_0 *= coeff);
gpu!(scope, normal_0 += mean);
gpu!(scope, normal_1 = sin(trigo_arg));
gpu!(scope, normal_1 *= coeff);
gpu!(scope, normal_1 += mean);
// Write to output
let write_index_0 = scope.create_local(Elem::UInt);
let write_index_1 = scope.create_local(Elem::UInt);
let iteration_offset = scope.create_local(Elem::UInt);
gpu!(scope, write_index_0 = write_index_base);
gpu!(scope, iteration_offset = two * i);
gpu!(scope, iteration_offset *= n_invocations);
gpu!(scope, write_index_0 += iteration_offset);
gpu!(scope, write_index_1 = write_index_0 + n_invocations);
gpu!(scope, output[write_index_0] = normal_0);
gpu!(scope, output[write_index_1] = normal_1);
})
);
}
fn args_length() -> usize {
2
}
}
/// Pseudo-random generaJitBackendl distribution
/// Pseudo-random generator with uniform distribution
pub fn random_normal<R: Runtime, E: JitElement, const D: usize>(
shape: Shape<D>,
device: &R::Device,
mean: E,
std: E,
) -> JitTensor<R, E, D> {
const N_VALUES_PER_THREAD: usize = 128; // must be even
let client = R::client(device);
let output = empty_device(client.clone(), device.clone(), shape.clone());
let info_handle = make_info_buffer::<R>(client.clone(), N_VALUES_PER_THREAD);
let args_handle = make_args_buffer::<R, E>(client.clone(), &[mean, std]);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
let kernel = StaticKernel::<
KernelSettings<NormalPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(workgroup);
client.execute(
Box::new(kernel),
&[&output.handle, &info_handle, &args_handle],
);
output
random(shape, device, Normal { mean, std })
}

View File

@ -1,130 +1,102 @@
use burn_compute::client::ComputeClient;
use burn_tensor::Shape;
use crate::{
compute::StaticKernel,
element::JitElement,
kernel::{
prng::base::{make_args_buffer, make_info_buffer},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
},
ops::numeric::empty_device,
gpu::{gpu, Elem, Scope, Variable},
kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2},
tensor::JitTensor,
IntElement, Runtime,
JitElement, Runtime,
};
use super::base::Prng;
use super::{random, Prng};
struct UniformPrng;
struct UniformIntPrng;
pub(crate) struct Uniform<E> {
lower_bound: E,
upper_bound: E,
}
impl StaticKernelSource for UniformPrng {
fn source() -> SourceTemplate {
Prng::source().register("num_args", "2").register(
"prng_loop",
include_str!("../../template/prng/uniform_inner_loop.wgsl"),
)
impl<E: JitElement> Prng<E> for Uniform<E> {
fn args(self) -> Vec<E> {
vec![self.lower_bound, self.upper_bound]
}
fn inner_loop(
scope: &mut Scope,
args: Vec<Variable>,
write_index_base: Variable,
n_invocations: Variable,
n_values_per_thread: usize,
state_0: Variable,
state_1: Variable,
state_2: Variable,
state_3: Variable,
output: Variable,
) {
let item = output.item();
let lower_bound = args[0];
let upper_bound = args[1];
let scale = scope.create_local(item);
gpu!(scope, scale = upper_bound - lower_bound);
gpu!(
scope,
range(0u32, n_values_per_thread).for_each(|i, scope| {
taus_step_0(scope, state_0);
taus_step_1(scope, state_1);
taus_step_2(scope, state_2);
lcg_step(scope, state_3);
let int_random = scope.create_local(Elem::UInt);
gpu!(scope, int_random = state_0 ^ state_1);
gpu!(scope, int_random = int_random ^ state_2);
gpu!(scope, int_random = int_random ^ state_3);
let float_random = scope.create_local(Elem::Float);
cast_uint_to_float(scope, int_random, float_random);
let uniform = scope.create_local(item);
gpu!(scope, uniform = float_random * scale);
gpu!(scope, uniform += lower_bound);
let write_index = scope.create_local(Elem::UInt);
gpu!(scope, write_index = i * n_invocations);
gpu!(scope, write_index += write_index_base);
gpu!(scope, output[write_index] = uniform);
})
);
}
fn args_length() -> usize {
2
}
}
impl StaticKernelSource for UniformIntPrng {
fn source() -> SourceTemplate {
Prng::source().register("num_args", "2").register(
"prng_loop",
include_str!("../../template/prng/uniform_int_inner_loop.wgsl"),
)
}
}
/// Pseudo-random generator for the uniform distribution.
/// Pseudo-random generator with uniform distribution
pub fn random_uniform<R: Runtime, E: JitElement, const D: usize>(
shape: Shape<D>,
device: &R::Device,
low: E,
high: E,
lower_bound: E,
upper_bound: E,
) -> JitTensor<R, E, D> {
let client = R::client(device);
uniform_kernel(client, device, &shape, low, high)
random(
shape,
device,
Uniform {
lower_bound,
upper_bound,
},
)
}
/// Pseudo-random generator for uniform distribution, based on
/// another tensor.
pub fn random_like_uniform<R: Runtime, E: JitElement, const D: usize>(
tensor: &JitTensor<R, E, D>,
low: E,
high: E,
lower_bound: E,
upper_bound: E,
) -> JitTensor<R, E, D> {
uniform_kernel(
tensor.client.clone(),
random_uniform(
tensor.shape.clone(),
&tensor.device,
&tensor.shape,
low,
high,
lower_bound,
upper_bound,
)
}
/// Pseudo-random generator for uniform int distribution, based on
/// another tensor's client, device and shape.
pub fn random_like_uniform_int<R: Runtime, E: IntElement, const D: usize>(
tensor: &JitTensor<R, E, D>,
low: E,
high: E,
) -> JitTensor<R, E, D> {
uniform_int_kernel(
tensor.client.clone(),
&tensor.device,
&tensor.shape,
low,
high,
)
}
fn uniform_kernel<R: Runtime, E: JitElement, const D: usize>(
client: ComputeClient<R::Server, R::Channel>,
device: &R::Device,
shape: &Shape<D>,
low: E,
high: E,
) -> JitTensor<R, E, D> {
const N_VALUES_PER_THREAD: usize = 128;
let output = empty_device(client.clone(), device.clone(), shape.clone());
let info_handle = make_info_buffer::<R>(client.clone(), N_VALUES_PER_THREAD);
let args_handle = make_args_buffer::<R, E>(client.clone(), &[low, high]);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
let kernel = StaticKernel::<
KernelSettings<UniformPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(workgroup);
client.execute(
Box::new(kernel),
&[&output.handle, &info_handle, &args_handle],
);
output
}
fn uniform_int_kernel<R: Runtime, E: IntElement, const D: usize>(
client: ComputeClient<R::Server, R::Channel>,
device: &R::Device,
shape: &Shape<D>,
low: E,
high: E,
) -> JitTensor<R, E, D> {
const N_VALUES_PER_THREAD: usize = 128;
let output = empty_device(client.clone(), device.clone(), shape.clone());
let info_handle = make_info_buffer::<R>(client.clone(), N_VALUES_PER_THREAD);
let args_handle = make_args_buffer::<R, E>(client.clone(), &[low, high]);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
let kernel = StaticKernel::<
KernelSettings<UniformIntPrng, u32, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(workgroup);
client.execute(
Box::new(kernel),
&[&output.handle, &info_handle, &args_handle],
);
output
}

View File

@ -1,13 +0,0 @@
let prob = args[0];
for (var i = 0u; i < n_values_per_thread; i++) {
let write_index = write_index_base + i * n_threads_per_workgroup;
state[0u] = taus_step_0(state[0u]);
state[1u] = taus_step_1(state[1u]);
state[2u] = taus_step_2(state[2u]);
state[3u] = lcg_step(state[3u]);
let random_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
let float = cast_u32_to_float(random_u32);
output[write_index] = cast_elem(float < prob);
}

View File

@ -1,12 +0,0 @@
fn box_muller_transform(unit_1: {{ elem }}, unit_2: {{ elem }}) -> array<{{ elem }}, 2> {
let mean = args[0];
let stdev = args[1];
let coeff = stdev * sqrt(-2.0 * log(unit_1));
let pi = 3.141592653589793238;
let trigo_arg = 2.0 * pi * unit_2;
let cos_ = cos(trigo_arg);
let sin_ = sin(trigo_arg);
return array(coeff * cos_ + mean, coeff * sin_ + mean);
}

View File

@ -1,23 +0,0 @@
for (var i = 0u; i < n_values_per_thread / 2u; i++) {
let write_index_0 = write_index_base + (2u * i) * n_threads_per_workgroup;
let write_index_1 = write_index_0 + n_threads_per_workgroup;
state[0u] = taus_step_0(state[0u]);
state[1u] = taus_step_1(state[1u]);
state[2u] = taus_step_2(state[2u]);
state[3u] = lcg_step(state[3u]);
let random_1_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
let random_1 = cast_u32_to_float(random_1_u32);
state[0u] = taus_step_0(state[0u]);
state[1u] = taus_step_1(state[1u]);
state[2u] = taus_step_2(state[2u]);
state[3u] = lcg_step(state[3u]);
let random_2_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
let random_2 = cast_u32_to_float(random_2_u32);
let transformed = box_muller_transform(random_1, random_2);
output[write_index_0] = transformed[0];
output[write_index_1] = transformed[1];
}

View File

@ -1,61 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> info: array<u32, 5>;
@group(0)
@binding(2)
var<storage, read> args: array<{{ elem }}, {{ num_args }}>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(local_invocation_index) local_id: u32,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
// Thread preparation
let n_threads_per_workgroup = {{ workgroup_size }}u;
let workgroup_offset = (workgroup_id.x * num_workgroups.y + workgroup_id.y) * n_threads_per_workgroup;
let n_values_per_thread = info[0u];
let write_index_base = workgroup_offset * n_values_per_thread + local_id;
// Set state with unique seeds
let thread_seed = 1000000007u * (workgroup_offset + local_id);
var state: array<u32, 4u>;
for (var i = 0u; i < 4u; i++) {
state[i] = info[i + 1u] + thread_seed;
}
// Creation of n_values_per_thread values, specific to the distribution
{{ prng_loop }}
}
fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
let b = ((z << s1) ^ z) >> s2;
return ((z & m) << s3) ^ b;
}
fn taus_step_0(z: u32) -> u32 {
return taus_step(z, 13u, 19u, 12u, 4294967294u);
}
fn taus_step_1(z: u32) -> u32 {
return taus_step(z, 2u, 25u, 4u, 4294967288u);
}
fn taus_step_2(z: u32) -> u32 {
return taus_step(z, 3u, 11u, 17u, 4294967280u);
}
fn lcg_step(z: u32) -> u32 {
return (1664525u * z + 1013904223u);
}
fn cast_u32_to_float(number: u32) -> {{ elem }} {
let tmp = 2.3283064365387e-10 * f32(number);
return {{ elem }}(tmp);
}

View File

@ -1,16 +0,0 @@
let low = args[0];
let high = args[1];
let scale = high - low;
let bias = low;
for (var i = 0u; i < n_values_per_thread; i++) {
let write_index = write_index_base + i * n_threads_per_workgroup;
state[0u] = taus_step_0(state[0u]);
state[1u] = taus_step_1(state[1u]);
state[2u] = taus_step_2(state[2u]);
state[3u] = lcg_step(state[3u]);
let random_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
let float = cast_u32_to_float(random_u32);
output[write_index] = float * scale + bias;
}

View File

@ -1,20 +0,0 @@
let low = u32(args[0]);
let high = u32(args[1]);
let range = high - low;
let safe_range = max(range, 1u); // Ensure range is not zero to avoid division by 0 in % op
for (var i = 0u; i < n_values_per_thread; i++) {
let write_index = write_index_base + i * n_threads_per_workgroup;
state[0u] = taus_step_0(state[0u]);
state[1u] = taus_step_1(state[1u]);
state[2u] = taus_step_2(state[2u]);
state[3u] = lcg_step(state[3u]);
let random_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
// Modulus operation to fit within the range
let mod_result: u32 = u32(random_u32 % safe_range);
output[write_index] = u32(mod_result + low);
}

View File

@ -25,7 +25,7 @@ mod tests {
#[serial]
fn empirical_mean_close_to_expectation() {
TestBackend::seed(0);
let shape = [128, 128];
let shape = [100, 100];
let device = Default::default();
let mean = 10.;
let tensor =
@ -44,6 +44,7 @@ mod tests {
let s = 1.;
let tensor =
Tensor::<TestBackend, 2>::random(shape.clone(), Distribution::Normal(mu, s), &device);
let stats = calculate_bin_stats(
tensor.into_data().value,
6,

View File

@ -19,7 +19,7 @@ mod reduction {
));
let val_ref = tensor_ref.sum_dim(1);
val_ref.into_data().assert_approx_eq(&val.into_data(), 3);
val_ref.into_data().assert_approx_eq(&val.into_data(), 2);
}
#[test]

View File

@ -507,6 +507,26 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
gpu::Operator::BitwiseAnd(op) => wgsl::Instruction::BitwiseAnd {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::BitwiseXor(op) => wgsl::Instruction::BitwiseXor {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::ShiftLeft(op) => wgsl::Instruction::ShiftLeft {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::ShiftRight(op) => wgsl::Instruction::ShiftRight {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
}
}

View File

@ -185,6 +185,26 @@ pub enum Instruction {
Loop {
instructions: Vec<Instruction>,
},
BitwiseAnd {
lhs: Variable,
rhs: Variable,
out: Variable,
},
BitwiseXor {
lhs: Variable,
rhs: Variable,
out: Variable,
},
ShiftLeft {
lhs: Variable,
rhs: Variable,
out: Variable,
},
ShiftRight {
lhs: Variable,
rhs: Variable,
out: Variable,
},
}
impl Display for Instruction {
@ -423,6 +443,18 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
}
f.write_str("}\n")
}
Instruction::BitwiseAnd { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} & {rhs};\n"))
}
Instruction::BitwiseXor { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} ^ {rhs};\n"))
}
Instruction::ShiftLeft { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} << {rhs};\n"))
}
Instruction::ShiftRight { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} >> {rhs};\n"))
}
}
}
}

View File

@ -145,6 +145,7 @@ where
}
let source = kernel.source().complete();
println!("{}", source);
let pipeline = self.compile_source(&source);
self.pipelines.insert(kernel_id.clone(), pipeline.clone());