Add documentation for custom `cubecl` kernels, update some outdated docs (#2404)

This commit is contained in:
Genna Wingert 2024-10-25 19:22:23 +02:00 committed by GitHub
parent fe86c10e1c
commit d5e8e3185c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 873 additions and 54 deletions

13
Cargo.lock generated
View File

@ -1683,6 +1683,19 @@ dependencies = [
"serde",
]
[[package]]
name = "custom-cubecl-kernel"
version = "0.15.0"
dependencies = [
"burn",
"burn-jit",
"bytemuck",
"cubecl",
"derive-new",
"log",
"serde",
]
[[package]]
name = "custom-image-dataset"
version = "0.15.0"

View File

@ -28,6 +28,7 @@
- [Quantization (Beta)](./quantization.md)
- [Advanced](./advanced/README.md)
- [Backend Extension](./advanced/backend-extension/README.md)
- [Custom `cubecl` Kernel](./advanced/backend-extension/custom-cubecl-kernel.md)
- [Custom WGPU Kernel](./advanced/backend-extension/custom-wgpu-kernel.md)
- [Custom Optimizer]()
- [WebAssembly]()

View File

@ -15,9 +15,9 @@ impression that Burn operates at a high level over the backend layer. However, m
explicit instead of being chosen via a compilation flag was a thoughtful design decision. This
explicitness does not imply that all backends must be identical; rather, it offers a great deal of
flexibility when composing backends. The autodifferentiation backend trait (see
[autodiff section](../../building-blocks/autodiff.md)) is an example of how the backend trait has been
extended to enable gradient computation with backpropagation. Furthermore, this design allows you to
create your own backend extension. To achieve this, you need to design your own backend trait
[autodiff section](../../building-blocks/autodiff.md)) is an example of how the backend trait has
been extended to enable gradient computation with backpropagation. Furthermore, this design allows
you to create your own backend extension. To achieve this, you need to design your own backend trait
specifying which functions should be supported.
```rust, ignore
@ -76,5 +76,7 @@ impl<E: TchElement> Backend for burn_autodiff::Autodiff<burn_tch::LibTorch<E>> {
}
```
The specificity of each implementation will be covered by the examples provided in this section.
Currently, we only have one example, but more are yet to come!
The specifics of each implementation will be covered by the examples provided in this section. The
`cubecl` compiler frontend is the recommended method of implementing custom kernels, since it
supports multiple backends, including `wgpu` and `CUDA`, and is the way first-party `burn` kernels
are written.

View File

@ -0,0 +1,376 @@
# Custom `cubecl` Kernel
In this section, you will learn how to create your own custom operation by writing your own kernel
with the cubecl compiler frontend. We will take the example of a common workflow in the deep
learning field, where we create a kernel to fuse multiple operations together. Note that `burn` does
this automatically, but a manual implementation might be more efficient in some cases. We will fuse
a matmul kernel followed by an addition and the ReLU activation function, which is commonly found in
various models. All the code can be found under the
[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/custom-cubecl-kernel).
## Custom Backend Trait
First, we need to determine the type signature of our newly created operation by defining our custom
backend traits. As we will use the associated type `TensorPrimitive` of the `Backend` trait, which
encapsulates the underlying tensor implementation of the backend, we will use a type alias to avoid
the ugly disambiguation with associated types.
```rust, ignore
/// We create our own Backend trait that extends the Burn backend trait.
pub trait Backend: burn::tensor::backend::Backend {
fn fused_matmul_add_relu(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
bias: FloatTensor<Self>,
) -> FloatTensor<Self>;
}
/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.
pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}
```
In our project, we can use these traits instead of the
`burn::tensor::backend::{Backend, AutodiffBackend}` traits provided by Burn. Burn's user APIs
typically make use of the `Tensor` struct rather than dealing directly with primitive tensor types.
Therefore, we can encapsulate our newly defined backend traits with functions that expose new
operations while maintaining a consistent API.
```rust, ignore
/// We define our custom implementation using the added function on our custom backend.
pub fn matmul_add_relu_custom<B: Backend>(
lhs: Tensor<B, 3>,
rhs: Tensor<B, 3>,
bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
let output = B::fused_matmul_add_relu(
lhs.into_primitive().tensor(),
rhs.into_primitive().tensor(),
bias.into_primitive().tensor(),
);
Tensor::from_primitive(TensorPrimitive::Float(output))
}
/// We define a reference implementation using basic tensor operations.
pub fn matmul_add_relu_reference<B: Backend>(
lhs: Tensor<B, 3>,
rhs: Tensor<B, 3>,
bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
let x = lhs.matmul(rhs) + bias;
activation::relu(x)
}
```
Note that we also provide a reference implementation for testing purposes, which allows us to easily
validate our new implementation. While not mandatory, having a reference implementation can be
valuable, especially in projects where creating a reference implementation solely using basic tensor
operations is feasible.
## Forward Kernel
Now, let's proceed to write the fused kernel using the `cubecl` compiler frontend. To keep things
simple, we'll create a straightforward matmul kernel without employing any intricate techniques. We
won't delve into the details of the `cube` macro, but if you're interested to learn more, please see
[`cubecl` Book](https://github.com/tracel-ai/cubecl/tree/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/cubecl-book).
the The actual matmul, add and relu computations are found at the end, after an extensive prelude
that serves to correctly map each compute unit to the data it is responsible for, with support for
batches.
```rust, ignore
use cubecl::{cube, prelude::*};
#[cube(launch)]
pub fn fused_matmul_add_relu_kernel<F: Float>(
lhs: &Tensor<F>,
rhs: &Tensor<F>,
bias: &Tensor<F>,
output: &mut Tensor<F>,
) {
let row = ABSOLUTE_POS_X;
let col = ABSOLUTE_POS_Y;
let batch = ABSOLUTE_POS_Z;
let n_rows = output.shape(output.rank() - 2);
let n_cols = output.shape(output.rank() - 1);
let dim_k = rhs.shape(rhs.rank() - 1);
if row >= n_rows || col >= n_cols {
return;
}
let offset_output = batch * n_rows * n_cols;
let mut offset_lhs = 0;
let mut offset_rhs = 0;
let batch_dims = output.rank() - 2;
for dim in 0..batch_dims {
offset_lhs += offset_output / output.stride(dim) % lhs.shape(dim) * lhs.stride(dim);
offset_rhs += offset_output / output.stride(dim) % rhs.shape(dim) * rhs.stride(dim);
}
let mut sum = F::new(0.0);
for k in 0..dim_k {
let lhs_index = row * dim_k + k;
let rhs_index = k * n_cols + col;
sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];
}
let out_index = row * n_cols + col;
let index = offset_output + out_index;
output[index] = F::max(sum + bias[index], F::new(0.0));
}
```
Now, let's move on to the next step, which involves implementing the remaining code to launch the
kernel. We'll go into implementing our custom backend trait for the generic JIT backend. This
automatically implements the trait for `burn-cuda`, `burn-wgpu` as well as fusion.
```rust, ignore
/// Implement our custom backend trait for the generic `JitBackend`.
impl<R: JitRuntime, F: FloatElement, I: IntElement> Backend for JitBackend<R, F, I> {
fn fused_matmul_add_relu(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
bias: FloatTensor<Self>,
) -> FloatTensor<Self> {
// Define cube dim, hardcoded for simplicity.
let cube_dim = CubeDim { x: 16, y: 16, z: 1 };
lhs.assert_is_on_same_device(&rhs);
lhs.assert_is_on_same_device(&bias);
// For simplicity, make sure each tensor is continuous.
let lhs = into_contiguous(lhs);
let rhs = into_contiguous(rhs);
let bias = into_contiguous(bias);
// Get the matmul relevant shapes.
let ndims = lhs.shape.num_dims();
let num_rows = lhs.shape.dims[ndims - 2];
let num_cols = rhs.shape.dims[ndims - 1];
// Compute shape of output, while tracking number of batches.
let mut num_batches = 1;
let mut shape_out = vec![0; ndims];
for i in shape_out.clone().into_iter().take(ndims - 2) {
shape_out[i] = usize::max(lhs.shape.dims[i], rhs.shape.dims[i]);
num_batches *= shape_out[i];
}
shape_out[ndims - 2] = num_rows;
shape_out[ndims - 1] = num_cols;
let shape_out = Shape::from(shape_out);
// Create a buffer for the output tensor.
let buffer = lhs
.client
.empty(shape_out.num_elements() * core::mem::size_of::<F>());
// Create the output tensor primitive.
let output =
JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
// Declare the wgsl workgroup with the number of cubes in x, y and z.
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
let cube_count =
CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
// Execute lazily the kernel with the launch information and the given buffers. For
// simplicity, no vectorization is performed
fused_matmul_add_relu_kernel::launch::<F, R>(
&lhs.client,
cube_count,
cube_dim,
lhs.as_tensor_arg(1),
rhs.as_tensor_arg(1),
bias.as_tensor_arg(1),
output.as_tensor_arg(1),
);
// Return the output tensor.
output
}
}
```
In the preceding code block, we demonstrated how to launch the kernel that modifies the correct
buffer. It's important to note that Rust's mutability safety doesn't apply here; the context has the
capability to execute any mutable operation on any buffer. While this isn't a problem in the
previous scenario where we only modify the newly created output buffer, it is wise to keep this in
mind.
## Backward
Now that the custom backend trait is implemented for the JIT backend, you can use it to invoke the
`matmul_add_relu_custom` function. However, calculating gradients is not yet possible at this stage.
If your use case does not extend beyond inference, there is no need to implement any of the
following code.
For the backward pass, we will leverage the backend implementation from `burn-autodiff`, which is
actually generic over the backend. Instead of crafting our own `cubecl` kernel for the backward
pass, we will use our fused kernel only for the forward pass, and compute the gradient using basic
operations.
```rust, ignore
// Implement our custom backend trait for any backend that also implements our custom backend trait.
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
fn fused_matmul_add_relu(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
bias: FloatTensor<Self>,
) -> FloatTensor<Self> {
// Create our zero-sized type that will implement the Backward trait.
#[derive(Debug)]
struct FusedMatmulAddReluBackward;
// Implement the backward trait for the given backend B, the node gradient
// with three other gradients to calculate (lhs, rhs, and bias).
impl<B: Backend> Backward<B, 3> for FusedMatmulAddReluBackward {
// Our state that we must build during the forward pass to compute the backward pass.
//
// Note that we could improve the performance further by only keeping the state of
// tensors that are tracked, improving memory management, but for simplicity, we avoid
// that part.
type State = (NodeID, NodeID, FloatTensor<B>, Shape);
fn backward(
self,
ops: Ops<Self::State, 3>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
// Get the nodes of each variable.
let [node_lhs, node_rhs, node_bias] = ops.parents;
// Fetch the gradient for the current node.
let grad = grads.consume::<B>(&ops.node);
// Set our state.
let (lhs_state, rhs_state, output, shape_bias) = ops.state;
let lhs = checkpointer.retrieve_node_output(lhs_state);
let rhs = checkpointer.retrieve_node_output(rhs_state);
// Fetch shapes of our tensor to support broadcasting.
let shape_lhs = B::float_shape(&lhs);
let shape_rhs = B::float_shape(&rhs);
// Compute the gradient of the output using the already existing `relu_backward`
// function in the basic Burn backend trait.
let grad_output = B::relu_backward(output, grad);
// Compute the lhs gradient, which is the derivative of matmul with support for
// broadcasting.
let grad_lhs = broadcast_shape::<B>(
B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),
&shape_lhs,
);
// Compute the rhs gradient, which is the derivative of matmul with support for
// broadcasting.
let grad_rhs = broadcast_shape::<B>(
B::float_matmul(B::float_transpose(lhs), grad_output.clone()),
&shape_rhs,
);
// The add derivative is only 1, so we just need to support broadcasting to
// compute the bias gradient.
let grad_bias = broadcast_shape::<B>(grad_output, &shape_bias);
// Register the gradient for each variable based on whether they are marked as
// `tracked`.
if let Some(node) = node_bias {
grads.register::<B>(node.id, grad_bias);
}
if let Some(node) = node_lhs {
grads.register::<B>(node.id, grad_lhs);
}
if let Some(node) = node_rhs {
grads.register::<B>(node.id, grad_rhs);
}
}
}
// Prepare a stateful operation with each variable node and corresponding graph.
//
// Each node can be fetched with `ops.parents` in the same order as defined here.
match FusedMatmulAddReluBackward
.prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])
// Marks the operation as compute bound, meaning it will save its
// state instead of recomputing itself during checkpointing
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
// When at least one node is tracked, we should register our backward step.
// The state consists of what will be needed for this operation's backward pass.
// Since we need the parents' outputs, we must checkpoint their ids to retrieve
// their node output at the beginning of the backward pass. We can also save
// utilitary data such as the bias shape. If we also need this operation's output,
// we can either save it in the state or recompute it.
// during the backward pass. Here we choose to save it in the state because it's a
// compute bound operation.
let lhs_state = prep.checkpoint(&lhs);
let rhs_state = prep.checkpoint(&rhs);
let bias_shape = B::float_shape(&bias.primitive);
let output = B::fused_matmul_add_relu(
lhs.primitive.clone(),
rhs.primitive.clone(),
bias.primitive,
);
let state = (lhs_state, rhs_state, output.clone(), bias_shape);
prep.finish(state, output)
}
OpsKind::UnTracked(prep) => {
// When no node is tracked, we can just compute the original operation without
// keeping any state.
let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive);
prep.finish(output)
}
}
}
}
```
The previous code is self-documented to make it clearer, but here is what it does in summary:
We define `fused_matmul_add_relu` within `Autodiff<B>`, allowing any autodiff-decorated backend to
benefit from our implementation. In an autodiff-decorated backend, the forward pass must still be
implemented. This is achieved using a comprehensive match statement block where computation is
delegated to the inner backend, while keeping track of a state. The state comprises any information
relevant to the backward pass, such as input and output tensors, along with the bias shape. When an
operation isn't tracked (meaning there won't be a backward pass for this specific operation in the
graph), storing a state becomes unnecessary, and we simply perform the forward computation.
The backward pass uses the gradient obtained from the preceding node in the computation graph. It
calculates the derivatives for `relu` (`relu_backward`), add (no operation is required here, as the
derivative is one), and `matmul` (another `matmul` with transposed inputs). This results in
gradients for both input tensors and the bias, which are registered for consumption by subsequent
operation nodes.
The only remaining part is to implement our autodiff-decorated backend trait for our JIT Backend.
```rust, ignore
impl<R: JitRuntime, F: FloatElement, I: IntElement> AutodiffBackend
for Autodiff<JitBackend<R, F, I>>
{
}
```
## Conclusion
In this guide, we've implemented a fused kernel using the `cubecl` compiler frontend, enabling
execution on any GPU and any `cubecl` backend. By delving into the inner workings of both the JIT
backend and the autodiff backend, we've gained a deeper understanding of these systems.
While extending a backend may be harder than working with straightforward tensors, the benefits can
be worth it. This approach enables the crafting of custom models with greater control over
execution, which can potentially greatly enhance the performance of your models.
As we conclude this guide, we hope that you have gained insights into Burn's world of backend
extensions, and that it will help you to unleash the full potential of your projects.

View File

@ -2,9 +2,10 @@
In this section, you will learn how to create your own custom operation by writing your own kernel
with the WGPU backend. We will take the example of a common workflow in the deep learning field,
where we create a kernel to fuse multiple operations together. We will fuse a matmul kernel followed
by an addition and the ReLU activation function, which is commonly found in various models. All the
code can be found under the
where we create a kernel to fuse multiple operations together. Note that `burn` does this
automatically, but a manual implementation might be more efficient in some cases. We will fuse a
matmul kernel followed by an addition and the ReLU activation function, which is commonly found in
various models. All the code can be found under the
[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/custom-wgpu-kernel).
## Custom Backend Trait
@ -15,9 +16,6 @@ encapsulates the underlying tensor implementation of the backend, we will use a
the ugly disambiguation with associated types.
```rust, ignore
/// We use a type alias for better readability.
pub type FloatTensor<B, const D: usize> = <B as burn::tensor::backend::Backend>::TensorPrimitive<D>;
/// We create our own Backend trait that extends the Burn backend trait.
pub trait Backend: burn::tensor::backend::Backend {
fn fused_matmul_add_relu(
@ -78,7 +76,7 @@ we'll create a straightforward matmul kernel without employing any intricate tec
won't delve into the details of the WGSL syntax, as it falls beyond the scope of this guide, we
still provide the implementation below for readers who are curious. The actual matmul, add and relu
computations are found at the end, after an extensive overhead whose use is to correctly map each
thread to the data it is responsible of, with support for batches.
compute unit to the data it is responsible of, with support for batches.
```wgsl, ignore
@group(0)
@ -451,7 +449,5 @@ While extending a backend may be harder than working with straightforward tensor
be worth it. This approach enables the crafting of custom models with greater control over
execution, which can potentially greatly enhance the performance of your models.
It is worth noting that while the manual fusion of operations can be valuable, our future plans
include the development of a backend extension that will automate this process. As we conclude this
guide, we hope that you have gained insights into Burn's world of backend extensions, and that it
will help you to unleash the full potential of your projects.
As we conclude this guide, we hope that you have gained insights into Burn's world of backend
extensions, and that it will help you to unleash the full potential of your projects.

View File

@ -139,14 +139,14 @@ seeing what completions are suggested will take you far. If you are having troub
to do it from the docs for that backend,
[try searching github for relevant function calls](https://docs.github.com/en/search-github/github-code-search/understanding-github-code-search-syntax).
## Adding the Op to fusion, JIT and wgpu backends
## Adding the Op to fusion, JIT and cubecl backends
Adding an operator to these backends can be fairly straightforward, though due to what these
backends are for, involves a bit more indirection. Fusion and jit, like autodiff, are not target
backends as much as backends that enable certain functionality for other backends, in this case
kernel fusion or just-in-time compilation (only available for `burn-wgpu` backend at the moment).
Adding the operator won't involve doing any calculation, you'll just be describing how the generated
code should look. Most of this can be copy/pasted/adjusted from other functions.
kernel fusion or just-in-time compilation. Adding the operator won't involve doing any calculation,
you'll just be describing how the generated code should look. Most of this can be
copy/pasted/adjusted from other functions.
Here's how powf was added to `burn-fusion`:
@ -157,41 +157,47 @@ Here's how powf was added to `burn-fusion`:
3. Added powf to the implementations of `NumericOperationDescription` enum under
[crates/burn-fusion/src/stream/context.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-fusion/src/stream/context.rs#L771)
The way wgpu handles tensor-scalar operations is by transforming both into a sequence of vectorized
scalar operations. Since powf already existed in burn-wgpu, it was pretty easy to reuse the existing
implementation for the situation where both sides of the operation were tensors. The `burn-wgpu`
crate is primarily concerned with how the operation is compiled and executed by the gpu. The actual
implementation is defined in `burn-jit`.
The way `cubecl` handles tensor-scalar operations is by transforming both into a sequence of
vectorized scalar operations. Since powf already existed in `cubecl`, it was pretty easy to reuse
the existing implementation for the situation where both sides of the operation were tensors. The
`cubecl` crate is primarily concerned with how the operation is compiled and executed by the gpu.
The actual implementation is defined in `burn-jit`.
Here is where code was added for powf in `burn-jit` and `burn-wgpu`:
Here is where code was added for powf in `burn-jit` and `cubecl`:
1. to the implementation of
[`FloatTensorOps` under `crates/burn-jit/src/ops/float_ops.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/ops/float_ops.rs#L491)
[`FloatTensorOps` under `crates/burn-jit/src/ops/float_ops.rs`](https://github.com/tracel-ai/burn/blob/3b51c26958128502d60fb35029c43d9b686b816c/crates/burn-jit/src/ops/float_ops.rs#L410)
2. the function being called was added to
[crates/burn-jit/src/ops/numeric.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/ops/numeric.rs#L229)
[crates/burn-jit/src/ops/numeric.rs](https://github.com/tracel-ai/burn/blob/3b51c26958128502d60fb35029c43d9b686b816c/crates/burn-jit/src/ops/numeric.rs#L147)
3. the operator was defined in
[`crates/burn-jit/src/codegen/dialect/gpu/operation.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/codegen/dialect/gpu/operation.rs#L37)
4. the vectorization was added to
[`crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs#L55)
5. how the operation looks to the gpu was added to
[`crates/burn-jit/src/fusion/tracing/builder.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/fusion/tracing/builder.rs#L279)
6. the mapping between the gpu operation and the WGSL instruction was added to
[`crates/burn-wgpu/src/compiler/wgsl/compiler.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-wgpu/src/compiler/wgsl/compiler.rs#L455)
7. the WGSL instruction itself was added to the
[instruction op enum in `crates/burn-wgpu/src/compiler/wgsl/instructions.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-wgpu/src/compiler/wgsl/instructions.rs#L103),
[`cubecl-core/src/ir/operation.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-core/src/ir/operation.rs#L68)
4. how the operation looks to the gpu was added to
[`crates/burn-jit/src/fusion/on_write/ir.rs`](https://github.com/tracel-ai/burn/blob/3b51c26958128502d60fb35029c43d9b686b816c/crates/burn-jit/src/fusion/on_write/ir.rs#L52)
5. the mappings between the gpu operation and the CPP, WGSL and SPIR-V instructions were added to
[`cubecl-cpp/src/shared/base.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-cpp/src/shared/base.rs#L456),
[`cubecl-wgpu/src/compiler/wgsl/compiler.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs#L652)
and
[`cubecl-spirv/src/instruction.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-spirv/src/instruction.rs#L408)
6. the instructions themselves were added for WGSL to
[instruction op enum in `cubecl-wgpu/src/compiler/wgsl/instructions.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs#L124),
and the actual
[instruction in wgsl here](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-wgpu/src/compiler/wgsl/instructions.rs#L273)
[instruction in wgsl here](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs#L547-L555),
for CPP in the enum here
[`cubecl-cpp/src/shared/instruction.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-cpp/src/shared/instruction.rs#L127)
and the actual instruction here
[`cubecl-cpp/src/shared/binary.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-cpp/src/shared/binary.rs#L137)
We needed to generate some custom WGSL code for powf, primarily due to issues with proper case
handling of the wgsl pow function, like 0 to the 0 power being 1, and any negative number to an even
power being positive. We reused as much as the existing logic as possible, and then branched at the
last point based off the var type of the rhs.
[See here](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-wgpu/src/compiler/wgsl/compiler.rs#L596).
For most operations, you shouldn't need to add to `crates/burn-wgpu/src/compiler/wgsl/extension.rs`
We needed to generate some custom WGSL code for powf in WGSL, primarily due to issues with proper
case handling of the wgsl pow function, like 0 to the 0 power being 1, and any negative number to an
even power being positive. We reused as much as the existing logic as possible, and then branched at
the last point based off the var type of the rhs.
[See here](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs#L911).
For most operations, you shouldn't need to add to `cubecl-wgpu/src/compiler/wgsl/extension.rs`
unless the operation isn't native to WGSL.
For functions that need a complex kernel without a direct mapping to a base instruction, it is not
as straightforward. An easier manner of implementing them is underway.
For functions that need a complex kernel without a direct mapping to a base instruction, simply use
the `cube` macro (see
[the `cubecl` book](https://github.com/tracel-ai/cubecl/tree/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/cubecl-book)).
## Adding the Op to burn-import

View File

@ -146,6 +146,7 @@ hashbrown = { workspace = true, features = ["serde"] } # no_std compatible
flate2 = { workspace = true, optional = true }
serde = { workspace = true, features = ["derive"] }
ahash = { workspace = true }
bincode = { workspace = true }
half = { workspace = true }
num-traits = { workspace = true }
@ -154,7 +155,6 @@ rmp-serde = { workspace = true, optional = true }
serde_json = { workspace = true, features = ["alloc"] } #Default enables std
spin = { workspace = true } # Using in place of use std::sync::Mutex when std is disabled
thiserror = { workspace = true, optional = true }
ahash = { workspace = true }
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic-util = { workspace = true }

View File

@ -1,5 +1,30 @@
# Burn-Cuda
# Burn CUDA Backend
This backend is still a work in progress and not ready to be used.
[Burn](https://github.com/tracel-ai/burn) CUDA backend
See #1525
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-cuda.svg)](https://crates.io/crates/burn-cuda)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-cuda/blob/master/README.md)
This crate provides a CUDA backend for [Burn](https://github.com/tracel-ai/burn) using the
[cubecl](https://github.com/tracel-ai/cubecl.git) and [cudarc](https://github.com/coreylowman/cudarc.git)
crates.
## Usage Example
```rust
#[cfg(feature = "cuda")]
mod cuda {
use burn_autodiff::Autodiff;
use burn_cuda::{Cuda, CudaDevice};
use mnist::training;
pub fn run() {
let device = CudaDevice::default();
training::run::<Autodiff<Cuda<f32, i32>>>(device);
}
}
```
## Dependencies
Requires CUDA 12.x to be installed and on the `PATH`.

View File

@ -0,0 +1,25 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
edition.workspace = true
license.workspace = true
name = "custom-cubecl-kernel"
publish = false
version.workspace = true
[dependencies]
burn = { path = "../../crates/burn", default-features = false, features = [
"autodiff",
"wgpu",
"autotune",
"template",
] }
burn-jit = { path = "../../crates/burn-jit" }
cubecl = { workspace = true, features = ["wgpu"] }
# Serialization
log = { workspace = true }
serde = { workspace = true, features = ["std", "derive"] }
# Wgpu internal dependencies
bytemuck = { workspace = true }
derive-new = { workspace = true }

View File

@ -0,0 +1,79 @@
use burn::{
backend::wgpu::WgpuRuntime,
tensor::{Distribution, Tensor},
};
use custom_cubecl_kernel::{
matmul_add_relu_custom, matmul_add_relu_reference, AutodiffBackend, Backend,
};
fn inference<B: Backend>(device: &B::Device) {
let lhs = Tensor::<B, 3>::random([1, 32, 32], Distribution::Default, device);
let rhs = Tensor::random([32, 32, 32], Distribution::Default, device);
let bias = Tensor::random([32, 32, 32], Distribution::Default, device);
let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone())
.into_data()
.convert::<f32>();
let custom = matmul_add_relu_custom(lhs, rhs, bias)
.into_data()
.convert::<f32>();
reference.assert_approx_eq(&custom, 3);
println!("Both reference and the custom fused kernel have the same output");
}
fn autodiff<B: AutodiffBackend>(device: &B::Device) {
let lhs = Tensor::<B, 3>::random([1, 32, 32], Distribution::Default, device).require_grad();
let rhs = Tensor::random([32, 32, 32], Distribution::Default, device).require_grad();
let bias = Tensor::random([32, 32, 32], Distribution::Default, device).require_grad();
let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone());
let mut gradients = reference.backward();
let lhs_grad_ref = lhs.grad_remove(&mut gradients).unwrap();
let rhs_grad_ref = rhs.grad_remove(&mut gradients).unwrap();
let bias_grad_ref = bias.grad_remove(&mut gradients).unwrap();
let lhs = lhs.detach();
let rhs = rhs.detach();
let bias = bias.detach();
let custom = matmul_add_relu_custom(lhs.clone(), rhs.clone(), bias.clone());
let mut gradients = custom.backward();
let lhs_grad_custom = lhs.grad_remove(&mut gradients).unwrap();
let rhs_grad_custom = rhs.grad_remove(&mut gradients).unwrap();
let bias_grad_custom = bias.grad_remove(&mut gradients).unwrap();
lhs_grad_ref
.into_data()
.convert::<B::FloatElem>()
.assert_approx_eq(&lhs_grad_custom.into_data().convert::<B::FloatElem>(), 3);
println!("Both reference and the custom fused kernel have the same lhs gradient");
rhs_grad_ref
.into_data()
.convert::<f32>()
.assert_approx_eq(&rhs_grad_custom.into_data().convert::<B::FloatElem>(), 3);
println!("Both reference and the custom fused kernel have the same rhs gradient");
bias_grad_ref
.into_data()
.convert::<f32>()
.assert_approx_eq(&bias_grad_custom.into_data().convert::<B::FloatElem>(), 3);
println!("Both reference and the custom fused kernel have the same bias gradient");
}
fn main() {
type MyBackend = burn::backend::wgpu::JitBackend<WgpuRuntime, f32, i32>;
type MyAutodiffBackend = burn::backend::Autodiff<MyBackend>;
let device = Default::default();
inference::<MyBackend>(&device);
autodiff::<MyAutodiffBackend>(&device);
}

View File

@ -0,0 +1,137 @@
use crate::FloatTensor;
use super::{AutodiffBackend, Backend};
use burn::{
backend::autodiff::{
checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
grads::Gradients,
ops::{broadcast_shape, Backward, Ops, OpsKind},
Autodiff, NodeID,
},
tensor::Shape,
};
use burn_jit::{FloatElement, IntElement, JitBackend, JitRuntime};
impl<R: JitRuntime, F: FloatElement, I: IntElement> AutodiffBackend
for Autodiff<JitBackend<R, F, I>>
{
}
// Implement our custom backend trait for any backend that also implements our custom backend trait.
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
fn fused_matmul_add_relu(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
bias: FloatTensor<Self>,
) -> FloatTensor<Self> {
// Create our zero-sized type that will implement the Backward trait.
#[derive(Debug)]
struct FusedMatmulAddReluBackward;
// Implement the backward trait for the given backend B, the node gradient
// with three other gradients to calculate (lhs, rhs, and bias).
impl<B: Backend> Backward<B, 3> for FusedMatmulAddReluBackward {
// Our state that we must build during the forward pass to compute the backward pass.
//
// Note that we could improve the performance further by only keeping the state of
// tensors that are tracked, improving memory management, but for simplicity, we avoid
// that part.
type State = (NodeID, NodeID, FloatTensor<B>, Shape);
fn backward(
self,
ops: Ops<Self::State, 3>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
// Get the nodes of each variable.
let [node_lhs, node_rhs, node_bias] = ops.parents;
// Fetch the gradient for the current node.
let grad = grads.consume::<B>(&ops.node);
// Set our state.
let (lhs_state, rhs_state, output, shape_bias) = ops.state;
let lhs = checkpointer.retrieve_node_output(lhs_state);
let rhs = checkpointer.retrieve_node_output(rhs_state);
// Fetch shapes of our tensor to support broadcasting.
let shape_lhs = B::float_shape(&lhs);
let shape_rhs = B::float_shape(&rhs);
// Compute the gradient of the output using the already existing `relu_backward`
// function in the basic Burn backend trait.
let grad_output = B::relu_backward(output, grad);
// Compute the lhs gradient, which is the derivative of matmul with support for
// broadcasting.
let grad_lhs = broadcast_shape::<B>(
B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),
&shape_lhs,
);
// Compute the rhs gradient, which is the derivative of matmul with support for
// broadcasting.
let grad_rhs = broadcast_shape::<B>(
B::float_matmul(B::float_transpose(lhs), grad_output.clone()),
&shape_rhs,
);
// The add derivative is only 1, so we just need to support broadcasting to
// compute the bias gradient.
let grad_bias = broadcast_shape::<B>(grad_output, &shape_bias);
// Register the gradient for each variable based on whether they are marked as
// `tracked`.
if let Some(node) = node_bias {
grads.register::<B>(node.id, grad_bias);
}
if let Some(node) = node_lhs {
grads.register::<B>(node.id, grad_lhs);
}
if let Some(node) = node_rhs {
grads.register::<B>(node.id, grad_rhs);
}
}
}
// Prepare a stateful operation with each variable node and corresponding graph.
//
// Each node can be fetched with `ops.parents` in the same order as defined here.
match FusedMatmulAddReluBackward
.prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])
// Marks the operation as compute bound, meaning it will save its
// state instead of recomputing itself during checkpointing
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
// When at least one node is tracked, we should register our backward step.
// The state consists of what will be needed for this operation's backward pass.
// Since we need the parents' outputs, we must checkpoint their ids to retrieve
// their node output at the beginning of the backward pass. We can also save
// utilitary data such as the bias shape. If we also need this operation's output,
// we can either save it in the state or recompute it.
// during the backward pass. Here we choose to save it in the state because it's a
// compute bound operation.
let lhs_state = prep.checkpoint(&lhs);
let rhs_state = prep.checkpoint(&rhs);
let bias_shape = B::float_shape(&bias.primitive);
let output = B::fused_matmul_add_relu(
lhs.primitive.clone(),
rhs.primitive.clone(),
bias.primitive,
);
let state = (lhs_state, rhs_state, output.clone(), bias_shape);
prep.finish(state, output)
}
OpsKind::UnTracked(prep) => {
// When no node is tracked, we can just compute the original operation without
// keeping any state.
let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive);
prep.finish(output)
}
}
}
}

View File

@ -0,0 +1,74 @@
use crate::{kernel::fused_matmul_add_relu_kernel, FloatTensor};
use super::Backend;
use burn::tensor::Shape;
use burn_jit::{
kernel::into_contiguous, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime,
};
use cubecl::{CubeCount, CubeDim};
/// Implement our custom backend trait for the generic `JitBackend`.
impl<R: JitRuntime, F: FloatElement, I: IntElement> Backend for JitBackend<R, F, I> {
fn fused_matmul_add_relu(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
bias: FloatTensor<Self>,
) -> FloatTensor<Self> {
// Define cube dim, hardcoded for simplicity.
let cube_dim = CubeDim { x: 16, y: 16, z: 1 };
lhs.assert_is_on_same_device(&rhs);
lhs.assert_is_on_same_device(&bias);
// For simplicity, make sure each tensor is continuous.
let lhs = into_contiguous(lhs);
let rhs = into_contiguous(rhs);
let bias = into_contiguous(bias);
// Get the matmul relevant shapes.
let ndims = lhs.shape.num_dims();
let num_rows = lhs.shape.dims[ndims - 2];
let num_cols = rhs.shape.dims[ndims - 1];
// Compute shape of output, while tracking number of batches.
let mut num_batches = 1;
let mut shape_out = vec![0; ndims];
for i in shape_out.clone().into_iter().take(ndims - 2) {
shape_out[i] = usize::max(lhs.shape.dims[i], rhs.shape.dims[i]);
num_batches *= shape_out[i];
}
shape_out[ndims - 2] = num_rows;
shape_out[ndims - 1] = num_cols;
let shape_out = Shape::from(shape_out);
// Create a buffer for the output tensor.
let buffer = lhs
.client
.empty(shape_out.num_elements() * core::mem::size_of::<F>());
// Create the output tensor primitive.
let output =
JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
// Declare the wgsl workgroup with the number of cubes in x, y and z.
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
let cube_count =
CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
// Execute lazily the kernel with the launch information and the given buffers. For
// simplicity, no vectorization is performed
fused_matmul_add_relu_kernel::launch::<F, R>(
&lhs.client,
cube_count,
cube_dim,
lhs.as_tensor_arg(1),
rhs.as_tensor_arg(1),
bias.as_tensor_arg(1),
output.as_tensor_arg(1),
);
// Return the output tensor.
output
}
}

View File

@ -0,0 +1,45 @@
use cubecl::{cube, prelude::*};
/// Declare a custom kernel that gets compiled to `wgpu`/`CUDA`
#[cube(launch)]
pub fn fused_matmul_add_relu_kernel<F: Float>(
lhs: &Tensor<F>,
rhs: &Tensor<F>,
bias: &Tensor<F>,
output: &mut Tensor<F>,
) {
let row = ABSOLUTE_POS_X;
let col = ABSOLUTE_POS_Y;
let batch = ABSOLUTE_POS_Z;
let n_rows = output.shape(output.rank() - 2);
let n_cols = output.shape(output.rank() - 1);
let dim_k = rhs.shape(rhs.rank() - 1);
if row >= n_rows || col >= n_cols {
return;
}
let offset_output = batch * n_rows * n_cols;
let mut offset_lhs = 0;
let mut offset_rhs = 0;
let batch_dims = output.rank() - 2;
for dim in 0..batch_dims {
offset_lhs += offset_output / output.stride(dim) % lhs.shape(dim) * lhs.stride(dim);
offset_rhs += offset_output / output.stride(dim) % rhs.shape(dim) * rhs.stride(dim);
}
let mut sum = F::new(0.0);
for k in 0..dim_k {
let lhs_index = row * dim_k + k;
let rhs_index = k * n_cols + col;
sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];
}
let out_index = row * n_cols + col;
let index = offset_output + out_index;
output[index] = F::max(sum + bias[index], F::new(0.0));
}

View File

@ -0,0 +1,43 @@
mod backward;
mod forward;
mod kernel;
use burn::tensor::{activation, ops::FloatTensor, Tensor, TensorPrimitive};
/// We create our own Backend trait that extends the Burn backend trait.
pub trait Backend: burn::tensor::backend::Backend {
fn fused_matmul_add_relu(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
bias: FloatTensor<Self>,
) -> FloatTensor<Self>;
}
/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.
pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}
/// We define our custom implementation using the added function on our custom backend.
pub fn matmul_add_relu_custom<B: Backend>(
lhs: Tensor<B, 3>,
rhs: Tensor<B, 3>,
bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
let output = B::fused_matmul_add_relu(
lhs.into_primitive().tensor(),
rhs.into_primitive().tensor(),
bias.into_primitive().tensor(),
);
Tensor::from_primitive(TensorPrimitive::Float(output))
}
/// We define a reference implementation using basic tensor operations.
pub fn matmul_add_relu_reference<B: Backend>(
lhs: Tensor<B, 3>,
rhs: Tensor<B, 3>,
bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
let x = lhs.matmul(rhs) + bias;
activation::relu(x)
}

View File

@ -1,10 +1,7 @@
mod backward;
mod forward;
use burn::tensor::{activation, Tensor, TensorPrimitive};
/// We use a type alias for better readability.
pub type FloatTensor<B> = <B as burn::tensor::backend::Backend>::FloatTensorPrimitive;
use burn::tensor::{activation, ops::FloatTensor, Tensor, TensorPrimitive};
/// We create our own Backend trait that extends the Burn backend trait.
pub trait Backend: burn::tensor::backend::Backend {