mirror of https://github.com/tracel-ai/burn.git
Add documentation for custom `cubecl` kernels, update some outdated docs (#2404)
This commit is contained in:
parent
fe86c10e1c
commit
d5e8e3185c
|
@ -1683,6 +1683,19 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "custom-cubecl-kernel"
|
||||
version = "0.15.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"burn-jit",
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"log",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "custom-image-dataset"
|
||||
version = "0.15.0"
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
- [Quantization (Beta)](./quantization.md)
|
||||
- [Advanced](./advanced/README.md)
|
||||
- [Backend Extension](./advanced/backend-extension/README.md)
|
||||
- [Custom `cubecl` Kernel](./advanced/backend-extension/custom-cubecl-kernel.md)
|
||||
- [Custom WGPU Kernel](./advanced/backend-extension/custom-wgpu-kernel.md)
|
||||
- [Custom Optimizer]()
|
||||
- [WebAssembly]()
|
||||
|
|
|
@ -15,9 +15,9 @@ impression that Burn operates at a high level over the backend layer. However, m
|
|||
explicit instead of being chosen via a compilation flag was a thoughtful design decision. This
|
||||
explicitness does not imply that all backends must be identical; rather, it offers a great deal of
|
||||
flexibility when composing backends. The autodifferentiation backend trait (see
|
||||
[autodiff section](../../building-blocks/autodiff.md)) is an example of how the backend trait has been
|
||||
extended to enable gradient computation with backpropagation. Furthermore, this design allows you to
|
||||
create your own backend extension. To achieve this, you need to design your own backend trait
|
||||
[autodiff section](../../building-blocks/autodiff.md)) is an example of how the backend trait has
|
||||
been extended to enable gradient computation with backpropagation. Furthermore, this design allows
|
||||
you to create your own backend extension. To achieve this, you need to design your own backend trait
|
||||
specifying which functions should be supported.
|
||||
|
||||
```rust, ignore
|
||||
|
@ -76,5 +76,7 @@ impl<E: TchElement> Backend for burn_autodiff::Autodiff<burn_tch::LibTorch<E>> {
|
|||
}
|
||||
```
|
||||
|
||||
The specificity of each implementation will be covered by the examples provided in this section.
|
||||
Currently, we only have one example, but more are yet to come!
|
||||
The specifics of each implementation will be covered by the examples provided in this section. The
|
||||
`cubecl` compiler frontend is the recommended method of implementing custom kernels, since it
|
||||
supports multiple backends, including `wgpu` and `CUDA`, and is the way first-party `burn` kernels
|
||||
are written.
|
||||
|
|
|
@ -0,0 +1,376 @@
|
|||
# Custom `cubecl` Kernel
|
||||
|
||||
In this section, you will learn how to create your own custom operation by writing your own kernel
|
||||
with the cubecl compiler frontend. We will take the example of a common workflow in the deep
|
||||
learning field, where we create a kernel to fuse multiple operations together. Note that `burn` does
|
||||
this automatically, but a manual implementation might be more efficient in some cases. We will fuse
|
||||
a matmul kernel followed by an addition and the ReLU activation function, which is commonly found in
|
||||
various models. All the code can be found under the
|
||||
[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/custom-cubecl-kernel).
|
||||
|
||||
## Custom Backend Trait
|
||||
|
||||
First, we need to determine the type signature of our newly created operation by defining our custom
|
||||
backend traits. As we will use the associated type `TensorPrimitive` of the `Backend` trait, which
|
||||
encapsulates the underlying tensor implementation of the backend, we will use a type alias to avoid
|
||||
the ugly disambiguation with associated types.
|
||||
|
||||
```rust, ignore
|
||||
/// We create our own Backend trait that extends the Burn backend trait.
|
||||
pub trait Backend: burn::tensor::backend::Backend {
|
||||
fn fused_matmul_add_relu(
|
||||
lhs: FloatTensor<Self>,
|
||||
rhs: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self>;
|
||||
}
|
||||
|
||||
/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.
|
||||
pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}
|
||||
```
|
||||
|
||||
In our project, we can use these traits instead of the
|
||||
`burn::tensor::backend::{Backend, AutodiffBackend}` traits provided by Burn. Burn's user APIs
|
||||
typically make use of the `Tensor` struct rather than dealing directly with primitive tensor types.
|
||||
Therefore, we can encapsulate our newly defined backend traits with functions that expose new
|
||||
operations while maintaining a consistent API.
|
||||
|
||||
```rust, ignore
|
||||
/// We define our custom implementation using the added function on our custom backend.
|
||||
pub fn matmul_add_relu_custom<B: Backend>(
|
||||
lhs: Tensor<B, 3>,
|
||||
rhs: Tensor<B, 3>,
|
||||
bias: Tensor<B, 3>,
|
||||
) -> Tensor<B, 3> {
|
||||
let output = B::fused_matmul_add_relu(
|
||||
lhs.into_primitive().tensor(),
|
||||
rhs.into_primitive().tensor(),
|
||||
bias.into_primitive().tensor(),
|
||||
);
|
||||
|
||||
Tensor::from_primitive(TensorPrimitive::Float(output))
|
||||
}
|
||||
|
||||
/// We define a reference implementation using basic tensor operations.
|
||||
pub fn matmul_add_relu_reference<B: Backend>(
|
||||
lhs: Tensor<B, 3>,
|
||||
rhs: Tensor<B, 3>,
|
||||
bias: Tensor<B, 3>,
|
||||
) -> Tensor<B, 3> {
|
||||
let x = lhs.matmul(rhs) + bias;
|
||||
|
||||
activation::relu(x)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
Note that we also provide a reference implementation for testing purposes, which allows us to easily
|
||||
validate our new implementation. While not mandatory, having a reference implementation can be
|
||||
valuable, especially in projects where creating a reference implementation solely using basic tensor
|
||||
operations is feasible.
|
||||
|
||||
## Forward Kernel
|
||||
|
||||
Now, let's proceed to write the fused kernel using the `cubecl` compiler frontend. To keep things
|
||||
simple, we'll create a straightforward matmul kernel without employing any intricate techniques. We
|
||||
won't delve into the details of the `cube` macro, but if you're interested to learn more, please see
|
||||
[`cubecl` Book](https://github.com/tracel-ai/cubecl/tree/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/cubecl-book).
|
||||
the The actual matmul, add and relu computations are found at the end, after an extensive prelude
|
||||
that serves to correctly map each compute unit to the data it is responsible for, with support for
|
||||
batches.
|
||||
|
||||
```rust, ignore
|
||||
use cubecl::{cube, prelude::*};
|
||||
|
||||
#[cube(launch)]
|
||||
pub fn fused_matmul_add_relu_kernel<F: Float>(
|
||||
lhs: &Tensor<F>,
|
||||
rhs: &Tensor<F>,
|
||||
bias: &Tensor<F>,
|
||||
output: &mut Tensor<F>,
|
||||
) {
|
||||
let row = ABSOLUTE_POS_X;
|
||||
let col = ABSOLUTE_POS_Y;
|
||||
let batch = ABSOLUTE_POS_Z;
|
||||
|
||||
let n_rows = output.shape(output.rank() - 2);
|
||||
let n_cols = output.shape(output.rank() - 1);
|
||||
let dim_k = rhs.shape(rhs.rank() - 1);
|
||||
|
||||
if row >= n_rows || col >= n_cols {
|
||||
return;
|
||||
}
|
||||
|
||||
let offset_output = batch * n_rows * n_cols;
|
||||
let mut offset_lhs = 0;
|
||||
let mut offset_rhs = 0;
|
||||
|
||||
let batch_dims = output.rank() - 2;
|
||||
for dim in 0..batch_dims {
|
||||
offset_lhs += offset_output / output.stride(dim) % lhs.shape(dim) * lhs.stride(dim);
|
||||
offset_rhs += offset_output / output.stride(dim) % rhs.shape(dim) * rhs.stride(dim);
|
||||
}
|
||||
|
||||
let mut sum = F::new(0.0);
|
||||
for k in 0..dim_k {
|
||||
let lhs_index = row * dim_k + k;
|
||||
let rhs_index = k * n_cols + col;
|
||||
|
||||
sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];
|
||||
}
|
||||
|
||||
let out_index = row * n_cols + col;
|
||||
let index = offset_output + out_index;
|
||||
|
||||
output[index] = F::max(sum + bias[index], F::new(0.0));
|
||||
}
|
||||
```
|
||||
|
||||
Now, let's move on to the next step, which involves implementing the remaining code to launch the
|
||||
kernel. We'll go into implementing our custom backend trait for the generic JIT backend. This
|
||||
automatically implements the trait for `burn-cuda`, `burn-wgpu` as well as fusion.
|
||||
|
||||
```rust, ignore
|
||||
/// Implement our custom backend trait for the generic `JitBackend`.
|
||||
impl<R: JitRuntime, F: FloatElement, I: IntElement> Backend for JitBackend<R, F, I> {
|
||||
fn fused_matmul_add_relu(
|
||||
lhs: FloatTensor<Self>,
|
||||
rhs: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
// Define cube dim, hardcoded for simplicity.
|
||||
let cube_dim = CubeDim { x: 16, y: 16, z: 1 };
|
||||
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
lhs.assert_is_on_same_device(&bias);
|
||||
|
||||
// For simplicity, make sure each tensor is continuous.
|
||||
let lhs = into_contiguous(lhs);
|
||||
let rhs = into_contiguous(rhs);
|
||||
let bias = into_contiguous(bias);
|
||||
|
||||
// Get the matmul relevant shapes.
|
||||
let ndims = lhs.shape.num_dims();
|
||||
let num_rows = lhs.shape.dims[ndims - 2];
|
||||
let num_cols = rhs.shape.dims[ndims - 1];
|
||||
|
||||
// Compute shape of output, while tracking number of batches.
|
||||
let mut num_batches = 1;
|
||||
let mut shape_out = vec![0; ndims];
|
||||
for i in shape_out.clone().into_iter().take(ndims - 2) {
|
||||
shape_out[i] = usize::max(lhs.shape.dims[i], rhs.shape.dims[i]);
|
||||
num_batches *= shape_out[i];
|
||||
}
|
||||
shape_out[ndims - 2] = num_rows;
|
||||
shape_out[ndims - 1] = num_cols;
|
||||
let shape_out = Shape::from(shape_out);
|
||||
|
||||
// Create a buffer for the output tensor.
|
||||
let buffer = lhs
|
||||
.client
|
||||
.empty(shape_out.num_elements() * core::mem::size_of::<F>());
|
||||
|
||||
// Create the output tensor primitive.
|
||||
let output =
|
||||
JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
|
||||
|
||||
// Declare the wgsl workgroup with the number of cubes in x, y and z.
|
||||
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
|
||||
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
|
||||
let cube_count =
|
||||
CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
|
||||
|
||||
// Execute lazily the kernel with the launch information and the given buffers. For
|
||||
// simplicity, no vectorization is performed
|
||||
fused_matmul_add_relu_kernel::launch::<F, R>(
|
||||
&lhs.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
lhs.as_tensor_arg(1),
|
||||
rhs.as_tensor_arg(1),
|
||||
bias.as_tensor_arg(1),
|
||||
output.as_tensor_arg(1),
|
||||
);
|
||||
|
||||
// Return the output tensor.
|
||||
output
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In the preceding code block, we demonstrated how to launch the kernel that modifies the correct
|
||||
buffer. It's important to note that Rust's mutability safety doesn't apply here; the context has the
|
||||
capability to execute any mutable operation on any buffer. While this isn't a problem in the
|
||||
previous scenario where we only modify the newly created output buffer, it is wise to keep this in
|
||||
mind.
|
||||
|
||||
## Backward
|
||||
|
||||
Now that the custom backend trait is implemented for the JIT backend, you can use it to invoke the
|
||||
`matmul_add_relu_custom` function. However, calculating gradients is not yet possible at this stage.
|
||||
If your use case does not extend beyond inference, there is no need to implement any of the
|
||||
following code.
|
||||
|
||||
For the backward pass, we will leverage the backend implementation from `burn-autodiff`, which is
|
||||
actually generic over the backend. Instead of crafting our own `cubecl` kernel for the backward
|
||||
pass, we will use our fused kernel only for the forward pass, and compute the gradient using basic
|
||||
operations.
|
||||
|
||||
```rust, ignore
|
||||
// Implement our custom backend trait for any backend that also implements our custom backend trait.
|
||||
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
||||
fn fused_matmul_add_relu(
|
||||
lhs: FloatTensor<Self>,
|
||||
rhs: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
// Create our zero-sized type that will implement the Backward trait.
|
||||
#[derive(Debug)]
|
||||
struct FusedMatmulAddReluBackward;
|
||||
|
||||
// Implement the backward trait for the given backend B, the node gradient
|
||||
// with three other gradients to calculate (lhs, rhs, and bias).
|
||||
impl<B: Backend> Backward<B, 3> for FusedMatmulAddReluBackward {
|
||||
// Our state that we must build during the forward pass to compute the backward pass.
|
||||
//
|
||||
// Note that we could improve the performance further by only keeping the state of
|
||||
// tensors that are tracked, improving memory management, but for simplicity, we avoid
|
||||
// that part.
|
||||
type State = (NodeID, NodeID, FloatTensor<B>, Shape);
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 3>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
// Get the nodes of each variable.
|
||||
let [node_lhs, node_rhs, node_bias] = ops.parents;
|
||||
// Fetch the gradient for the current node.
|
||||
let grad = grads.consume::<B>(&ops.node);
|
||||
|
||||
// Set our state.
|
||||
let (lhs_state, rhs_state, output, shape_bias) = ops.state;
|
||||
let lhs = checkpointer.retrieve_node_output(lhs_state);
|
||||
let rhs = checkpointer.retrieve_node_output(rhs_state);
|
||||
|
||||
// Fetch shapes of our tensor to support broadcasting.
|
||||
let shape_lhs = B::float_shape(&lhs);
|
||||
let shape_rhs = B::float_shape(&rhs);
|
||||
|
||||
// Compute the gradient of the output using the already existing `relu_backward`
|
||||
// function in the basic Burn backend trait.
|
||||
let grad_output = B::relu_backward(output, grad);
|
||||
|
||||
// Compute the lhs gradient, which is the derivative of matmul with support for
|
||||
// broadcasting.
|
||||
let grad_lhs = broadcast_shape::<B>(
|
||||
B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),
|
||||
&shape_lhs,
|
||||
);
|
||||
// Compute the rhs gradient, which is the derivative of matmul with support for
|
||||
// broadcasting.
|
||||
let grad_rhs = broadcast_shape::<B>(
|
||||
B::float_matmul(B::float_transpose(lhs), grad_output.clone()),
|
||||
&shape_rhs,
|
||||
);
|
||||
// The add derivative is only 1, so we just need to support broadcasting to
|
||||
// compute the bias gradient.
|
||||
let grad_bias = broadcast_shape::<B>(grad_output, &shape_bias);
|
||||
|
||||
// Register the gradient for each variable based on whether they are marked as
|
||||
// `tracked`.
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B>(node.id, grad_bias);
|
||||
}
|
||||
if let Some(node) = node_lhs {
|
||||
grads.register::<B>(node.id, grad_lhs);
|
||||
}
|
||||
if let Some(node) = node_rhs {
|
||||
grads.register::<B>(node.id, grad_rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare a stateful operation with each variable node and corresponding graph.
|
||||
//
|
||||
// Each node can be fetched with `ops.parents` in the same order as defined here.
|
||||
match FusedMatmulAddReluBackward
|
||||
.prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])
|
||||
// Marks the operation as compute bound, meaning it will save its
|
||||
// state instead of recomputing itself during checkpointing
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
// When at least one node is tracked, we should register our backward step.
|
||||
|
||||
// The state consists of what will be needed for this operation's backward pass.
|
||||
// Since we need the parents' outputs, we must checkpoint their ids to retrieve
|
||||
// their node output at the beginning of the backward pass. We can also save
|
||||
// utilitary data such as the bias shape. If we also need this operation's output,
|
||||
// we can either save it in the state or recompute it.
|
||||
// during the backward pass. Here we choose to save it in the state because it's a
|
||||
// compute bound operation.
|
||||
let lhs_state = prep.checkpoint(&lhs);
|
||||
let rhs_state = prep.checkpoint(&rhs);
|
||||
let bias_shape = B::float_shape(&bias.primitive);
|
||||
|
||||
let output = B::fused_matmul_add_relu(
|
||||
lhs.primitive.clone(),
|
||||
rhs.primitive.clone(),
|
||||
bias.primitive,
|
||||
);
|
||||
|
||||
let state = (lhs_state, rhs_state, output.clone(), bias_shape);
|
||||
|
||||
prep.finish(state, output)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => {
|
||||
// When no node is tracked, we can just compute the original operation without
|
||||
// keeping any state.
|
||||
let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive);
|
||||
prep.finish(output)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The previous code is self-documented to make it clearer, but here is what it does in summary:
|
||||
|
||||
We define `fused_matmul_add_relu` within `Autodiff<B>`, allowing any autodiff-decorated backend to
|
||||
benefit from our implementation. In an autodiff-decorated backend, the forward pass must still be
|
||||
implemented. This is achieved using a comprehensive match statement block where computation is
|
||||
delegated to the inner backend, while keeping track of a state. The state comprises any information
|
||||
relevant to the backward pass, such as input and output tensors, along with the bias shape. When an
|
||||
operation isn't tracked (meaning there won't be a backward pass for this specific operation in the
|
||||
graph), storing a state becomes unnecessary, and we simply perform the forward computation.
|
||||
|
||||
The backward pass uses the gradient obtained from the preceding node in the computation graph. It
|
||||
calculates the derivatives for `relu` (`relu_backward`), add (no operation is required here, as the
|
||||
derivative is one), and `matmul` (another `matmul` with transposed inputs). This results in
|
||||
gradients for both input tensors and the bias, which are registered for consumption by subsequent
|
||||
operation nodes.
|
||||
|
||||
The only remaining part is to implement our autodiff-decorated backend trait for our JIT Backend.
|
||||
|
||||
```rust, ignore
|
||||
impl<R: JitRuntime, F: FloatElement, I: IntElement> AutodiffBackend
|
||||
for Autodiff<JitBackend<R, F, I>>
|
||||
{
|
||||
}
|
||||
```
|
||||
|
||||
## Conclusion
|
||||
|
||||
In this guide, we've implemented a fused kernel using the `cubecl` compiler frontend, enabling
|
||||
execution on any GPU and any `cubecl` backend. By delving into the inner workings of both the JIT
|
||||
backend and the autodiff backend, we've gained a deeper understanding of these systems.
|
||||
|
||||
While extending a backend may be harder than working with straightforward tensors, the benefits can
|
||||
be worth it. This approach enables the crafting of custom models with greater control over
|
||||
execution, which can potentially greatly enhance the performance of your models.
|
||||
|
||||
As we conclude this guide, we hope that you have gained insights into Burn's world of backend
|
||||
extensions, and that it will help you to unleash the full potential of your projects.
|
|
@ -2,9 +2,10 @@
|
|||
|
||||
In this section, you will learn how to create your own custom operation by writing your own kernel
|
||||
with the WGPU backend. We will take the example of a common workflow in the deep learning field,
|
||||
where we create a kernel to fuse multiple operations together. We will fuse a matmul kernel followed
|
||||
by an addition and the ReLU activation function, which is commonly found in various models. All the
|
||||
code can be found under the
|
||||
where we create a kernel to fuse multiple operations together. Note that `burn` does this
|
||||
automatically, but a manual implementation might be more efficient in some cases. We will fuse a
|
||||
matmul kernel followed by an addition and the ReLU activation function, which is commonly found in
|
||||
various models. All the code can be found under the
|
||||
[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/custom-wgpu-kernel).
|
||||
|
||||
## Custom Backend Trait
|
||||
|
@ -15,9 +16,6 @@ encapsulates the underlying tensor implementation of the backend, we will use a
|
|||
the ugly disambiguation with associated types.
|
||||
|
||||
```rust, ignore
|
||||
/// We use a type alias for better readability.
|
||||
pub type FloatTensor<B, const D: usize> = <B as burn::tensor::backend::Backend>::TensorPrimitive<D>;
|
||||
|
||||
/// We create our own Backend trait that extends the Burn backend trait.
|
||||
pub trait Backend: burn::tensor::backend::Backend {
|
||||
fn fused_matmul_add_relu(
|
||||
|
@ -78,7 +76,7 @@ we'll create a straightforward matmul kernel without employing any intricate tec
|
|||
won't delve into the details of the WGSL syntax, as it falls beyond the scope of this guide, we
|
||||
still provide the implementation below for readers who are curious. The actual matmul, add and relu
|
||||
computations are found at the end, after an extensive overhead whose use is to correctly map each
|
||||
thread to the data it is responsible of, with support for batches.
|
||||
compute unit to the data it is responsible of, with support for batches.
|
||||
|
||||
```wgsl, ignore
|
||||
@group(0)
|
||||
|
@ -451,7 +449,5 @@ While extending a backend may be harder than working with straightforward tensor
|
|||
be worth it. This approach enables the crafting of custom models with greater control over
|
||||
execution, which can potentially greatly enhance the performance of your models.
|
||||
|
||||
It is worth noting that while the manual fusion of operations can be valuable, our future plans
|
||||
include the development of a backend extension that will automate this process. As we conclude this
|
||||
guide, we hope that you have gained insights into Burn's world of backend extensions, and that it
|
||||
will help you to unleash the full potential of your projects.
|
||||
As we conclude this guide, we hope that you have gained insights into Burn's world of backend
|
||||
extensions, and that it will help you to unleash the full potential of your projects.
|
||||
|
|
|
@ -139,14 +139,14 @@ seeing what completions are suggested will take you far. If you are having troub
|
|||
to do it from the docs for that backend,
|
||||
[try searching github for relevant function calls](https://docs.github.com/en/search-github/github-code-search/understanding-github-code-search-syntax).
|
||||
|
||||
## Adding the Op to fusion, JIT and wgpu backends
|
||||
## Adding the Op to fusion, JIT and cubecl backends
|
||||
|
||||
Adding an operator to these backends can be fairly straightforward, though due to what these
|
||||
backends are for, involves a bit more indirection. Fusion and jit, like autodiff, are not target
|
||||
backends as much as backends that enable certain functionality for other backends, in this case
|
||||
kernel fusion or just-in-time compilation (only available for `burn-wgpu` backend at the moment).
|
||||
Adding the operator won't involve doing any calculation, you'll just be describing how the generated
|
||||
code should look. Most of this can be copy/pasted/adjusted from other functions.
|
||||
kernel fusion or just-in-time compilation. Adding the operator won't involve doing any calculation,
|
||||
you'll just be describing how the generated code should look. Most of this can be
|
||||
copy/pasted/adjusted from other functions.
|
||||
|
||||
Here's how powf was added to `burn-fusion`:
|
||||
|
||||
|
@ -157,41 +157,47 @@ Here's how powf was added to `burn-fusion`:
|
|||
3. Added powf to the implementations of `NumericOperationDescription` enum under
|
||||
[crates/burn-fusion/src/stream/context.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-fusion/src/stream/context.rs#L771)
|
||||
|
||||
The way wgpu handles tensor-scalar operations is by transforming both into a sequence of vectorized
|
||||
scalar operations. Since powf already existed in burn-wgpu, it was pretty easy to reuse the existing
|
||||
implementation for the situation where both sides of the operation were tensors. The `burn-wgpu`
|
||||
crate is primarily concerned with how the operation is compiled and executed by the gpu. The actual
|
||||
implementation is defined in `burn-jit`.
|
||||
The way `cubecl` handles tensor-scalar operations is by transforming both into a sequence of
|
||||
vectorized scalar operations. Since powf already existed in `cubecl`, it was pretty easy to reuse
|
||||
the existing implementation for the situation where both sides of the operation were tensors. The
|
||||
`cubecl` crate is primarily concerned with how the operation is compiled and executed by the gpu.
|
||||
The actual implementation is defined in `burn-jit`.
|
||||
|
||||
Here is where code was added for powf in `burn-jit` and `burn-wgpu`:
|
||||
Here is where code was added for powf in `burn-jit` and `cubecl`:
|
||||
|
||||
1. to the implementation of
|
||||
[`FloatTensorOps` under `crates/burn-jit/src/ops/float_ops.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/ops/float_ops.rs#L491)
|
||||
[`FloatTensorOps` under `crates/burn-jit/src/ops/float_ops.rs`](https://github.com/tracel-ai/burn/blob/3b51c26958128502d60fb35029c43d9b686b816c/crates/burn-jit/src/ops/float_ops.rs#L410)
|
||||
2. the function being called was added to
|
||||
[crates/burn-jit/src/ops/numeric.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/ops/numeric.rs#L229)
|
||||
[crates/burn-jit/src/ops/numeric.rs](https://github.com/tracel-ai/burn/blob/3b51c26958128502d60fb35029c43d9b686b816c/crates/burn-jit/src/ops/numeric.rs#L147)
|
||||
3. the operator was defined in
|
||||
[`crates/burn-jit/src/codegen/dialect/gpu/operation.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/codegen/dialect/gpu/operation.rs#L37)
|
||||
4. the vectorization was added to
|
||||
[`crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs#L55)
|
||||
5. how the operation looks to the gpu was added to
|
||||
[`crates/burn-jit/src/fusion/tracing/builder.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/fusion/tracing/builder.rs#L279)
|
||||
6. the mapping between the gpu operation and the WGSL instruction was added to
|
||||
[`crates/burn-wgpu/src/compiler/wgsl/compiler.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-wgpu/src/compiler/wgsl/compiler.rs#L455)
|
||||
7. the WGSL instruction itself was added to the
|
||||
[instruction op enum in `crates/burn-wgpu/src/compiler/wgsl/instructions.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-wgpu/src/compiler/wgsl/instructions.rs#L103),
|
||||
[`cubecl-core/src/ir/operation.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-core/src/ir/operation.rs#L68)
|
||||
4. how the operation looks to the gpu was added to
|
||||
[`crates/burn-jit/src/fusion/on_write/ir.rs`](https://github.com/tracel-ai/burn/blob/3b51c26958128502d60fb35029c43d9b686b816c/crates/burn-jit/src/fusion/on_write/ir.rs#L52)
|
||||
5. the mappings between the gpu operation and the CPP, WGSL and SPIR-V instructions were added to
|
||||
[`cubecl-cpp/src/shared/base.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-cpp/src/shared/base.rs#L456),
|
||||
[`cubecl-wgpu/src/compiler/wgsl/compiler.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs#L652)
|
||||
and
|
||||
[`cubecl-spirv/src/instruction.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-spirv/src/instruction.rs#L408)
|
||||
6. the instructions themselves were added for WGSL to
|
||||
[instruction op enum in `cubecl-wgpu/src/compiler/wgsl/instructions.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs#L124),
|
||||
and the actual
|
||||
[instruction in wgsl here](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-wgpu/src/compiler/wgsl/instructions.rs#L273)
|
||||
[instruction in wgsl here](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs#L547-L555),
|
||||
for CPP in the enum here
|
||||
[`cubecl-cpp/src/shared/instruction.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-cpp/src/shared/instruction.rs#L127)
|
||||
and the actual instruction here
|
||||
[`cubecl-cpp/src/shared/binary.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-cpp/src/shared/binary.rs#L137)
|
||||
|
||||
We needed to generate some custom WGSL code for powf, primarily due to issues with proper case
|
||||
handling of the wgsl pow function, like 0 to the 0 power being 1, and any negative number to an even
|
||||
power being positive. We reused as much as the existing logic as possible, and then branched at the
|
||||
last point based off the var type of the rhs.
|
||||
[See here](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-wgpu/src/compiler/wgsl/compiler.rs#L596).
|
||||
For most operations, you shouldn't need to add to `crates/burn-wgpu/src/compiler/wgsl/extension.rs`
|
||||
We needed to generate some custom WGSL code for powf in WGSL, primarily due to issues with proper
|
||||
case handling of the wgsl pow function, like 0 to the 0 power being 1, and any negative number to an
|
||||
even power being positive. We reused as much as the existing logic as possible, and then branched at
|
||||
the last point based off the var type of the rhs.
|
||||
[See here](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs#L911).
|
||||
For most operations, you shouldn't need to add to `cubecl-wgpu/src/compiler/wgsl/extension.rs`
|
||||
unless the operation isn't native to WGSL.
|
||||
|
||||
For functions that need a complex kernel without a direct mapping to a base instruction, it is not
|
||||
as straightforward. An easier manner of implementing them is underway.
|
||||
For functions that need a complex kernel without a direct mapping to a base instruction, simply use
|
||||
the `cube` macro (see
|
||||
[the `cubecl` book](https://github.com/tracel-ai/cubecl/tree/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/cubecl-book)).
|
||||
|
||||
## Adding the Op to burn-import
|
||||
|
||||
|
|
|
@ -146,6 +146,7 @@ hashbrown = { workspace = true, features = ["serde"] } # no_std compatible
|
|||
flate2 = { workspace = true, optional = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
|
||||
ahash = { workspace = true }
|
||||
bincode = { workspace = true }
|
||||
half = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
|
@ -154,7 +155,6 @@ rmp-serde = { workspace = true, optional = true }
|
|||
serde_json = { workspace = true, features = ["alloc"] } #Default enables std
|
||||
spin = { workspace = true } # Using in place of use std::sync::Mutex when std is disabled
|
||||
thiserror = { workspace = true, optional = true }
|
||||
ahash = { workspace = true }
|
||||
|
||||
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
|
||||
portable-atomic-util = { workspace = true }
|
||||
|
|
|
@ -1,5 +1,30 @@
|
|||
# Burn-Cuda
|
||||
# Burn CUDA Backend
|
||||
|
||||
This backend is still a work in progress and not ready to be used.
|
||||
[Burn](https://github.com/tracel-ai/burn) CUDA backend
|
||||
|
||||
See #1525
|
||||
[](https://crates.io/crates/burn-cuda)
|
||||
[](https://github.com/tracel-ai/burn-cuda/blob/master/README.md)
|
||||
|
||||
This crate provides a CUDA backend for [Burn](https://github.com/tracel-ai/burn) using the
|
||||
[cubecl](https://github.com/tracel-ai/cubecl.git) and [cudarc](https://github.com/coreylowman/cudarc.git)
|
||||
crates.
|
||||
|
||||
## Usage Example
|
||||
|
||||
```rust
|
||||
#[cfg(feature = "cuda")]
|
||||
mod cuda {
|
||||
use burn_autodiff::Autodiff;
|
||||
use burn_cuda::{Cuda, CudaDevice};
|
||||
use mnist::training;
|
||||
|
||||
pub fn run() {
|
||||
let device = CudaDevice::default();
|
||||
training::run::<Autodiff<Cuda<f32, i32>>>(device);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
Requires CUDA 12.x to be installed and on the `PATH`.
|
|
@ -0,0 +1,25 @@
|
|||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "custom-cubecl-kernel"
|
||||
publish = false
|
||||
version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../crates/burn", default-features = false, features = [
|
||||
"autodiff",
|
||||
"wgpu",
|
||||
"autotune",
|
||||
"template",
|
||||
] }
|
||||
burn-jit = { path = "../../crates/burn-jit" }
|
||||
cubecl = { workspace = true, features = ["wgpu"] }
|
||||
|
||||
# Serialization
|
||||
log = { workspace = true }
|
||||
serde = { workspace = true, features = ["std", "derive"] }
|
||||
|
||||
# Wgpu internal dependencies
|
||||
bytemuck = { workspace = true }
|
||||
derive-new = { workspace = true }
|
|
@ -0,0 +1,79 @@
|
|||
use burn::{
|
||||
backend::wgpu::WgpuRuntime,
|
||||
tensor::{Distribution, Tensor},
|
||||
};
|
||||
use custom_cubecl_kernel::{
|
||||
matmul_add_relu_custom, matmul_add_relu_reference, AutodiffBackend, Backend,
|
||||
};
|
||||
|
||||
fn inference<B: Backend>(device: &B::Device) {
|
||||
let lhs = Tensor::<B, 3>::random([1, 32, 32], Distribution::Default, device);
|
||||
let rhs = Tensor::random([32, 32, 32], Distribution::Default, device);
|
||||
let bias = Tensor::random([32, 32, 32], Distribution::Default, device);
|
||||
|
||||
let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone())
|
||||
.into_data()
|
||||
.convert::<f32>();
|
||||
let custom = matmul_add_relu_custom(lhs, rhs, bias)
|
||||
.into_data()
|
||||
.convert::<f32>();
|
||||
|
||||
reference.assert_approx_eq(&custom, 3);
|
||||
|
||||
println!("Both reference and the custom fused kernel have the same output");
|
||||
}
|
||||
|
||||
fn autodiff<B: AutodiffBackend>(device: &B::Device) {
|
||||
let lhs = Tensor::<B, 3>::random([1, 32, 32], Distribution::Default, device).require_grad();
|
||||
let rhs = Tensor::random([32, 32, 32], Distribution::Default, device).require_grad();
|
||||
let bias = Tensor::random([32, 32, 32], Distribution::Default, device).require_grad();
|
||||
|
||||
let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone());
|
||||
|
||||
let mut gradients = reference.backward();
|
||||
|
||||
let lhs_grad_ref = lhs.grad_remove(&mut gradients).unwrap();
|
||||
let rhs_grad_ref = rhs.grad_remove(&mut gradients).unwrap();
|
||||
let bias_grad_ref = bias.grad_remove(&mut gradients).unwrap();
|
||||
|
||||
let lhs = lhs.detach();
|
||||
let rhs = rhs.detach();
|
||||
let bias = bias.detach();
|
||||
|
||||
let custom = matmul_add_relu_custom(lhs.clone(), rhs.clone(), bias.clone());
|
||||
|
||||
let mut gradients = custom.backward();
|
||||
|
||||
let lhs_grad_custom = lhs.grad_remove(&mut gradients).unwrap();
|
||||
let rhs_grad_custom = rhs.grad_remove(&mut gradients).unwrap();
|
||||
let bias_grad_custom = bias.grad_remove(&mut gradients).unwrap();
|
||||
|
||||
lhs_grad_ref
|
||||
.into_data()
|
||||
.convert::<B::FloatElem>()
|
||||
.assert_approx_eq(&lhs_grad_custom.into_data().convert::<B::FloatElem>(), 3);
|
||||
|
||||
println!("Both reference and the custom fused kernel have the same lhs gradient");
|
||||
|
||||
rhs_grad_ref
|
||||
.into_data()
|
||||
.convert::<f32>()
|
||||
.assert_approx_eq(&rhs_grad_custom.into_data().convert::<B::FloatElem>(), 3);
|
||||
|
||||
println!("Both reference and the custom fused kernel have the same rhs gradient");
|
||||
|
||||
bias_grad_ref
|
||||
.into_data()
|
||||
.convert::<f32>()
|
||||
.assert_approx_eq(&bias_grad_custom.into_data().convert::<B::FloatElem>(), 3);
|
||||
|
||||
println!("Both reference and the custom fused kernel have the same bias gradient");
|
||||
}
|
||||
|
||||
fn main() {
|
||||
type MyBackend = burn::backend::wgpu::JitBackend<WgpuRuntime, f32, i32>;
|
||||
type MyAutodiffBackend = burn::backend::Autodiff<MyBackend>;
|
||||
let device = Default::default();
|
||||
inference::<MyBackend>(&device);
|
||||
autodiff::<MyAutodiffBackend>(&device);
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
use crate::FloatTensor;
|
||||
|
||||
use super::{AutodiffBackend, Backend};
|
||||
use burn::{
|
||||
backend::autodiff::{
|
||||
checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
|
||||
grads::Gradients,
|
||||
ops::{broadcast_shape, Backward, Ops, OpsKind},
|
||||
Autodiff, NodeID,
|
||||
},
|
||||
tensor::Shape,
|
||||
};
|
||||
use burn_jit::{FloatElement, IntElement, JitBackend, JitRuntime};
|
||||
|
||||
impl<R: JitRuntime, F: FloatElement, I: IntElement> AutodiffBackend
|
||||
for Autodiff<JitBackend<R, F, I>>
|
||||
{
|
||||
}
|
||||
|
||||
// Implement our custom backend trait for any backend that also implements our custom backend trait.
|
||||
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
||||
fn fused_matmul_add_relu(
|
||||
lhs: FloatTensor<Self>,
|
||||
rhs: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
// Create our zero-sized type that will implement the Backward trait.
|
||||
#[derive(Debug)]
|
||||
struct FusedMatmulAddReluBackward;
|
||||
|
||||
// Implement the backward trait for the given backend B, the node gradient
|
||||
// with three other gradients to calculate (lhs, rhs, and bias).
|
||||
impl<B: Backend> Backward<B, 3> for FusedMatmulAddReluBackward {
|
||||
// Our state that we must build during the forward pass to compute the backward pass.
|
||||
//
|
||||
// Note that we could improve the performance further by only keeping the state of
|
||||
// tensors that are tracked, improving memory management, but for simplicity, we avoid
|
||||
// that part.
|
||||
type State = (NodeID, NodeID, FloatTensor<B>, Shape);
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 3>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
// Get the nodes of each variable.
|
||||
let [node_lhs, node_rhs, node_bias] = ops.parents;
|
||||
// Fetch the gradient for the current node.
|
||||
let grad = grads.consume::<B>(&ops.node);
|
||||
|
||||
// Set our state.
|
||||
let (lhs_state, rhs_state, output, shape_bias) = ops.state;
|
||||
let lhs = checkpointer.retrieve_node_output(lhs_state);
|
||||
let rhs = checkpointer.retrieve_node_output(rhs_state);
|
||||
|
||||
// Fetch shapes of our tensor to support broadcasting.
|
||||
let shape_lhs = B::float_shape(&lhs);
|
||||
let shape_rhs = B::float_shape(&rhs);
|
||||
|
||||
// Compute the gradient of the output using the already existing `relu_backward`
|
||||
// function in the basic Burn backend trait.
|
||||
let grad_output = B::relu_backward(output, grad);
|
||||
|
||||
// Compute the lhs gradient, which is the derivative of matmul with support for
|
||||
// broadcasting.
|
||||
let grad_lhs = broadcast_shape::<B>(
|
||||
B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),
|
||||
&shape_lhs,
|
||||
);
|
||||
// Compute the rhs gradient, which is the derivative of matmul with support for
|
||||
// broadcasting.
|
||||
let grad_rhs = broadcast_shape::<B>(
|
||||
B::float_matmul(B::float_transpose(lhs), grad_output.clone()),
|
||||
&shape_rhs,
|
||||
);
|
||||
// The add derivative is only 1, so we just need to support broadcasting to
|
||||
// compute the bias gradient.
|
||||
let grad_bias = broadcast_shape::<B>(grad_output, &shape_bias);
|
||||
|
||||
// Register the gradient for each variable based on whether they are marked as
|
||||
// `tracked`.
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B>(node.id, grad_bias);
|
||||
}
|
||||
if let Some(node) = node_lhs {
|
||||
grads.register::<B>(node.id, grad_lhs);
|
||||
}
|
||||
if let Some(node) = node_rhs {
|
||||
grads.register::<B>(node.id, grad_rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare a stateful operation with each variable node and corresponding graph.
|
||||
//
|
||||
// Each node can be fetched with `ops.parents` in the same order as defined here.
|
||||
match FusedMatmulAddReluBackward
|
||||
.prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])
|
||||
// Marks the operation as compute bound, meaning it will save its
|
||||
// state instead of recomputing itself during checkpointing
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
// When at least one node is tracked, we should register our backward step.
|
||||
|
||||
// The state consists of what will be needed for this operation's backward pass.
|
||||
// Since we need the parents' outputs, we must checkpoint their ids to retrieve
|
||||
// their node output at the beginning of the backward pass. We can also save
|
||||
// utilitary data such as the bias shape. If we also need this operation's output,
|
||||
// we can either save it in the state or recompute it.
|
||||
// during the backward pass. Here we choose to save it in the state because it's a
|
||||
// compute bound operation.
|
||||
let lhs_state = prep.checkpoint(&lhs);
|
||||
let rhs_state = prep.checkpoint(&rhs);
|
||||
let bias_shape = B::float_shape(&bias.primitive);
|
||||
|
||||
let output = B::fused_matmul_add_relu(
|
||||
lhs.primitive.clone(),
|
||||
rhs.primitive.clone(),
|
||||
bias.primitive,
|
||||
);
|
||||
|
||||
let state = (lhs_state, rhs_state, output.clone(), bias_shape);
|
||||
|
||||
prep.finish(state, output)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => {
|
||||
// When no node is tracked, we can just compute the original operation without
|
||||
// keeping any state.
|
||||
let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive);
|
||||
prep.finish(output)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
use crate::{kernel::fused_matmul_add_relu_kernel, FloatTensor};
|
||||
|
||||
use super::Backend;
|
||||
use burn::tensor::Shape;
|
||||
use burn_jit::{
|
||||
kernel::into_contiguous, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime,
|
||||
};
|
||||
use cubecl::{CubeCount, CubeDim};
|
||||
|
||||
/// Implement our custom backend trait for the generic `JitBackend`.
|
||||
impl<R: JitRuntime, F: FloatElement, I: IntElement> Backend for JitBackend<R, F, I> {
|
||||
fn fused_matmul_add_relu(
|
||||
lhs: FloatTensor<Self>,
|
||||
rhs: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
// Define cube dim, hardcoded for simplicity.
|
||||
let cube_dim = CubeDim { x: 16, y: 16, z: 1 };
|
||||
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
lhs.assert_is_on_same_device(&bias);
|
||||
|
||||
// For simplicity, make sure each tensor is continuous.
|
||||
let lhs = into_contiguous(lhs);
|
||||
let rhs = into_contiguous(rhs);
|
||||
let bias = into_contiguous(bias);
|
||||
|
||||
// Get the matmul relevant shapes.
|
||||
let ndims = lhs.shape.num_dims();
|
||||
let num_rows = lhs.shape.dims[ndims - 2];
|
||||
let num_cols = rhs.shape.dims[ndims - 1];
|
||||
|
||||
// Compute shape of output, while tracking number of batches.
|
||||
let mut num_batches = 1;
|
||||
let mut shape_out = vec![0; ndims];
|
||||
for i in shape_out.clone().into_iter().take(ndims - 2) {
|
||||
shape_out[i] = usize::max(lhs.shape.dims[i], rhs.shape.dims[i]);
|
||||
num_batches *= shape_out[i];
|
||||
}
|
||||
shape_out[ndims - 2] = num_rows;
|
||||
shape_out[ndims - 1] = num_cols;
|
||||
let shape_out = Shape::from(shape_out);
|
||||
|
||||
// Create a buffer for the output tensor.
|
||||
let buffer = lhs
|
||||
.client
|
||||
.empty(shape_out.num_elements() * core::mem::size_of::<F>());
|
||||
|
||||
// Create the output tensor primitive.
|
||||
let output =
|
||||
JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
|
||||
|
||||
// Declare the wgsl workgroup with the number of cubes in x, y and z.
|
||||
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
|
||||
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
|
||||
let cube_count =
|
||||
CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
|
||||
|
||||
// Execute lazily the kernel with the launch information and the given buffers. For
|
||||
// simplicity, no vectorization is performed
|
||||
fused_matmul_add_relu_kernel::launch::<F, R>(
|
||||
&lhs.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
lhs.as_tensor_arg(1),
|
||||
rhs.as_tensor_arg(1),
|
||||
bias.as_tensor_arg(1),
|
||||
output.as_tensor_arg(1),
|
||||
);
|
||||
|
||||
// Return the output tensor.
|
||||
output
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
use cubecl::{cube, prelude::*};
|
||||
|
||||
/// Declare a custom kernel that gets compiled to `wgpu`/`CUDA`
|
||||
#[cube(launch)]
|
||||
pub fn fused_matmul_add_relu_kernel<F: Float>(
|
||||
lhs: &Tensor<F>,
|
||||
rhs: &Tensor<F>,
|
||||
bias: &Tensor<F>,
|
||||
output: &mut Tensor<F>,
|
||||
) {
|
||||
let row = ABSOLUTE_POS_X;
|
||||
let col = ABSOLUTE_POS_Y;
|
||||
let batch = ABSOLUTE_POS_Z;
|
||||
|
||||
let n_rows = output.shape(output.rank() - 2);
|
||||
let n_cols = output.shape(output.rank() - 1);
|
||||
let dim_k = rhs.shape(rhs.rank() - 1);
|
||||
|
||||
if row >= n_rows || col >= n_cols {
|
||||
return;
|
||||
}
|
||||
|
||||
let offset_output = batch * n_rows * n_cols;
|
||||
let mut offset_lhs = 0;
|
||||
let mut offset_rhs = 0;
|
||||
|
||||
let batch_dims = output.rank() - 2;
|
||||
for dim in 0..batch_dims {
|
||||
offset_lhs += offset_output / output.stride(dim) % lhs.shape(dim) * lhs.stride(dim);
|
||||
offset_rhs += offset_output / output.stride(dim) % rhs.shape(dim) * rhs.stride(dim);
|
||||
}
|
||||
|
||||
let mut sum = F::new(0.0);
|
||||
for k in 0..dim_k {
|
||||
let lhs_index = row * dim_k + k;
|
||||
let rhs_index = k * n_cols + col;
|
||||
|
||||
sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];
|
||||
}
|
||||
|
||||
let out_index = row * n_cols + col;
|
||||
let index = offset_output + out_index;
|
||||
|
||||
output[index] = F::max(sum + bias[index], F::new(0.0));
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
mod backward;
|
||||
mod forward;
|
||||
mod kernel;
|
||||
|
||||
use burn::tensor::{activation, ops::FloatTensor, Tensor, TensorPrimitive};
|
||||
|
||||
/// We create our own Backend trait that extends the Burn backend trait.
|
||||
pub trait Backend: burn::tensor::backend::Backend {
|
||||
fn fused_matmul_add_relu(
|
||||
lhs: FloatTensor<Self>,
|
||||
rhs: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self>;
|
||||
}
|
||||
|
||||
/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.
|
||||
pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}
|
||||
|
||||
/// We define our custom implementation using the added function on our custom backend.
|
||||
pub fn matmul_add_relu_custom<B: Backend>(
|
||||
lhs: Tensor<B, 3>,
|
||||
rhs: Tensor<B, 3>,
|
||||
bias: Tensor<B, 3>,
|
||||
) -> Tensor<B, 3> {
|
||||
let output = B::fused_matmul_add_relu(
|
||||
lhs.into_primitive().tensor(),
|
||||
rhs.into_primitive().tensor(),
|
||||
bias.into_primitive().tensor(),
|
||||
);
|
||||
|
||||
Tensor::from_primitive(TensorPrimitive::Float(output))
|
||||
}
|
||||
|
||||
/// We define a reference implementation using basic tensor operations.
|
||||
pub fn matmul_add_relu_reference<B: Backend>(
|
||||
lhs: Tensor<B, 3>,
|
||||
rhs: Tensor<B, 3>,
|
||||
bias: Tensor<B, 3>,
|
||||
) -> Tensor<B, 3> {
|
||||
let x = lhs.matmul(rhs) + bias;
|
||||
|
||||
activation::relu(x)
|
||||
}
|
|
@ -1,10 +1,7 @@
|
|||
mod backward;
|
||||
mod forward;
|
||||
|
||||
use burn::tensor::{activation, Tensor, TensorPrimitive};
|
||||
|
||||
/// We use a type alias for better readability.
|
||||
pub type FloatTensor<B> = <B as burn::tensor::backend::Backend>::FloatTensorPrimitive;
|
||||
use burn::tensor::{activation, ops::FloatTensor, Tensor, TensorPrimitive};
|
||||
|
||||
/// We create our own Backend trait that extends the Burn backend trait.
|
||||
pub trait Backend: burn::tensor::backend::Backend {
|
||||
|
|
Loading…
Reference in New Issue