mirror of https://github.com/tracel-ai/burn.git
JIT Migration: PRNG (#1433)
* wip bernoulli * wip * bernoulli works * uniform works * done * remove old * refactor prng traits * forgot to save file * allow * clippy * clippy * scalar commutativity * array instead of vec
This commit is contained in:
parent
3f7e6bd5bc
commit
093cbd397d
|
@ -29,6 +29,10 @@ macro_rules! gpu {
|
|||
($scope:expr, $out:ident = $lhs:ident * $rhs:expr) => {
|
||||
gpu!($scope, $out = mul($lhs, $rhs))
|
||||
};
|
||||
// out *= input
|
||||
($scope:expr, $out:ident *= $input:ident) => {
|
||||
gpu!($scope, $out = mul($out, $input))
|
||||
};
|
||||
// out = mul(lhs, rhs)
|
||||
($scope:expr, $out:ident = mul($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Mul(
|
||||
|
@ -55,10 +59,6 @@ macro_rules! gpu {
|
|||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs ^ rhs
|
||||
($scope:expr, $out:ident = $lhs:ident ^ $rhs:expr) => {
|
||||
gpu!($scope, $out = powf($lhs, $rhs))
|
||||
};
|
||||
// out = powf(lhs, rhs)
|
||||
($scope:expr, $out:ident = powf($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Powf(
|
||||
|
@ -95,6 +95,46 @@ macro_rules! gpu {
|
|||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs & rhs
|
||||
($scope:expr, $out: ident = $lhs:ident & $rhs:ident) => {
|
||||
gpu!($scope, $out = bitwise_and($lhs, $rhs))
|
||||
};
|
||||
// out = bitwise_and(lhs, rhs)
|
||||
($scope:expr, $out:ident = bitwise_and($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::BitwiseAnd(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs ^ rhs
|
||||
($scope:expr, $out: ident = $lhs:ident ^ $rhs:ident) => {
|
||||
gpu!($scope, $out = bitwise_xor($lhs, $rhs))
|
||||
};
|
||||
// out = bitwise_xor(lhs, rhs)
|
||||
($scope:expr, $out:ident = bitwise_xor($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::BitwiseXor(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs << rhs
|
||||
($scope:expr, $out: ident = $lhs:ident << $rhs:ident) => {
|
||||
gpu!($scope, $out = shift_left($lhs, $rhs))
|
||||
};
|
||||
// out = shift_left(lhs, rhs)
|
||||
($scope:expr, $out:ident = shift_left($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::ShiftLeft(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs >> rhs
|
||||
($scope:expr, $out: ident = $lhs:ident >> $rhs:ident) => {
|
||||
gpu!($scope, $out = shift_right($lhs, $rhs))
|
||||
};
|
||||
// out = shift_right(lhs, rhs)
|
||||
($scope:expr, $out:ident = shift_right($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::ShiftRight(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs == rhs
|
||||
($scope:expr, $out:ident = $lhs:ident == $rhs:expr) => {
|
||||
gpu!($scope, $out = equal($lhs, $rhs))
|
||||
|
|
|
@ -53,6 +53,10 @@ pub enum Operator {
|
|||
Not(UnaryOperator),
|
||||
Max(BinaryOperator),
|
||||
Min(BinaryOperator),
|
||||
BitwiseAnd(BinaryOperator),
|
||||
BitwiseXor(BinaryOperator),
|
||||
ShiftLeft(BinaryOperator),
|
||||
ShiftRight(BinaryOperator),
|
||||
}
|
||||
|
||||
/// All metadata that can be access in a shader.
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use crate::JitElement;
|
||||
|
||||
use super::{
|
||||
gpu, processing::ScopeProcessing, Elem, IndexOffsetGlobalWithLayout, Item, Operation, Operator,
|
||||
Procedure, ReadGlobal, ReadGlobalWithLayout, UnaryOperator, Variable, Vectorization,
|
||||
|
@ -64,6 +66,18 @@ impl Scope {
|
|||
local
|
||||
}
|
||||
|
||||
/// Create a variable initialized at some value.
|
||||
pub(crate) fn create_with_value<E: JitElement, I: Into<Item> + Copy>(
|
||||
&mut self,
|
||||
value: E,
|
||||
item: I,
|
||||
) -> Variable {
|
||||
let local = self.create_local(item);
|
||||
let value = Variable::ConstantScalar(value.to_f64().unwrap(), item.into().elem());
|
||||
gpu!(self, local = value);
|
||||
local
|
||||
}
|
||||
|
||||
/// Create a local variable of the given [item type](Item).
|
||||
pub(crate) fn create_local<I: Into<Item>>(&mut self, item: I) -> Variable {
|
||||
let item = item.into();
|
||||
|
|
|
@ -74,6 +74,10 @@ impl Operator {
|
|||
Operator::And(op) => Operator::And(op.vectorize(vectorization)),
|
||||
Operator::Or(op) => Operator::Or(op.vectorize(vectorization)),
|
||||
Operator::Not(op) => Operator::Not(op.vectorize(vectorization)),
|
||||
Operator::BitwiseAnd(op) => Operator::BitwiseAnd(op.vectorize(vectorization)),
|
||||
Operator::BitwiseXor(op) => Operator::BitwiseXor(op.vectorize(vectorization)),
|
||||
Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)),
|
||||
Operator::ShiftRight(op) => Operator::ShiftRight(op.vectorize(vectorization)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use crate::compute::{DynamicKernel, Kernel, StaticKernel, WorkGroup};
|
||||
use crate::element::JitElement;
|
||||
use crate::gpu::Elem;
|
||||
use crate::kernel::{
|
||||
elemwise_workgroup, DynamicKernelSource, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
};
|
||||
|
@ -22,10 +23,6 @@ pub enum WorkgroupLaunch {
|
|||
}
|
||||
|
||||
/// Execute a static kernel.
|
||||
///
|
||||
///
|
||||
/// The limitation from this method is that you can't launch a kernel with multiple types of
|
||||
/// scalar.
|
||||
pub fn execute_static<R, K, E>(
|
||||
inputs: &[EagerHandle<R>],
|
||||
outputs: &[EagerHandle<R>],
|
||||
|
@ -37,12 +34,32 @@ pub fn execute_static<R, K, E>(
|
|||
R: Runtime,
|
||||
E: JitElement,
|
||||
{
|
||||
let settings = execute_settings(inputs, outputs, scalar_elems, launch, &client);
|
||||
execute_static_::<R, K, E, E, E>(inputs, outputs, scalar_elems, None, None, launch, client)
|
||||
}
|
||||
|
||||
fn execute_static_<R, K, E1, E2, E3>(
|
||||
inputs: &[EagerHandle<R>],
|
||||
outputs: &[EagerHandle<R>],
|
||||
scalars_1: Option<&[E1]>,
|
||||
scalars_2: Option<&[E2]>,
|
||||
scalars_3: Option<&[E3]>,
|
||||
launch: WorkgroupLaunch,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
) where
|
||||
K: StaticKernelSource + 'static,
|
||||
R: Runtime,
|
||||
E1: JitElement,
|
||||
E2: JitElement,
|
||||
E3: JitElement,
|
||||
{
|
||||
let settings = execute_settings(
|
||||
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
|
||||
);
|
||||
let mut handles = settings.handles_tensors;
|
||||
let workgroup = settings.workgroup;
|
||||
|
||||
handles.push(&settings.handle_info);
|
||||
if let Some(handle) = settings.handle_scalars.as_ref() {
|
||||
for handle in settings.handles_scalars.iter() {
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
|
@ -50,11 +67,180 @@ pub fn execute_static<R, K, E>(
|
|||
client.execute(kernel, &handles);
|
||||
}
|
||||
|
||||
pub struct Execution<'h, K, R: Runtime, Scalars> {
|
||||
scalars: Scalars,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
kernel: K,
|
||||
inputs: &'h [EagerHandle<'h, R>],
|
||||
outputs: &'h [EagerHandle<'h, R>],
|
||||
}
|
||||
|
||||
impl<'h, K, R: Runtime> Execution<'h, K, R, ()> {
|
||||
pub fn start(
|
||||
kernel: K,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
) -> Execution<'h, K, R, ()> {
|
||||
Execution {
|
||||
scalars: (),
|
||||
client,
|
||||
kernel,
|
||||
inputs: &[],
|
||||
outputs: &[],
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn inputs(self, inputs: &'h [EagerHandle<'h, R>]) -> Execution<'h, K, R, ()> {
|
||||
Execution {
|
||||
scalars: self.scalars,
|
||||
client: self.client,
|
||||
kernel: self.kernel,
|
||||
inputs,
|
||||
outputs: self.outputs,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn outputs(self, outputs: &'h [EagerHandle<'h, R>]) -> Execution<'h, K, R, ()> {
|
||||
Execution {
|
||||
scalars: self.scalars,
|
||||
client: self.client,
|
||||
kernel: self.kernel,
|
||||
inputs: self.inputs,
|
||||
outputs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'h, K, R> Execution<'h, K, R, ()>
|
||||
where
|
||||
K: DynamicKernelSource + 'static,
|
||||
R: Runtime,
|
||||
{
|
||||
pub fn with_scalars<E>(self, scalars: &[E]) -> Execution<'h, K, R, (&[E],)> {
|
||||
Execution {
|
||||
scalars: (scalars,),
|
||||
client: self.client,
|
||||
kernel: self.kernel,
|
||||
inputs: self.inputs,
|
||||
outputs: self.outputs,
|
||||
}
|
||||
}
|
||||
/// Execute a dynamic kernel.
|
||||
#[allow(unused)]
|
||||
pub fn execute(self, launch: WorkgroupLaunch) {
|
||||
execute_dynamic_::<R, K, f32, f32, f32>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self.kernel,
|
||||
launch,
|
||||
self.client,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'h, 'a, K, R, E> Execution<'h, K, R, (&'a [E],)>
|
||||
where
|
||||
K: DynamicKernelSource + 'static,
|
||||
R: Runtime,
|
||||
E: JitElement,
|
||||
{
|
||||
pub fn with_scalars<'b, E2>(
|
||||
self,
|
||||
scalars: &'b [E2],
|
||||
) -> Execution<'h, K, R, (&'a [E], &'b [E2])> {
|
||||
Execution {
|
||||
scalars: (self.scalars.0, scalars),
|
||||
client: self.client,
|
||||
kernel: self.kernel,
|
||||
inputs: self.inputs,
|
||||
outputs: self.outputs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a dynamic kernel.
|
||||
#[allow(unused)]
|
||||
pub fn execute(self, launch: WorkgroupLaunch) {
|
||||
execute_dynamic_::<R, K, E, f32, f32>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
Some(self.scalars.0),
|
||||
None,
|
||||
None,
|
||||
self.kernel,
|
||||
launch,
|
||||
self.client,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'h, 'a, 'b, K, R, E1, E2> Execution<'h, K, R, (&'a [E1], &'b [E2])>
|
||||
where
|
||||
K: DynamicKernelSource + 'static,
|
||||
R: Runtime,
|
||||
E1: JitElement,
|
||||
E2: JitElement,
|
||||
{
|
||||
#[allow(unused, clippy::type_complexity)]
|
||||
pub fn with_scalars<'c, E3>(
|
||||
self,
|
||||
scalars: &'c [E3],
|
||||
) -> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])> {
|
||||
Execution {
|
||||
scalars: (self.scalars.0, self.scalars.1, scalars),
|
||||
client: self.client,
|
||||
kernel: self.kernel,
|
||||
inputs: self.inputs,
|
||||
outputs: self.outputs,
|
||||
}
|
||||
}
|
||||
/// Execute a dynamic kernel.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn execute(self, launch: WorkgroupLaunch)
|
||||
where
|
||||
K: DynamicKernelSource + 'static,
|
||||
R: Runtime,
|
||||
{
|
||||
execute_dynamic_::<R, K, E1, E2, f32>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
Some(self.scalars.0),
|
||||
Some(self.scalars.1),
|
||||
None,
|
||||
self.kernel,
|
||||
launch,
|
||||
self.client,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'h, 'a, 'b, 'c, K, R, E1, E2, E3> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])>
|
||||
where
|
||||
K: DynamicKernelSource + 'static,
|
||||
R: Runtime,
|
||||
E1: JitElement,
|
||||
E2: JitElement,
|
||||
E3: JitElement,
|
||||
{
|
||||
/// Execute a dynamic kernel.
|
||||
#[allow(unused)]
|
||||
pub fn execute(self, launch: WorkgroupLaunch) {
|
||||
execute_dynamic_::<R, K, E1, E2, E3>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
Some(self.scalars.0),
|
||||
Some(self.scalars.1),
|
||||
Some(self.scalars.2),
|
||||
self.kernel,
|
||||
launch,
|
||||
self.client,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a dynamic kernel.
|
||||
///
|
||||
///
|
||||
/// The limitation from this method is that you can't launch a kernel with multiple types of
|
||||
/// scalar.
|
||||
pub fn execute_dynamic<R, K, E>(
|
||||
inputs: &[EagerHandle<R>],
|
||||
outputs: &[EagerHandle<R>],
|
||||
|
@ -67,12 +253,43 @@ pub fn execute_dynamic<R, K, E>(
|
|||
R: Runtime,
|
||||
E: JitElement,
|
||||
{
|
||||
let settings = execute_settings(inputs, outputs, scalar_elems, launch, &client);
|
||||
execute_dynamic_::<R, K, E, E, E>(
|
||||
inputs,
|
||||
outputs,
|
||||
scalar_elems,
|
||||
None,
|
||||
None,
|
||||
kernel,
|
||||
launch,
|
||||
client,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn execute_dynamic_<R, K, E1, E2, E3>(
|
||||
inputs: &[EagerHandle<R>],
|
||||
outputs: &[EagerHandle<R>],
|
||||
scalars_1: Option<&[E1]>,
|
||||
scalars_2: Option<&[E2]>,
|
||||
scalars_3: Option<&[E3]>,
|
||||
kernel: K,
|
||||
launch: WorkgroupLaunch,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
) where
|
||||
K: DynamicKernelSource + 'static,
|
||||
R: Runtime,
|
||||
E1: JitElement,
|
||||
E2: JitElement,
|
||||
E3: JitElement,
|
||||
{
|
||||
let settings = execute_settings(
|
||||
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
|
||||
);
|
||||
let mut handles = settings.handles_tensors;
|
||||
let workgroup = settings.workgroup;
|
||||
|
||||
handles.push(&settings.handle_info);
|
||||
if let Some(handle) = settings.handle_scalars.as_ref() {
|
||||
for handle in settings.handles_scalars.iter() {
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
|
@ -84,14 +301,16 @@ pub fn execute_dynamic<R, K, E>(
|
|||
struct ExecuteSettings<'a, R: Runtime> {
|
||||
handles_tensors: Vec<&'a Handle<R::Server>>,
|
||||
handle_info: Handle<R::Server>,
|
||||
handle_scalars: Option<Handle<R::Server>>,
|
||||
handles_scalars: Vec<Handle<R::Server>>,
|
||||
workgroup: WorkGroup,
|
||||
}
|
||||
|
||||
fn execute_settings<'a, R: Runtime, E: JitElement>(
|
||||
fn execute_settings<'a, R: Runtime, E1: JitElement, E2: JitElement, E3: JitElement>(
|
||||
inputs: &'a [EagerHandle<R>],
|
||||
outputs: &'a [EagerHandle<R>],
|
||||
scalar_elems: Option<&[E]>,
|
||||
scalars_1: Option<&[E1]>,
|
||||
scalars_2: Option<&[E2]>,
|
||||
scalars_3: Option<&[E3]>,
|
||||
launch: WorkgroupLaunch,
|
||||
client: &ComputeClient<R::Server, R::Channel>,
|
||||
) -> ExecuteSettings<'a, R> {
|
||||
|
@ -139,10 +358,8 @@ fn execute_settings<'a, R: Runtime, E: JitElement>(
|
|||
let info = client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
// Finally we finish with the named bindings.
|
||||
let mut scalars = None;
|
||||
if let Some(values) = &scalar_elems {
|
||||
scalars = Some(client.create(bytemuck::cast_slice(values)));
|
||||
}
|
||||
let handles_scalars =
|
||||
create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client);
|
||||
|
||||
let workgroup = match launch {
|
||||
WorkgroupLaunch::Custom(workgroup) => workgroup,
|
||||
|
@ -152,11 +369,54 @@ fn execute_settings<'a, R: Runtime, E: JitElement>(
|
|||
ExecuteSettings {
|
||||
handles_tensors: handles,
|
||||
handle_info: info,
|
||||
handle_scalars: scalars,
|
||||
handles_scalars,
|
||||
workgroup,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_scalar_handles<R: Runtime, E1: JitElement, E2: JitElement, E3: JitElement>(
|
||||
scalars_0: Option<&[E1]>,
|
||||
scalars_1: Option<&[E2]>,
|
||||
scalars_2: Option<&[E3]>,
|
||||
client: &ComputeClient<R::Server, R::Channel>,
|
||||
) -> Vec<Handle<R::Server>> {
|
||||
// It is crucial that scalars follow this order: float, int, uint
|
||||
let element_priority = |elem: Elem| match elem {
|
||||
Elem::Float => 0,
|
||||
Elem::Int => 1,
|
||||
Elem::UInt => 2,
|
||||
Elem::Bool => panic!("Bool scalars are not supported"),
|
||||
};
|
||||
let scalar_priorities: [usize; 3] = [
|
||||
element_priority(E1::gpu_elem()),
|
||||
element_priority(E2::gpu_elem()),
|
||||
element_priority(E3::gpu_elem()),
|
||||
];
|
||||
|
||||
let mut handles_scalars = Vec::new();
|
||||
for i in 0..3 {
|
||||
for (j, scalar_priority) in scalar_priorities.iter().enumerate() {
|
||||
if scalar_priority == &i {
|
||||
if j == 0 {
|
||||
if let Some(values) = &scalars_0 {
|
||||
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
|
||||
}
|
||||
} else if j == 1 {
|
||||
if let Some(values) = &scalars_1 {
|
||||
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
|
||||
}
|
||||
} else if j == 2 {
|
||||
if let Some(values) = &scalars_2 {
|
||||
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handles_scalars
|
||||
}
|
||||
|
||||
pub(crate) fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
|
||||
let mut num_elems = 1;
|
||||
for i in shape.iter() {
|
||||
|
|
|
@ -326,6 +326,26 @@ impl TraceBuilder {
|
|||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::BitwiseAnd(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::BitwiseXor(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::ShiftLeft(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::ShiftRight(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
},
|
||||
Operation::Procedure(proc) => {
|
||||
match proc {
|
||||
|
|
|
@ -181,20 +181,6 @@ pub(crate) fn elemwise_workgroup(num_elems: usize, workgroup_size: usize) -> Wor
|
|||
WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1)
|
||||
}
|
||||
|
||||
pub(crate) fn prng_workgroup(
|
||||
num_elems: usize,
|
||||
workgroup_size: usize,
|
||||
n_values_per_thread: usize,
|
||||
) -> WorkGroup {
|
||||
let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
|
||||
let num_elem_per_invocation = workgroup_size * workgroup_size;
|
||||
let num_invocations = f32::ceil(num_threads / num_elem_per_invocation as f32);
|
||||
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
|
||||
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
|
||||
|
||||
WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -1,11 +1,123 @@
|
|||
use crate::{element::JitElement, kernel_wgsl, Runtime, SEED};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
compute::WorkGroup,
|
||||
gpu::{gpu, Elem, Scope, Variable},
|
||||
kernel::{DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT},
|
||||
tensor::JitTensor,
|
||||
Compiler, JitElement, Runtime, SEED,
|
||||
};
|
||||
use burn_common::rand::get_seeded_rng;
|
||||
use burn_compute::{client::ComputeClient, server::Handle};
|
||||
use burn_tensor::Shape;
|
||||
use rand::Rng;
|
||||
|
||||
kernel_wgsl!(Prng, "../../template/prng/prng.wgsl");
|
||||
pub(crate) const N_VALUES_PER_THREAD: usize = 128;
|
||||
|
||||
pub(crate) fn get_seeds() -> Vec<u32> {
|
||||
/// Pseudo-random generator
|
||||
pub(crate) fn random<P: Prng<E>, R: Runtime, E: JitElement, const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &R::Device,
|
||||
prng: P,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let client = R::client(device);
|
||||
let kernel: PrngEagerKernel<P, R, E> = PrngEagerKernel::new();
|
||||
let num_elems = shape.num_elements();
|
||||
let buffer = client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = JitTensor::new(client.clone(), device.clone(), shape.clone(), buffer);
|
||||
let seeds = get_seeds();
|
||||
|
||||
Execution::start(kernel, client)
|
||||
.outputs(&[EagerHandle::<R>::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)])
|
||||
.with_scalars(&seeds)
|
||||
.with_scalars(&prng.args())
|
||||
.execute(WorkgroupLaunch::Custom(prng_workgroup(
|
||||
num_elems,
|
||||
WORKGROUP_DEFAULT,
|
||||
N_VALUES_PER_THREAD,
|
||||
)));
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn prng_workgroup(
|
||||
num_elems: usize,
|
||||
workgroup_size: usize,
|
||||
n_values_per_thread: usize,
|
||||
) -> WorkGroup {
|
||||
let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
|
||||
let num_elem_per_invocation = workgroup_size * workgroup_size;
|
||||
let num_invocations = f32::ceil(num_threads / num_elem_per_invocation as f32);
|
||||
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
|
||||
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
|
||||
|
||||
WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1)
|
||||
}
|
||||
|
||||
impl<P: Prng<E>, R: Runtime, E: JitElement> DynamicKernelSource for PrngEagerKernel<P, R, E> {
|
||||
fn source(&self) -> crate::kernel::SourceTemplate {
|
||||
let mut scope = Scope::root();
|
||||
let item = E::gpu_elem().into();
|
||||
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
|
||||
let seed0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let seed1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let seed2 = Variable::GlobalScalar(2, Elem::UInt);
|
||||
let seed3 = Variable::GlobalScalar(3, Elem::UInt);
|
||||
let seeds = [seed0, seed1, seed2, seed3];
|
||||
|
||||
let mut args = Vec::<Variable>::new();
|
||||
for i in 0..P::args_length() {
|
||||
args.push(Variable::GlobalScalar(i as u16, item.elem()));
|
||||
}
|
||||
|
||||
PrngShader::<P, E>::new(output, N_VALUES_PER_THREAD, seeds, args).expand(&mut scope);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
let args = InputInfo::Scalar {
|
||||
elem: E::gpu_elem(),
|
||||
size: P::args_length(),
|
||||
};
|
||||
let seeds = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
size: 4,
|
||||
};
|
||||
let out = OutputInfo::Array { item };
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![args, seeds],
|
||||
outputs: vec![out],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = CompilationSettings::default();
|
||||
let shader = Compilation::new(info).compile(settings);
|
||||
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||
SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", core::any::TypeId::of::<Self>(),)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct PrngEagerKernel<P: Prng<E>, R: Runtime, E: JitElement> {
|
||||
_prng: PhantomData<P>,
|
||||
_runtime: PhantomData<R>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
pub(crate) fn get_seeds() -> [u32; 4] {
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
let mut rng = match seed.as_ref() {
|
||||
Some(rng_seeded) => rng_seeded.clone(),
|
||||
|
@ -16,23 +128,168 @@ pub(crate) fn get_seeds() -> Vec<u32> {
|
|||
seeds.push(rng.gen());
|
||||
}
|
||||
*seed = Some(rng);
|
||||
seeds
|
||||
|
||||
seeds.try_into().unwrap()
|
||||
}
|
||||
|
||||
pub(crate) fn make_info_buffer<R: Runtime>(
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
pub(crate) trait Prng<E>: Send + Sync + 'static {
|
||||
fn args(self) -> Vec<E>;
|
||||
|
||||
fn args_length() -> usize;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn inner_loop(
|
||||
scope: &mut Scope,
|
||||
args: Vec<Variable>,
|
||||
write_index_base: Variable,
|
||||
n_invocations: Variable,
|
||||
n_values_per_thread: usize,
|
||||
state_0: Variable,
|
||||
state_1: Variable,
|
||||
state_2: Variable,
|
||||
state_3: Variable,
|
||||
output: Variable,
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct PrngShader<P: Prng<E>, E: JitElement> {
|
||||
output: Variable,
|
||||
n_values_per_thread: usize,
|
||||
) -> Handle<R::Server> {
|
||||
let mut info = get_seeds();
|
||||
info.insert(0, n_values_per_thread as u32);
|
||||
client.create(bytemuck::cast_slice(&info))
|
||||
seeds: [Variable; 4],
|
||||
args: Vec<Variable>,
|
||||
_prng: PhantomData<P>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
pub(crate) fn make_args_buffer<R: Runtime, E: JitElement>(
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
args: &[E],
|
||||
) -> Handle<R::Server> {
|
||||
client.create(E::as_bytes(args))
|
||||
impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
|
||||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let output = self.output;
|
||||
let [seed_0, seed_1, seed_2, seed_3] = self.seeds;
|
||||
let n_values_per_thread: Variable = self.n_values_per_thread.into();
|
||||
let args = self.args;
|
||||
|
||||
let workgroup_size_x = Variable::WorkgroupSizeX;
|
||||
let workgroup_size_y = Variable::WorkgroupSizeY;
|
||||
let workgroup_id_x = Variable::WorkgroupIdX;
|
||||
let workgroup_id_y = Variable::WorkgroupIdY;
|
||||
let num_workgroups_y = Variable::NumWorkgroupsY;
|
||||
let local_index = Variable::LocalInvocationIndex;
|
||||
|
||||
let n_invocations = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, n_invocations = workgroup_size_x);
|
||||
gpu!(scope, n_invocations *= workgroup_size_y);
|
||||
|
||||
let workgroup_offset = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, workgroup_offset = workgroup_id_x * num_workgroups_y);
|
||||
gpu!(scope, workgroup_offset += workgroup_id_y);
|
||||
gpu!(scope, workgroup_offset *= n_invocations);
|
||||
|
||||
let write_index_base = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, write_index_base = workgroup_offset);
|
||||
gpu!(scope, write_index_base *= n_values_per_thread);
|
||||
gpu!(scope, write_index_base += local_index);
|
||||
|
||||
// Set state with unique seeds
|
||||
let thread_seed = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, thread_seed = cast(1000000007));
|
||||
let thread_seed_index = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, thread_seed_index = workgroup_offset + local_index);
|
||||
gpu!(scope, thread_seed *= thread_seed_index);
|
||||
|
||||
let state_0 = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, state_0 = thread_seed);
|
||||
gpu!(scope, state_0 += seed_0);
|
||||
|
||||
let state_1 = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, state_1 = thread_seed);
|
||||
gpu!(scope, state_1 += seed_1);
|
||||
|
||||
let state_2 = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, state_2 = thread_seed);
|
||||
gpu!(scope, state_2 += seed_2);
|
||||
|
||||
let state_3 = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, state_3 = thread_seed);
|
||||
gpu!(scope, state_3 += seed_3);
|
||||
|
||||
// Creation of n_values_per_thread values, specific to the distribution
|
||||
P::inner_loop(
|
||||
scope,
|
||||
args,
|
||||
write_index_base,
|
||||
n_invocations,
|
||||
self.n_values_per_thread,
|
||||
state_0,
|
||||
state_1,
|
||||
state_2,
|
||||
state_3,
|
||||
output,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn taus_step_0(scope: &mut Scope, z: Variable) {
|
||||
taus_step(
|
||||
scope,
|
||||
z,
|
||||
13u32.into(),
|
||||
19u32.into(),
|
||||
12u32.into(),
|
||||
4294967294u32.into(),
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) fn taus_step_1(scope: &mut Scope, z: Variable) {
|
||||
taus_step(
|
||||
scope,
|
||||
z,
|
||||
2u32.into(),
|
||||
25u32.into(),
|
||||
4u32.into(),
|
||||
4294967288u32.into(),
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) fn taus_step_2(scope: &mut Scope, z: Variable) {
|
||||
taus_step(
|
||||
scope,
|
||||
z,
|
||||
3u32.into(),
|
||||
11u32.into(),
|
||||
17u32.into(),
|
||||
4294967280u32.into(),
|
||||
);
|
||||
}
|
||||
|
||||
fn taus_step(
|
||||
scope: &mut Scope,
|
||||
z: Variable,
|
||||
s1: Variable,
|
||||
s2: Variable,
|
||||
s3: Variable,
|
||||
m: Variable,
|
||||
) {
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, b = z << s1);
|
||||
gpu!(scope, b = b ^ z);
|
||||
gpu!(scope, b = b >> s2);
|
||||
gpu!(scope, z = z & m);
|
||||
gpu!(scope, z = z << s3);
|
||||
gpu!(scope, z = z ^ b);
|
||||
}
|
||||
|
||||
pub(crate) fn lcg_step(scope: &mut Scope, z: Variable) {
|
||||
let a: Variable = 1664525u32.into();
|
||||
let b: Variable = 1013904223u32.into();
|
||||
gpu!(scope, z *= a);
|
||||
gpu!(scope, z += b);
|
||||
}
|
||||
|
||||
pub(crate) fn cast_uint_to_float(scope: &mut Scope, int_random: Variable, float_random: Variable) {
|
||||
let tmp: Variable = 2.328_306_4e-10.into();
|
||||
gpu!(scope, float_random = cast(int_random));
|
||||
gpu!(scope, float_random *= tmp);
|
||||
}
|
||||
|
||||
#[cfg(feature = "export_tests")]
|
||||
|
@ -72,21 +329,4 @@ pub mod tests_utils {
|
|||
output[current_runs].n_runs += 1;
|
||||
output
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_bins() {
|
||||
let numbers = vec![0., 1., 1.5, 2., 2.5, 3., 2.5, 1.5, 3.5];
|
||||
let number_of_bins = 4;
|
||||
let low = 0.;
|
||||
let high = 4.;
|
||||
let stats = calculate_bin_stats(numbers, number_of_bins, low, high);
|
||||
assert_eq!(stats[0].count, 1);
|
||||
assert_eq!(stats[0].n_runs, 1);
|
||||
assert_eq!(stats[1].count, 3);
|
||||
assert_eq!(stats[1].n_runs, 2);
|
||||
assert_eq!(stats[2].count, 3);
|
||||
assert_eq!(stats[2].n_runs, 2);
|
||||
assert_eq!(stats[3].count, 2);
|
||||
assert_eq!(stats[3].n_runs, 2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,53 +1,73 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::JitElement,
|
||||
kernel::{
|
||||
prng::base::{make_args_buffer, make_info_buffer},
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
},
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
use super::base::Prng;
|
||||
use crate::{
|
||||
gpu::{gpu, Elem, Scope, Variable},
|
||||
kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2},
|
||||
tensor::JitTensor,
|
||||
JitElement, Runtime,
|
||||
};
|
||||
|
||||
struct BernoulliPrng;
|
||||
use super::{random, Prng};
|
||||
|
||||
impl StaticKernelSource for BernoulliPrng {
|
||||
fn source() -> SourceTemplate {
|
||||
Prng::source()
|
||||
.register("num_args", "1")
|
||||
.register(
|
||||
"prng_loop",
|
||||
include_str!("../../template/prng/bernoulli_inner_loop.wgsl"),
|
||||
)
|
||||
.add_template("fn cast_elem(e: bool) -> {{ elem }} {return {{elem}}(e);}")
|
||||
pub(crate) struct Bernoulli<E> {
|
||||
probability: E,
|
||||
}
|
||||
|
||||
impl<E: JitElement> Prng<E> for Bernoulli<E> {
|
||||
fn args(self) -> Vec<E> {
|
||||
vec![self.probability]
|
||||
}
|
||||
|
||||
fn inner_loop(
|
||||
scope: &mut Scope,
|
||||
args: Vec<Variable>,
|
||||
write_index_base: Variable,
|
||||
n_invocations: Variable,
|
||||
n_values_per_thread: usize,
|
||||
state_0: Variable,
|
||||
state_1: Variable,
|
||||
state_2: Variable,
|
||||
state_3: Variable,
|
||||
output: Variable,
|
||||
) {
|
||||
let prob = args[0];
|
||||
gpu!(
|
||||
scope,
|
||||
range(0u32, n_values_per_thread).for_each(|i, scope| {
|
||||
taus_step_0(scope, state_0);
|
||||
taus_step_1(scope, state_1);
|
||||
taus_step_2(scope, state_2);
|
||||
lcg_step(scope, state_3);
|
||||
|
||||
let int_random = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, int_random = state_0 ^ state_1);
|
||||
gpu!(scope, int_random = int_random ^ state_2);
|
||||
gpu!(scope, int_random = int_random ^ state_3);
|
||||
|
||||
let float_random = scope.create_local(Elem::Float);
|
||||
cast_uint_to_float(scope, int_random, float_random);
|
||||
|
||||
let bernoulli = scope.create_local(Elem::Bool);
|
||||
gpu!(scope, bernoulli = float_random < prob);
|
||||
|
||||
let write_index = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, write_index = i * n_invocations);
|
||||
gpu!(scope, write_index += write_index_base);
|
||||
gpu!(scope, output[write_index] = bernoulli);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
fn args_length() -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
/// Pseudo-random generator for bernoulli
|
||||
/// Pseudo-random generator with bernoulli distribution
|
||||
pub fn random_bernoulli<R: Runtime, E: JitElement, const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &R::Device,
|
||||
prob: E,
|
||||
probability: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
const N_VALUES_PER_THREAD: usize = 128;
|
||||
|
||||
let client = R::client(device);
|
||||
let output = empty_device(client.clone(), device.clone(), shape.clone());
|
||||
let info_handle = make_info_buffer::<R>(client.clone(), N_VALUES_PER_THREAD);
|
||||
let args_handle = make_args_buffer::<R, E>(client.clone(), &[prob]);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<BernoulliPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(workgroup);
|
||||
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
&[&output.handle, &info_handle, &args_handle],
|
||||
);
|
||||
|
||||
output
|
||||
random(shape, device, Bernoulli { probability })
|
||||
}
|
||||
|
|
|
@ -1,57 +1,122 @@
|
|||
use std::f32::consts::PI;
|
||||
|
||||
use burn_tensor::Shape;
|
||||
|
||||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::JitElement,
|
||||
kernel::{
|
||||
prng::base::{make_args_buffer, make_info_buffer},
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
},
|
||||
ops::numeric::empty_device,
|
||||
gpu::{gpu, Elem, Scope, Variable},
|
||||
kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2},
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
JitElement, Runtime,
|
||||
};
|
||||
|
||||
use super::base::Prng;
|
||||
use super::{random, Prng};
|
||||
|
||||
struct NormalPrng;
|
||||
pub(crate) struct Normal<E> {
|
||||
mean: E,
|
||||
std: E,
|
||||
}
|
||||
|
||||
impl StaticKernelSource for NormalPrng {
|
||||
fn source() -> SourceTemplate {
|
||||
Prng::source()
|
||||
.register("num_args", "2")
|
||||
.register(
|
||||
"prng_loop",
|
||||
include_str!("../../template/prng/normal_inner_loop.wgsl"),
|
||||
)
|
||||
.add_template(include_str!(
|
||||
"../../template/prng/box_muller_transform.wgsl"
|
||||
))
|
||||
impl<E: JitElement> Prng<E> for Normal<E> {
|
||||
fn args(self) -> Vec<E> {
|
||||
vec![self.mean, self.std]
|
||||
}
|
||||
|
||||
fn inner_loop(
|
||||
scope: &mut Scope,
|
||||
args: Vec<Variable>,
|
||||
write_index_base: Variable,
|
||||
n_invocations: Variable,
|
||||
n_values_per_thread: usize,
|
||||
state_0: Variable,
|
||||
state_1: Variable,
|
||||
state_2: Variable,
|
||||
state_3: Variable,
|
||||
output: Variable,
|
||||
) {
|
||||
let item = output.item();
|
||||
let mean = args[0];
|
||||
let std = args[1];
|
||||
let two_pi = scope.create_with_value(2. * PI, Elem::Float);
|
||||
let t_neg = scope.create_with_value(-2.0, item);
|
||||
let two: Variable = 2u32.into();
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(0u32, n_values_per_thread / 2).for_each(|i, scope| {
|
||||
let int_random = scope.create_local(Elem::UInt);
|
||||
|
||||
// First random uniform integer
|
||||
taus_step_0(scope, state_0);
|
||||
taus_step_1(scope, state_1);
|
||||
taus_step_2(scope, state_2);
|
||||
lcg_step(scope, state_3);
|
||||
|
||||
gpu!(scope, int_random = state_0 ^ state_1);
|
||||
gpu!(scope, int_random = int_random ^ state_2);
|
||||
gpu!(scope, int_random = int_random ^ state_3);
|
||||
|
||||
let unit_0 = scope.create_local(Elem::Float);
|
||||
cast_uint_to_float(scope, int_random, unit_0);
|
||||
|
||||
// Second random uniform integer
|
||||
taus_step_0(scope, state_0);
|
||||
taus_step_1(scope, state_1);
|
||||
taus_step_2(scope, state_2);
|
||||
lcg_step(scope, state_3);
|
||||
|
||||
gpu!(scope, int_random = state_0 ^ state_1);
|
||||
gpu!(scope, int_random = int_random ^ state_2);
|
||||
gpu!(scope, int_random = int_random ^ state_3);
|
||||
|
||||
let unit_1 = scope.create_local(Elem::Float);
|
||||
cast_uint_to_float(scope, int_random, unit_1);
|
||||
|
||||
// Box-Muller transform
|
||||
let coeff = scope.create_local(item);
|
||||
gpu!(scope, coeff = log(unit_0));
|
||||
gpu!(scope, coeff *= t_neg);
|
||||
gpu!(scope, coeff = sqrt(coeff));
|
||||
gpu!(scope, coeff *= std);
|
||||
|
||||
let trigo_arg = scope.create_local(item);
|
||||
gpu!(scope, trigo_arg = two_pi * unit_1);
|
||||
|
||||
let normal_0 = scope.create_local(item);
|
||||
let normal_1 = scope.create_local(item);
|
||||
gpu!(scope, normal_0 = cos(trigo_arg));
|
||||
gpu!(scope, normal_0 *= coeff);
|
||||
gpu!(scope, normal_0 += mean);
|
||||
gpu!(scope, normal_1 = sin(trigo_arg));
|
||||
gpu!(scope, normal_1 *= coeff);
|
||||
gpu!(scope, normal_1 += mean);
|
||||
|
||||
// Write to output
|
||||
let write_index_0 = scope.create_local(Elem::UInt);
|
||||
let write_index_1 = scope.create_local(Elem::UInt);
|
||||
let iteration_offset = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, write_index_0 = write_index_base);
|
||||
gpu!(scope, iteration_offset = two * i);
|
||||
gpu!(scope, iteration_offset *= n_invocations);
|
||||
gpu!(scope, write_index_0 += iteration_offset);
|
||||
gpu!(scope, write_index_1 = write_index_0 + n_invocations);
|
||||
|
||||
gpu!(scope, output[write_index_0] = normal_0);
|
||||
gpu!(scope, output[write_index_1] = normal_1);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
fn args_length() -> usize {
|
||||
2
|
||||
}
|
||||
}
|
||||
|
||||
/// Pseudo-random generaJitBackendl distribution
|
||||
/// Pseudo-random generator with uniform distribution
|
||||
pub fn random_normal<R: Runtime, E: JitElement, const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &R::Device,
|
||||
mean: E,
|
||||
std: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
const N_VALUES_PER_THREAD: usize = 128; // must be even
|
||||
|
||||
let client = R::client(device);
|
||||
let output = empty_device(client.clone(), device.clone(), shape.clone());
|
||||
let info_handle = make_info_buffer::<R>(client.clone(), N_VALUES_PER_THREAD);
|
||||
let args_handle = make_args_buffer::<R, E>(client.clone(), &[mean, std]);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<NormalPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(workgroup);
|
||||
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
&[&output.handle, &info_handle, &args_handle],
|
||||
);
|
||||
|
||||
output
|
||||
random(shape, device, Normal { mean, std })
|
||||
}
|
||||
|
|
|
@ -1,130 +1,102 @@
|
|||
use burn_compute::client::ComputeClient;
|
||||
use burn_tensor::Shape;
|
||||
|
||||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::JitElement,
|
||||
kernel::{
|
||||
prng::base::{make_args_buffer, make_info_buffer},
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
},
|
||||
ops::numeric::empty_device,
|
||||
gpu::{gpu, Elem, Scope, Variable},
|
||||
kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2},
|
||||
tensor::JitTensor,
|
||||
IntElement, Runtime,
|
||||
JitElement, Runtime,
|
||||
};
|
||||
|
||||
use super::base::Prng;
|
||||
use super::{random, Prng};
|
||||
|
||||
struct UniformPrng;
|
||||
struct UniformIntPrng;
|
||||
pub(crate) struct Uniform<E> {
|
||||
lower_bound: E,
|
||||
upper_bound: E,
|
||||
}
|
||||
|
||||
impl StaticKernelSource for UniformPrng {
|
||||
fn source() -> SourceTemplate {
|
||||
Prng::source().register("num_args", "2").register(
|
||||
"prng_loop",
|
||||
include_str!("../../template/prng/uniform_inner_loop.wgsl"),
|
||||
)
|
||||
impl<E: JitElement> Prng<E> for Uniform<E> {
|
||||
fn args(self) -> Vec<E> {
|
||||
vec![self.lower_bound, self.upper_bound]
|
||||
}
|
||||
|
||||
fn inner_loop(
|
||||
scope: &mut Scope,
|
||||
args: Vec<Variable>,
|
||||
write_index_base: Variable,
|
||||
n_invocations: Variable,
|
||||
n_values_per_thread: usize,
|
||||
state_0: Variable,
|
||||
state_1: Variable,
|
||||
state_2: Variable,
|
||||
state_3: Variable,
|
||||
output: Variable,
|
||||
) {
|
||||
let item = output.item();
|
||||
let lower_bound = args[0];
|
||||
let upper_bound = args[1];
|
||||
let scale = scope.create_local(item);
|
||||
gpu!(scope, scale = upper_bound - lower_bound);
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(0u32, n_values_per_thread).for_each(|i, scope| {
|
||||
taus_step_0(scope, state_0);
|
||||
taus_step_1(scope, state_1);
|
||||
taus_step_2(scope, state_2);
|
||||
lcg_step(scope, state_3);
|
||||
|
||||
let int_random = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, int_random = state_0 ^ state_1);
|
||||
gpu!(scope, int_random = int_random ^ state_2);
|
||||
gpu!(scope, int_random = int_random ^ state_3);
|
||||
|
||||
let float_random = scope.create_local(Elem::Float);
|
||||
cast_uint_to_float(scope, int_random, float_random);
|
||||
|
||||
let uniform = scope.create_local(item);
|
||||
gpu!(scope, uniform = float_random * scale);
|
||||
gpu!(scope, uniform += lower_bound);
|
||||
|
||||
let write_index = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, write_index = i * n_invocations);
|
||||
gpu!(scope, write_index += write_index_base);
|
||||
gpu!(scope, output[write_index] = uniform);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
fn args_length() -> usize {
|
||||
2
|
||||
}
|
||||
}
|
||||
|
||||
impl StaticKernelSource for UniformIntPrng {
|
||||
fn source() -> SourceTemplate {
|
||||
Prng::source().register("num_args", "2").register(
|
||||
"prng_loop",
|
||||
include_str!("../../template/prng/uniform_int_inner_loop.wgsl"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Pseudo-random generator for the uniform distribution.
|
||||
/// Pseudo-random generator with uniform distribution
|
||||
pub fn random_uniform<R: Runtime, E: JitElement, const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &R::Device,
|
||||
low: E,
|
||||
high: E,
|
||||
lower_bound: E,
|
||||
upper_bound: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let client = R::client(device);
|
||||
uniform_kernel(client, device, &shape, low, high)
|
||||
random(
|
||||
shape,
|
||||
device,
|
||||
Uniform {
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Pseudo-random generator for uniform distribution, based on
|
||||
/// another tensor.
|
||||
pub fn random_like_uniform<R: Runtime, E: JitElement, const D: usize>(
|
||||
tensor: &JitTensor<R, E, D>,
|
||||
low: E,
|
||||
high: E,
|
||||
lower_bound: E,
|
||||
upper_bound: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
uniform_kernel(
|
||||
tensor.client.clone(),
|
||||
random_uniform(
|
||||
tensor.shape.clone(),
|
||||
&tensor.device,
|
||||
&tensor.shape,
|
||||
low,
|
||||
high,
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
)
|
||||
}
|
||||
|
||||
/// Pseudo-random generator for uniform int distribution, based on
|
||||
/// another tensor's client, device and shape.
|
||||
pub fn random_like_uniform_int<R: Runtime, E: IntElement, const D: usize>(
|
||||
tensor: &JitTensor<R, E, D>,
|
||||
low: E,
|
||||
high: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
uniform_int_kernel(
|
||||
tensor.client.clone(),
|
||||
&tensor.device,
|
||||
&tensor.shape,
|
||||
low,
|
||||
high,
|
||||
)
|
||||
}
|
||||
|
||||
fn uniform_kernel<R: Runtime, E: JitElement, const D: usize>(
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
device: &R::Device,
|
||||
shape: &Shape<D>,
|
||||
low: E,
|
||||
high: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
const N_VALUES_PER_THREAD: usize = 128;
|
||||
|
||||
let output = empty_device(client.clone(), device.clone(), shape.clone());
|
||||
let info_handle = make_info_buffer::<R>(client.clone(), N_VALUES_PER_THREAD);
|
||||
let args_handle = make_args_buffer::<R, E>(client.clone(), &[low, high]);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<UniformPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(workgroup);
|
||||
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
&[&output.handle, &info_handle, &args_handle],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn uniform_int_kernel<R: Runtime, E: IntElement, const D: usize>(
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
device: &R::Device,
|
||||
shape: &Shape<D>,
|
||||
low: E,
|
||||
high: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
const N_VALUES_PER_THREAD: usize = 128;
|
||||
|
||||
let output = empty_device(client.clone(), device.clone(), shape.clone());
|
||||
let info_handle = make_info_buffer::<R>(client.clone(), N_VALUES_PER_THREAD);
|
||||
let args_handle = make_args_buffer::<R, E>(client.clone(), &[low, high]);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<UniformIntPrng, u32, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(workgroup);
|
||||
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
&[&output.handle, &info_handle, &args_handle],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
let prob = args[0];
|
||||
for (var i = 0u; i < n_values_per_thread; i++) {
|
||||
let write_index = write_index_base + i * n_threads_per_workgroup;
|
||||
|
||||
state[0u] = taus_step_0(state[0u]);
|
||||
state[1u] = taus_step_1(state[1u]);
|
||||
state[2u] = taus_step_2(state[2u]);
|
||||
state[3u] = lcg_step(state[3u]);
|
||||
let random_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
|
||||
let float = cast_u32_to_float(random_u32);
|
||||
|
||||
output[write_index] = cast_elem(float < prob);
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
fn box_muller_transform(unit_1: {{ elem }}, unit_2: {{ elem }}) -> array<{{ elem }}, 2> {
|
||||
let mean = args[0];
|
||||
let stdev = args[1];
|
||||
let coeff = stdev * sqrt(-2.0 * log(unit_1));
|
||||
|
||||
let pi = 3.141592653589793238;
|
||||
let trigo_arg = 2.0 * pi * unit_2;
|
||||
let cos_ = cos(trigo_arg);
|
||||
let sin_ = sin(trigo_arg);
|
||||
|
||||
return array(coeff * cos_ + mean, coeff * sin_ + mean);
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
for (var i = 0u; i < n_values_per_thread / 2u; i++) {
|
||||
let write_index_0 = write_index_base + (2u * i) * n_threads_per_workgroup;
|
||||
let write_index_1 = write_index_0 + n_threads_per_workgroup;
|
||||
|
||||
state[0u] = taus_step_0(state[0u]);
|
||||
state[1u] = taus_step_1(state[1u]);
|
||||
state[2u] = taus_step_2(state[2u]);
|
||||
state[3u] = lcg_step(state[3u]);
|
||||
let random_1_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
|
||||
let random_1 = cast_u32_to_float(random_1_u32);
|
||||
|
||||
state[0u] = taus_step_0(state[0u]);
|
||||
state[1u] = taus_step_1(state[1u]);
|
||||
state[2u] = taus_step_2(state[2u]);
|
||||
state[3u] = lcg_step(state[3u]);
|
||||
let random_2_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
|
||||
let random_2 = cast_u32_to_float(random_2_u32);
|
||||
|
||||
let transformed = box_muller_transform(random_1, random_2);
|
||||
|
||||
output[write_index_0] = transformed[0];
|
||||
output[write_index_1] = transformed[1];
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> info: array<u32, 5>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> args: array<{{ elem }}, {{ num_args }}>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(local_invocation_index) local_id: u32,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
) {
|
||||
// Thread preparation
|
||||
let n_threads_per_workgroup = {{ workgroup_size }}u;
|
||||
let workgroup_offset = (workgroup_id.x * num_workgroups.y + workgroup_id.y) * n_threads_per_workgroup;
|
||||
let n_values_per_thread = info[0u];
|
||||
let write_index_base = workgroup_offset * n_values_per_thread + local_id;
|
||||
|
||||
// Set state with unique seeds
|
||||
let thread_seed = 1000000007u * (workgroup_offset + local_id);
|
||||
var state: array<u32, 4u>;
|
||||
for (var i = 0u; i < 4u; i++) {
|
||||
state[i] = info[i + 1u] + thread_seed;
|
||||
}
|
||||
|
||||
// Creation of n_values_per_thread values, specific to the distribution
|
||||
{{ prng_loop }}
|
||||
}
|
||||
|
||||
fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
|
||||
let b = ((z << s1) ^ z) >> s2;
|
||||
return ((z & m) << s3) ^ b;
|
||||
}
|
||||
|
||||
fn taus_step_0(z: u32) -> u32 {
|
||||
return taus_step(z, 13u, 19u, 12u, 4294967294u);
|
||||
}
|
||||
|
||||
fn taus_step_1(z: u32) -> u32 {
|
||||
return taus_step(z, 2u, 25u, 4u, 4294967288u);
|
||||
}
|
||||
|
||||
fn taus_step_2(z: u32) -> u32 {
|
||||
return taus_step(z, 3u, 11u, 17u, 4294967280u);
|
||||
}
|
||||
|
||||
fn lcg_step(z: u32) -> u32 {
|
||||
return (1664525u * z + 1013904223u);
|
||||
}
|
||||
|
||||
fn cast_u32_to_float(number: u32) -> {{ elem }} {
|
||||
let tmp = 2.3283064365387e-10 * f32(number);
|
||||
return {{ elem }}(tmp);
|
||||
}
|
|
@ -1,16 +0,0 @@
|
|||
let low = args[0];
|
||||
let high = args[1];
|
||||
let scale = high - low;
|
||||
let bias = low;
|
||||
for (var i = 0u; i < n_values_per_thread; i++) {
|
||||
let write_index = write_index_base + i * n_threads_per_workgroup;
|
||||
|
||||
state[0u] = taus_step_0(state[0u]);
|
||||
state[1u] = taus_step_1(state[1u]);
|
||||
state[2u] = taus_step_2(state[2u]);
|
||||
state[3u] = lcg_step(state[3u]);
|
||||
let random_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
|
||||
let float = cast_u32_to_float(random_u32);
|
||||
|
||||
output[write_index] = float * scale + bias;
|
||||
}
|
|
@ -1,20 +0,0 @@
|
|||
let low = u32(args[0]);
|
||||
let high = u32(args[1]);
|
||||
let range = high - low;
|
||||
|
||||
let safe_range = max(range, 1u); // Ensure range is not zero to avoid division by 0 in % op
|
||||
|
||||
for (var i = 0u; i < n_values_per_thread; i++) {
|
||||
let write_index = write_index_base + i * n_threads_per_workgroup;
|
||||
|
||||
state[0u] = taus_step_0(state[0u]);
|
||||
state[1u] = taus_step_1(state[1u]);
|
||||
state[2u] = taus_step_2(state[2u]);
|
||||
state[3u] = lcg_step(state[3u]);
|
||||
let random_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
|
||||
|
||||
// Modulus operation to fit within the range
|
||||
let mod_result: u32 = u32(random_u32 % safe_range);
|
||||
|
||||
output[write_index] = u32(mod_result + low);
|
||||
}
|
|
@ -25,7 +25,7 @@ mod tests {
|
|||
#[serial]
|
||||
fn empirical_mean_close_to_expectation() {
|
||||
TestBackend::seed(0);
|
||||
let shape = [128, 128];
|
||||
let shape = [100, 100];
|
||||
let device = Default::default();
|
||||
let mean = 10.;
|
||||
let tensor =
|
||||
|
@ -44,6 +44,7 @@ mod tests {
|
|||
let s = 1.;
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 2>::random(shape.clone(), Distribution::Normal(mu, s), &device);
|
||||
|
||||
let stats = calculate_bin_stats(
|
||||
tensor.into_data().value,
|
||||
6,
|
||||
|
|
|
@ -19,7 +19,7 @@ mod reduction {
|
|||
));
|
||||
let val_ref = tensor_ref.sum_dim(1);
|
||||
|
||||
val_ref.into_data().assert_approx_eq(&val.into_data(), 3);
|
||||
val_ref.into_data().assert_approx_eq(&val.into_data(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -507,6 +507,26 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
|
|||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::BitwiseAnd(op) => wgsl::Instruction::BitwiseAnd {
|
||||
lhs: self.compile_variable(op.lhs),
|
||||
rhs: self.compile_variable(op.rhs),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::BitwiseXor(op) => wgsl::Instruction::BitwiseXor {
|
||||
lhs: self.compile_variable(op.lhs),
|
||||
rhs: self.compile_variable(op.rhs),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::ShiftLeft(op) => wgsl::Instruction::ShiftLeft {
|
||||
lhs: self.compile_variable(op.lhs),
|
||||
rhs: self.compile_variable(op.rhs),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::ShiftRight(op) => wgsl::Instruction::ShiftRight {
|
||||
lhs: self.compile_variable(op.lhs),
|
||||
rhs: self.compile_variable(op.rhs),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -185,6 +185,26 @@ pub enum Instruction {
|
|||
Loop {
|
||||
instructions: Vec<Instruction>,
|
||||
},
|
||||
BitwiseAnd {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
BitwiseXor {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
ShiftLeft {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
ShiftRight {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
}
|
||||
|
||||
impl Display for Instruction {
|
||||
|
@ -423,6 +443,18 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
|
|||
}
|
||||
f.write_str("}\n")
|
||||
}
|
||||
Instruction::BitwiseAnd { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} & {rhs};\n"))
|
||||
}
|
||||
Instruction::BitwiseXor { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} ^ {rhs};\n"))
|
||||
}
|
||||
Instruction::ShiftLeft { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} << {rhs};\n"))
|
||||
}
|
||||
Instruction::ShiftRight { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} >> {rhs};\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -145,6 +145,7 @@ where
|
|||
}
|
||||
|
||||
let source = kernel.source().complete();
|
||||
println!("{}", source);
|
||||
let pipeline = self.compile_source(&source);
|
||||
self.pipelines.insert(kernel_id.clone(), pipeline.clone());
|
||||
|
||||
|
|
Loading…
Reference in New Issue