Fix: wgpu cat + copy ops (#477)

This commit is contained in:
Nathaniel Simard 2023-07-07 12:19:10 -04:00 committed by GitHub
parent 513b9281c2
commit ddbbe39d74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 17 deletions

View File

@ -47,26 +47,35 @@ pub fn cat<E: WgpuElement, const D: usize>(
#[cfg(test)]
mod tests {
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{Distribution, Tensor};
use burn_tensor::{backend::Backend, Distribution, Tensor};
#[test]
fn cat_should_support_multiple_invokations() {
test_same_as_reference([6, 256]);
fn cat_should_support_multiple_invocations_dim0() {
test_same_as_reference([6, 256], 2, 0);
}
#[test]
fn cat_should_support_multiple_invocations_dim1() {
test_same_as_reference([6, 256], 2, 1);
}
#[test]
fn cat_should_support_uneven_launch() {
test_same_as_reference([1, 137]);
test_same_as_reference([1, 137], 2, 0);
}
fn test_same_as_reference(shape: [usize; 2]) {
let tensor1 = Tensor::<TestBackend, 2>::random(shape, Distribution::Default);
let tensor2 = Tensor::<TestBackend, 2>::random(shape, Distribution::Default);
let tensor1_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor1.to_data());
let tensor2_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor2.to_data());
fn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: usize) {
TestBackend::seed(0);
let tensors = (0..num_tensors)
.map(|_| Tensor::<TestBackend, 2>::random(shape, Distribution::Default))
.collect::<Vec<_>>();
let tensors_ref = tensors
.iter()
.map(|tensor| Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data()))
.collect::<Vec<_>>();
let tensor = Tensor::<TestBackend, 2>::cat(vec![tensor1, tensor2], 0);
let tensor_ref = Tensor::<ReferenceBackend, 2>::cat(vec![tensor1_ref, tensor2_ref], 0);
let tensor = Tensor::<TestBackend, 2>::cat(tensors, dim);
let tensor_ref = Tensor::<ReferenceBackend, 2>::cat(tensors_ref, dim);
tensor
.into_data()

View File

@ -58,7 +58,6 @@ pub fn swap_dims<E: WgpuElement, const D: usize>(
dim2: usize,
) -> WgpuTensor<E, D> {
tensor.strides.swap(dim1, dim2);
tensor.shape.dims.swap(dim1, dim2);
tensor

View File

@ -23,8 +23,9 @@ fn main(
let dim_cat = info[4u * dim + 1u];
let dim_cat_index = info[4u * dim + 2u];
var index_input: u32 = 0u;
var index_output: u32 = 0u;
var num_elems = 1u;
var index_input = 0u;
var index_output = 0u;
for (var i: u32 = 1u; i <= dim; i++) {
let stride_input = info[i];
@ -32,8 +33,9 @@ fn main(
let shape_input = info[i + 2u * dim];
let shape_output = info[i + 3u * dim];
let num_block_output = id / stride_output % shape_output;
let num_block_output = id / stride_input % shape_input;
index_input += num_block_output * stride_input;
num_elems *= shape_input;
if i - 1u == dim_cat {
index_output += (num_block_output + dim_cat_index) * stride_output;
@ -42,6 +44,8 @@ fn main(
}
}
if id < num_elems {
output[index_output] = input[index_input];
}
}

View File

@ -72,7 +72,7 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
// slowdowns.
//
// The solution is just to use a simple unary compute shader.
unary!(CopyBuffer, body "output[global_id.x] = input[global_id.x];");
unary!(CopyBuffer, body "output[id] = input[id];");
unary_default::<CopyBuffer, E, D>(self.clone())
}

View File

@ -79,6 +79,16 @@ std_func() {
# all features
echo "Running all-features checks"
build_and_test_all_features "burn-dataset"
cd burn-core || exit
echo "Test burn-core with tch backend"
cargo test --features test-tch
echo "Test burn-core with wgpu backend"
cargo test --features test-wgpu
cd .. || exit
}
# Run the checks for no_std