mirror of https://github.com/tracel-ai/burn.git
Fix: wgpu cat + copy ops (#477)
This commit is contained in:
parent
513b9281c2
commit
ddbbe39d74
|
@ -47,26 +47,35 @@ pub fn cat<E: WgpuElement, const D: usize>(
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tests::{ReferenceBackend, TestBackend};
|
||||
use burn_tensor::{Distribution, Tensor};
|
||||
use burn_tensor::{backend::Backend, Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
fn cat_should_support_multiple_invokations() {
|
||||
test_same_as_reference([6, 256]);
|
||||
fn cat_should_support_multiple_invocations_dim0() {
|
||||
test_same_as_reference([6, 256], 2, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cat_should_support_multiple_invocations_dim1() {
|
||||
test_same_as_reference([6, 256], 2, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cat_should_support_uneven_launch() {
|
||||
test_same_as_reference([1, 137]);
|
||||
test_same_as_reference([1, 137], 2, 0);
|
||||
}
|
||||
|
||||
fn test_same_as_reference(shape: [usize; 2]) {
|
||||
let tensor1 = Tensor::<TestBackend, 2>::random(shape, Distribution::Default);
|
||||
let tensor2 = Tensor::<TestBackend, 2>::random(shape, Distribution::Default);
|
||||
let tensor1_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor1.to_data());
|
||||
let tensor2_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor2.to_data());
|
||||
fn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: usize) {
|
||||
TestBackend::seed(0);
|
||||
let tensors = (0..num_tensors)
|
||||
.map(|_| Tensor::<TestBackend, 2>::random(shape, Distribution::Default))
|
||||
.collect::<Vec<_>>();
|
||||
let tensors_ref = tensors
|
||||
.iter()
|
||||
.map(|tensor| Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let tensor = Tensor::<TestBackend, 2>::cat(vec![tensor1, tensor2], 0);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::cat(vec![tensor1_ref, tensor2_ref], 0);
|
||||
let tensor = Tensor::<TestBackend, 2>::cat(tensors, dim);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::cat(tensors_ref, dim);
|
||||
|
||||
tensor
|
||||
.into_data()
|
||||
|
|
|
@ -58,7 +58,6 @@ pub fn swap_dims<E: WgpuElement, const D: usize>(
|
|||
dim2: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
tensor.strides.swap(dim1, dim2);
|
||||
|
||||
tensor.shape.dims.swap(dim1, dim2);
|
||||
|
||||
tensor
|
||||
|
|
|
@ -23,8 +23,9 @@ fn main(
|
|||
let dim_cat = info[4u * dim + 1u];
|
||||
let dim_cat_index = info[4u * dim + 2u];
|
||||
|
||||
var index_input: u32 = 0u;
|
||||
var index_output: u32 = 0u;
|
||||
var num_elems = 1u;
|
||||
var index_input = 0u;
|
||||
var index_output = 0u;
|
||||
|
||||
for (var i: u32 = 1u; i <= dim; i++) {
|
||||
let stride_input = info[i];
|
||||
|
@ -32,8 +33,9 @@ fn main(
|
|||
let shape_input = info[i + 2u * dim];
|
||||
let shape_output = info[i + 3u * dim];
|
||||
|
||||
let num_block_output = id / stride_output % shape_output;
|
||||
let num_block_output = id / stride_input % shape_input;
|
||||
index_input += num_block_output * stride_input;
|
||||
num_elems *= shape_input;
|
||||
|
||||
if i - 1u == dim_cat {
|
||||
index_output += (num_block_output + dim_cat_index) * stride_output;
|
||||
|
@ -42,6 +44,8 @@ fn main(
|
|||
}
|
||||
}
|
||||
|
||||
if id < num_elems {
|
||||
output[index_output] = input[index_input];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
|
|||
// slowdowns.
|
||||
//
|
||||
// The solution is just to use a simple unary compute shader.
|
||||
unary!(CopyBuffer, body "output[global_id.x] = input[global_id.x];");
|
||||
unary!(CopyBuffer, body "output[id] = input[id];");
|
||||
unary_default::<CopyBuffer, E, D>(self.clone())
|
||||
}
|
||||
|
||||
|
|
|
@ -79,6 +79,16 @@ std_func() {
|
|||
# all features
|
||||
echo "Running all-features checks"
|
||||
build_and_test_all_features "burn-dataset"
|
||||
|
||||
cd burn-core || exit
|
||||
|
||||
echo "Test burn-core with tch backend"
|
||||
cargo test --features test-tch
|
||||
|
||||
echo "Test burn-core with wgpu backend"
|
||||
cargo test --features test-wgpu
|
||||
|
||||
cd .. || exit
|
||||
}
|
||||
|
||||
# Run the checks for no_std
|
||||
|
|
Loading…
Reference in New Issue