diff --git a/burn-wgpu/src/kernel/cat.rs b/burn-wgpu/src/kernel/cat.rs index 35c0347b4..5f0fc6bc5 100644 --- a/burn-wgpu/src/kernel/cat.rs +++ b/burn-wgpu/src/kernel/cat.rs @@ -47,26 +47,35 @@ pub fn cat( #[cfg(test)] mod tests { use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Distribution, Tensor}; + use burn_tensor::{backend::Backend, Distribution, Tensor}; #[test] - fn cat_should_support_multiple_invokations() { - test_same_as_reference([6, 256]); + fn cat_should_support_multiple_invocations_dim0() { + test_same_as_reference([6, 256], 2, 0); + } + + #[test] + fn cat_should_support_multiple_invocations_dim1() { + test_same_as_reference([6, 256], 2, 1); } #[test] fn cat_should_support_uneven_launch() { - test_same_as_reference([1, 137]); + test_same_as_reference([1, 137], 2, 0); } - fn test_same_as_reference(shape: [usize; 2]) { - let tensor1 = Tensor::::random(shape, Distribution::Default); - let tensor2 = Tensor::::random(shape, Distribution::Default); - let tensor1_ref = Tensor::::from_data(tensor1.to_data()); - let tensor2_ref = Tensor::::from_data(tensor2.to_data()); + fn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: usize) { + TestBackend::seed(0); + let tensors = (0..num_tensors) + .map(|_| Tensor::::random(shape, Distribution::Default)) + .collect::>(); + let tensors_ref = tensors + .iter() + .map(|tensor| Tensor::::from_data(tensor.to_data())) + .collect::>(); - let tensor = Tensor::::cat(vec![tensor1, tensor2], 0); - let tensor_ref = Tensor::::cat(vec![tensor1_ref, tensor2_ref], 0); + let tensor = Tensor::::cat(tensors, dim); + let tensor_ref = Tensor::::cat(tensors_ref, dim); tensor .into_data() diff --git a/burn-wgpu/src/ops/base.rs b/burn-wgpu/src/ops/base.rs index 05d427f2c..aab93e63d 100644 --- a/burn-wgpu/src/ops/base.rs +++ b/burn-wgpu/src/ops/base.rs @@ -58,7 +58,6 @@ pub fn swap_dims( dim2: usize, ) -> WgpuTensor { tensor.strides.swap(dim1, dim2); - tensor.shape.dims.swap(dim1, dim2); tensor diff --git a/burn-wgpu/src/template/cat.wgsl b/burn-wgpu/src/template/cat.wgsl index 523a79fa9..830fd560b 100644 --- a/burn-wgpu/src/template/cat.wgsl +++ b/burn-wgpu/src/template/cat.wgsl @@ -23,8 +23,9 @@ fn main( let dim_cat = info[4u * dim + 1u]; let dim_cat_index = info[4u * dim + 2u]; - var index_input: u32 = 0u; - var index_output: u32 = 0u; + var num_elems = 1u; + var index_input = 0u; + var index_output = 0u; for (var i: u32 = 1u; i <= dim; i++) { let stride_input = info[i]; @@ -32,8 +33,9 @@ fn main( let shape_input = info[i + 2u * dim]; let shape_output = info[i + 3u * dim]; - let num_block_output = id / stride_output % shape_output; + let num_block_output = id / stride_input % shape_input; index_input += num_block_output * stride_input; + num_elems *= shape_input; if i - 1u == dim_cat { index_output += (num_block_output + dim_cat_index) * stride_output; @@ -42,6 +44,8 @@ fn main( } } - output[index_output] = input[index_input]; + if id < num_elems { + output[index_output] = input[index_input]; + } } diff --git a/burn-wgpu/src/tensor/base.rs b/burn-wgpu/src/tensor/base.rs index 591ac8b3b..c0041d800 100644 --- a/burn-wgpu/src/tensor/base.rs +++ b/burn-wgpu/src/tensor/base.rs @@ -72,7 +72,7 @@ impl WgpuTensor { // slowdowns. // // The solution is just to use a simple unary compute shader. - unary!(CopyBuffer, body "output[global_id.x] = input[global_id.x];"); + unary!(CopyBuffer, body "output[id] = input[id];"); unary_default::(self.clone()) } diff --git a/run-checks.sh b/run-checks.sh index 105706e12..d9269f05b 100755 --- a/run-checks.sh +++ b/run-checks.sh @@ -79,6 +79,16 @@ std_func() { # all features echo "Running all-features checks" build_and_test_all_features "burn-dataset" + + cd burn-core || exit + + echo "Test burn-core with tch backend" + cargo test --features test-tch + + echo "Test burn-core with wgpu backend" + cargo test --features test-wgpu + + cd .. || exit } # Run the checks for no_std