Fix: wgpu cat + copy ops (#477)

This commit is contained in:
Nathaniel Simard 2023-07-07 12:19:10 -04:00 committed by GitHub
parent 513b9281c2
commit ddbbe39d74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 17 deletions

View File

@ -47,26 +47,35 @@ pub fn cat<E: WgpuElement, const D: usize>(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::tests::{ReferenceBackend, TestBackend}; use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{Distribution, Tensor}; use burn_tensor::{backend::Backend, Distribution, Tensor};
#[test] #[test]
fn cat_should_support_multiple_invokations() { fn cat_should_support_multiple_invocations_dim0() {
test_same_as_reference([6, 256]); test_same_as_reference([6, 256], 2, 0);
}
#[test]
fn cat_should_support_multiple_invocations_dim1() {
test_same_as_reference([6, 256], 2, 1);
} }
#[test] #[test]
fn cat_should_support_uneven_launch() { fn cat_should_support_uneven_launch() {
test_same_as_reference([1, 137]); test_same_as_reference([1, 137], 2, 0);
} }
fn test_same_as_reference(shape: [usize; 2]) { fn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: usize) {
let tensor1 = Tensor::<TestBackend, 2>::random(shape, Distribution::Default); TestBackend::seed(0);
let tensor2 = Tensor::<TestBackend, 2>::random(shape, Distribution::Default); let tensors = (0..num_tensors)
let tensor1_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor1.to_data()); .map(|_| Tensor::<TestBackend, 2>::random(shape, Distribution::Default))
let tensor2_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor2.to_data()); .collect::<Vec<_>>();
let tensors_ref = tensors
.iter()
.map(|tensor| Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data()))
.collect::<Vec<_>>();
let tensor = Tensor::<TestBackend, 2>::cat(vec![tensor1, tensor2], 0); let tensor = Tensor::<TestBackend, 2>::cat(tensors, dim);
let tensor_ref = Tensor::<ReferenceBackend, 2>::cat(vec![tensor1_ref, tensor2_ref], 0); let tensor_ref = Tensor::<ReferenceBackend, 2>::cat(tensors_ref, dim);
tensor tensor
.into_data() .into_data()

View File

@ -58,7 +58,6 @@ pub fn swap_dims<E: WgpuElement, const D: usize>(
dim2: usize, dim2: usize,
) -> WgpuTensor<E, D> { ) -> WgpuTensor<E, D> {
tensor.strides.swap(dim1, dim2); tensor.strides.swap(dim1, dim2);
tensor.shape.dims.swap(dim1, dim2); tensor.shape.dims.swap(dim1, dim2);
tensor tensor

View File

@ -23,8 +23,9 @@ fn main(
let dim_cat = info[4u * dim + 1u]; let dim_cat = info[4u * dim + 1u];
let dim_cat_index = info[4u * dim + 2u]; let dim_cat_index = info[4u * dim + 2u];
var index_input: u32 = 0u; var num_elems = 1u;
var index_output: u32 = 0u; var index_input = 0u;
var index_output = 0u;
for (var i: u32 = 1u; i <= dim; i++) { for (var i: u32 = 1u; i <= dim; i++) {
let stride_input = info[i]; let stride_input = info[i];
@ -32,8 +33,9 @@ fn main(
let shape_input = info[i + 2u * dim]; let shape_input = info[i + 2u * dim];
let shape_output = info[i + 3u * dim]; let shape_output = info[i + 3u * dim];
let num_block_output = id / stride_output % shape_output; let num_block_output = id / stride_input % shape_input;
index_input += num_block_output * stride_input; index_input += num_block_output * stride_input;
num_elems *= shape_input;
if i - 1u == dim_cat { if i - 1u == dim_cat {
index_output += (num_block_output + dim_cat_index) * stride_output; index_output += (num_block_output + dim_cat_index) * stride_output;
@ -42,6 +44,8 @@ fn main(
} }
} }
output[index_output] = input[index_input]; if id < num_elems {
output[index_output] = input[index_input];
}
} }

View File

@ -72,7 +72,7 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
// slowdowns. // slowdowns.
// //
// The solution is just to use a simple unary compute shader. // The solution is just to use a simple unary compute shader.
unary!(CopyBuffer, body "output[global_id.x] = input[global_id.x];"); unary!(CopyBuffer, body "output[id] = input[id];");
unary_default::<CopyBuffer, E, D>(self.clone()) unary_default::<CopyBuffer, E, D>(self.clone())
} }

View File

@ -79,6 +79,16 @@ std_func() {
# all features # all features
echo "Running all-features checks" echo "Running all-features checks"
build_and_test_all_features "burn-dataset" build_and_test_all_features "burn-dataset"
cd burn-core || exit
echo "Test burn-core with tch backend"
cargo test --features test-tch
echo "Test burn-core with wgpu backend"
cargo test --features test-wgpu
cd .. || exit
} }
# Run the checks for no_std # Run the checks for no_std