Refactor index => slice (#466)

This commit is contained in:
Nathaniel Simard 2023-07-05 16:30:11 -04:00 committed by GitHub
parent 042d2201d2
commit 65bf6c1cbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
76 changed files with 1558 additions and 1290 deletions

View File

@ -44,11 +44,11 @@ impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B>
B::bool_reshape(tensor, shape)
}
fn bool_index<const D1: usize, const D2: usize>(
fn bool_slice<const D1: usize, const D2: usize>(
tensor: BoolTensor<B, D1>,
indexes: [std::ops::Range<usize>; D2],
ranges: [std::ops::Range<usize>; D2],
) -> BoolTensor<B, D1> {
B::bool_index(tensor, indexes)
B::bool_slice(tensor, ranges)
}
fn bool_empty<const D: usize>(
@ -58,12 +58,12 @@ impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B>
B::bool_empty(shape, device)
}
fn bool_index_assign<const D1: usize, const D2: usize>(
tensor: <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1>,
indexes: [std::ops::Range<usize>; D2],
value: <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1>,
) -> <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1> {
B::bool_index_assign(tensor, indexes, value)
fn bool_slice_assign<const D1: usize, const D2: usize>(
tensor: BoolTensor<Self, D1>,
ranges: [std::ops::Range<usize>; D2],
value: BoolTensor<Self, D1>,
) -> BoolTensor<Self, D1> {
B::bool_slice_assign(tensor, ranges, value)
}
fn bool_cat<const D: usize>(tensors: Vec<BoolTensor<B, D>>, dim: usize) -> BoolTensor<B, D> {

View File

@ -43,11 +43,11 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::int_reshape(tensor, shape)
}
fn int_index<const D1: usize, const D2: usize>(
fn int_slice<const D1: usize, const D2: usize>(
tensor: IntTensor<B, D1>,
indexes: [std::ops::Range<usize>; D2],
ranges: [std::ops::Range<usize>; D2],
) -> IntTensor<B, D1> {
B::int_index(tensor, indexes)
B::int_slice(tensor, ranges)
}
fn int_empty<const D: usize>(
@ -57,12 +57,12 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::int_empty(shape, device)
}
fn int_index_assign<const D1: usize, const D2: usize>(
fn int_slice_assign<const D1: usize, const D2: usize>(
tensor: IntTensor<B, D1>,
indexes: [std::ops::Range<usize>; D2],
ranges: [std::ops::Range<usize>; D2],
value: IntTensor<B, D1>,
) -> IntTensor<B, D1> {
B::int_index_assign(tensor, indexes, value)
B::int_slice_assign(tensor, ranges, value)
}
fn int_cat<const D: usize>(tensors: Vec<IntTensor<B, D>>, dim: usize) -> IntTensor<B, D> {
@ -198,35 +198,35 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn int_gather<const D: usize>(
dim: usize,
tensor: IntTensor<B, D>,
indexes: IntTensor<B, D>,
indices: IntTensor<B, D>,
) -> IntTensor<B, D> {
B::int_gather(dim, tensor, indexes)
B::int_gather(dim, tensor, indices)
}
fn int_scatter<const D: usize>(
dim: usize,
tensor: IntTensor<B, D>,
indexes: IntTensor<B, D>,
indices: IntTensor<B, D>,
value: IntTensor<B, D>,
) -> IntTensor<B, D> {
B::int_scatter(dim, tensor, indexes, value)
B::int_scatter(dim, tensor, indices, value)
}
fn int_index_select_dim<const D: usize>(
fn int_select<const D: usize>(
tensor: IntTensor<B, D>,
dim: usize,
indexes: IntTensor<B, 1>,
indices: IntTensor<B, 1>,
) -> IntTensor<B, D> {
B::int_index_select_dim(tensor, dim, indexes)
B::int_select(tensor, dim, indices)
}
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: IntTensor<B, D1>,
fn int_select_assign<const D: usize>(
tensor: IntTensor<B, D>,
dim: usize,
indexes: IntTensor<B, 1>,
value: IntTensor<B, D2>,
) -> IntTensor<B, D1> {
B::int_index_select_dim_assign(tensor, dim, indexes, value)
indices: IntTensor<B, 1>,
value: IntTensor<B, D>,
) -> IntTensor<B, D> {
B::int_select_assign(tensor, dim, indices, value)
}
fn int_mask_where<const D: usize>(
@ -260,11 +260,11 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
) -> B::IntTensorPrimitive<D> {
B::int_max_dim(tensor, dim)
}
fn int_max_dim_with_indexes<const D: usize>(
fn int_max_dim_with_indices<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
B::int_max_dim_with_indexes(tensor, dim)
B::int_max_dim_with_indices(tensor, dim)
}
fn int_min<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1> {
B::int_min(tensor)
@ -275,10 +275,10 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
) -> B::IntTensorPrimitive<D> {
B::int_min_dim(tensor, dim)
}
fn int_min_dim_with_indexes<const D: usize>(
fn int_min_dim_with_indices<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
B::int_min_dim_with_indexes(tensor, dim)
B::int_min_dim_with_indices(tensor, dim)
}
}

View File

@ -10,11 +10,11 @@ impl<B: Backend, const D: usize> Backward<B, D, 1> for MaxMinDim {
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let (indexes, shape) = ops.state;
let (indices, shape) = ops.state;
let device = B::device(&grad);
let zeros = B::zeros(shape, &device);
B::scatter(D - 1, zeros, indexes, grad)
B::scatter(D - 1, zeros, indices, grad)
});
}
}

View File

@ -9,7 +9,7 @@ use burn_tensor::ops::*;
use super::OpsKind;
impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn embedding(weights: ADTensor<B, 2>, indexes: IntTensor<B, 2>) -> ADTensor<B, 3> {
fn embedding(weights: ADTensor<B, 2>, indices: IntTensor<B, 2>) -> ADTensor<B, 3> {
#[derive(Debug)]
struct Embedding;
@ -17,10 +17,10 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
type State = (B::TensorPrimitive<2>, IntTensor<B, 2>);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let (weights, indexes) = ops.state;
let (weights, indices) = ops.state;
unary::<B, 3, 2, _>(ops.parents, ops.node, grads, |grad| {
B::embedding_backward(weights, grad, indexes)
B::embedding_backward(weights, grad, indices)
});
}
}
@ -30,19 +30,19 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
.statefull()
{
OpsKind::Tracked(prep) => prep.finish(
(weights.primitive.clone(), indexes.clone()),
B::embedding(weights.primitive, indexes),
(weights.primitive.clone(), indices.clone()),
B::embedding(weights.primitive, indices),
),
OpsKind::UnTracked(prep) => prep.finish(B::embedding(weights.primitive, indexes)),
OpsKind::UnTracked(prep) => prep.finish(B::embedding(weights.primitive, indices)),
}
}
fn embedding_backward(
weights: ADTensor<B, 2>,
output: ADTensor<B, 3>,
indexes: IntTensor<B, 2>,
indices: IntTensor<B, 2>,
) -> ADTensor<B, 2> {
let tensor = B::embedding_backward(weights.primitive, output.primitive, indexes);
let tensor = B::embedding_backward(weights.primitive, output.primitive, indices);
ADTensor::new(tensor)
}
@ -360,9 +360,9 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
match MaxPool2D.prepare([x.node], [x.graph]).statefull() {
OpsKind::Tracked(prep) => {
let output =
B::max_pool2d_with_indexes(x.primitive.clone(), kernel_size, stride, padding);
B::max_pool2d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
prep.finish(
(x.primitive, output.indexes, kernel_size, stride, padding),
(x.primitive, output.indices, kernel_size, stride, padding),
output.output,
)
}
@ -372,21 +372,21 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
}
fn max_pool2d_with_indexes(
fn max_pool2d_with_indices(
x: ADTensor<B, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> MaxPool2dWithIndexes<ADBackendDecorator<B>> {
) -> MaxPool2dWithIndices<ADBackendDecorator<B>> {
match MaxPool2D.prepare([x.node], [x.graph]).statefull() {
OpsKind::Tracked(prep) => {
let output =
B::max_pool2d_with_indexes(x.primitive.clone(), kernel_size, stride, padding);
B::max_pool2d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
let output_tensor = prep.finish(
(
x.primitive,
output.indexes.clone(),
output.indices.clone(),
kernel_size,
stride,
padding,
@ -394,32 +394,32 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
output.output,
);
MaxPool2dWithIndexes::new(output_tensor, output.indexes)
MaxPool2dWithIndices::new(output_tensor, output.indices)
}
OpsKind::UnTracked(prep) => {
let output = B::max_pool2d_with_indexes(x.primitive, kernel_size, stride, padding);
let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding);
let output_tensor = prep.finish(output.output);
MaxPool2dWithIndexes::new(output_tensor, output.indexes)
MaxPool2dWithIndices::new(output_tensor, output.indices)
}
}
}
fn max_pool2d_with_indexes_backward(
fn max_pool2d_with_indices_backward(
x: ADTensor<B, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
output_grad: ADTensor<B, 4>,
indexes: IntTensor<B, 4>,
indices: IntTensor<B, 4>,
) -> MaxPool2dBackward<ADBackendDecorator<B>> {
let output = B::max_pool2d_with_indexes_backward(
let output = B::max_pool2d_with_indices_backward(
x.primitive,
kernel_size,
stride,
padding,
output_grad.primitive,
indexes,
indices,
);
MaxPool2dBackward::new(ADTensor::new(output.x_grad))
}
@ -440,11 +440,11 @@ impl<B: Backend> Backward<B, 4, 1> for MaxPool2D {
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let [node_parent] = ops.parents;
let grad = grads.consume::<B, 4>(&ops.node);
let (x, indexes, kernel_size, stride, padding) = ops.state;
let (x, indices, kernel_size, stride, padding) = ops.state;
if let Some(node) = node_parent {
let grad =
B::max_pool2d_with_indexes_backward(x, kernel_size, stride, padding, grad, indexes);
B::max_pool2d_with_indices_backward(x, kernel_size, stride, padding, grad, indices);
grads.register::<B, 4>(node, grad.x_grad);
}

View File

@ -441,7 +441,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn gather<const D: usize>(
dim: usize,
tensor: ADTensor<B, D>,
indexes: IntTensor<B, D>,
indices: IntTensor<B, D>,
) -> ADTensor<B, D> {
#[derive(Debug)]
struct Gather;
@ -450,11 +450,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
type State = (usize, IntTensor<B, D>, Shape<D>, B::Device);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let (dim, indexes, shape, device) = ops.state;
let (dim, indices, shape, device) = ops.state;
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let zeros = B::zeros(shape, &device);
B::scatter(dim, zeros, indexes, grad)
B::scatter(dim, zeros, indices, grad)
});
}
}
@ -463,20 +463,20 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
OpsKind::Tracked(prep) => prep.finish(
(
dim,
indexes.clone(),
indices.clone(),
B::shape(&tensor.primitive),
B::device(&tensor.primitive),
),
B::gather(dim, tensor.primitive, indexes),
B::gather(dim, tensor.primitive, indices),
),
OpsKind::UnTracked(prep) => prep.finish(B::gather(dim, tensor.primitive, indexes)),
OpsKind::UnTracked(prep) => prep.finish(B::gather(dim, tensor.primitive, indices)),
}
}
fn scatter<const D: usize>(
dim: usize,
tensor: ADTensor<B, D>,
indexes: IntTensor<B, D>,
indices: IntTensor<B, D>,
value: ADTensor<B, D>,
) -> ADTensor<B, D> {
#[derive(Debug)]
@ -486,8 +486,8 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
type State = (usize, IntTensor<B, D>, Shape<D>, Shape<D>, B::Device);
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
let (dim, indexes, shape_lhs, shape_rhs, device) = ops.state;
let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes));
let (dim, indices, shape_lhs, shape_rhs, device) = ops.state;
let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices));
binary::<B, D, D, D, _, _>(
ops.parents,
@ -495,11 +495,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
grads,
|grad| {
let zeros = B::zeros(shape_lhs, &device);
B::scatter(dim, grad, indexes_4lhs.unwrap(), zeros)
B::scatter(dim, grad, indices_4lhs.unwrap(), zeros)
},
|grad| {
let zeros = B::zeros(shape_rhs, &device);
B::scatter(dim, zeros, indexes_4rhs.unwrap(), grad)
B::scatter(dim, zeros, indices_4rhs.unwrap(), grad)
},
);
}
@ -512,23 +512,23 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
OpsKind::Tracked(prep) => prep.finish(
(
dim,
indexes.clone(),
indices.clone(),
B::shape(&tensor.primitive),
B::shape(&value.primitive),
B::device(&value.primitive),
),
B::scatter(dim, tensor.primitive, indexes, value.primitive),
B::scatter(dim, tensor.primitive, indices, value.primitive),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::scatter(dim, tensor.primitive, indexes, value.primitive))
prep.finish(B::scatter(dim, tensor.primitive, indices, value.primitive))
}
}
}
fn index_select<const D: usize>(
fn select<const D: usize>(
tensor: ADTensor<B, D>,
dim: usize,
indexes: IntTensor<B, 1>,
indices: IntTensor<B, 1>,
) -> ADTensor<B, D> {
#[derive(Debug)]
struct IndexSelectDim;
@ -537,11 +537,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
type State = (usize, IntTensor<B, 1>, Shape<D>, B::Device);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let (dim, indexes, shape, device) = ops.state;
let (dim, indices, shape, device) = ops.state;
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let zeros = B::zeros(shape, &device);
B::index_select_assign(zeros, dim, indexes, grad)
B::select_assign(zeros, dim, indices, grad)
});
}
}
@ -553,76 +553,74 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
OpsKind::Tracked(prep) => prep.finish(
(
dim,
indexes.clone(),
indices.clone(),
B::shape(&tensor.primitive),
B::device(&tensor.primitive),
),
B::index_select(tensor.primitive, dim, indexes),
B::select(tensor.primitive, dim, indices),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::index_select(tensor.primitive, dim, indexes))
}
OpsKind::UnTracked(prep) => prep.finish(B::select(tensor.primitive, dim, indices)),
}
}
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: ADTensor<B, D1>,
fn select_assign<const D: usize>(
tensor: ADTensor<B, D>,
dim: usize,
indexes: IntTensor<B, 1>,
value: ADTensor<B, D2>,
) -> ADTensor<B, D1> {
indices: IntTensor<B, 1>,
value: ADTensor<B, D>,
) -> ADTensor<B, D> {
#[derive(Debug)]
struct IndexSelectDimAssign<const D2: usize>;
struct IndexSelectDimAssign<const D: usize>;
impl<B: Backend, const D1: usize, const D2: usize> Backward<B, D1, 2> for IndexSelectDimAssign<D2> {
type State = (usize, IntTensor<B, 1>, Shape<D1>, Shape<D2>, B::Device);
impl<B: Backend, const D: usize> Backward<B, D, 2> for IndexSelectDimAssign<D> {
type State = (usize, IntTensor<B, 1>, Shape<D>, Shape<D>, B::Device);
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
let (dim, indexes, shape_lhs, shape_rhs, device) = ops.state;
let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes));
let (dim, indices, shape_lhs, shape_rhs, device) = ops.state;
let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices));
binary::<B, D1, D1, D2, _, _>(
binary::<B, D, D, D, _, _>(
ops.parents,
ops.node,
grads,
|grad| {
let zeros = B::zeros(shape_lhs, &device);
B::index_select_assign(grad, dim, indexes_4lhs.unwrap(), zeros)
B::select_assign(grad, dim, indices_4lhs.unwrap(), zeros)
},
|grad| {
let zeros = B::zeros(shape_rhs, &device);
B::index_select_assign(zeros, dim, indexes_4rhs.unwrap(), grad)
B::select_assign(zeros, dim, indices_4rhs.unwrap(), grad)
},
);
}
}
match IndexSelectDimAssign::<D2>
match IndexSelectDimAssign::<D>
.prepare([tensor.node, value.node], [tensor.graph, value.graph])
.statefull()
{
OpsKind::Tracked(prep) => prep.finish(
(
dim,
indexes.clone(),
indices.clone(),
B::shape(&tensor.primitive),
B::shape(&value.primitive),
B::device(&value.primitive),
),
B::index_select_assign(tensor.primitive, dim, indexes, value.primitive),
B::select_assign(tensor.primitive, dim, indices, value.primitive),
),
OpsKind::UnTracked(prep) => prep.finish(B::index_select_assign(
OpsKind::UnTracked(prep) => prep.finish(B::select_assign(
tensor.primitive,
dim,
indexes,
indices,
value.primitive,
)),
}
}
fn index<const D1: usize, const D2: usize>(
fn slice<const D1: usize, const D2: usize>(
tensor: ADTensor<B, D1>,
indexes: [std::ops::Range<usize>; D2],
ranges: [std::ops::Range<usize>; D2],
) -> ADTensor<B, D1> {
#[derive(Debug)]
struct Index<const D2: usize>;
@ -631,11 +629,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
type State = ([std::ops::Range<usize>; D2], Shape<D1>, B::Device);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let (indexes, shape, device) = ops.state;
let (ranges, shape, device) = ops.state;
unary::<B, D1, D1, _>(ops.parents, ops.node, grads, |grad| {
let zeros = B::zeros(shape, &device);
B::index_assign(zeros, indexes, grad)
B::slice_assign(zeros, ranges, grad)
});
}
}
@ -643,19 +641,19 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
match Index.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => prep.finish(
(
indexes.clone(),
ranges.clone(),
B::shape(&tensor.primitive),
B::device(&tensor.primitive),
),
B::index(tensor.primitive, indexes),
B::slice(tensor.primitive, ranges),
),
OpsKind::UnTracked(prep) => prep.finish(B::index(tensor.primitive, indexes)),
OpsKind::UnTracked(prep) => prep.finish(B::slice(tensor.primitive, ranges)),
}
}
fn index_assign<const D1: usize, const D2: usize>(
fn slice_assign<const D1: usize, const D2: usize>(
tensor: ADTensor<B, D1>,
indexes: [std::ops::Range<usize>; D2],
ranges: [std::ops::Range<usize>; D2],
value: ADTensor<B, D1>,
) -> ADTensor<B, D1> {
#[derive(Debug)]
@ -665,8 +663,8 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
type State = ([std::ops::Range<usize>; D2], Shape<D1>, B::Device);
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
let (indexes, shape_rhs, device) = ops.state;
let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes));
let (ranges, shape_rhs, device) = ops.state;
let [ranges_4lhs, ranges_4rhs] = duplicate(&ops.parents, Some(ranges));
binary::<B, D1, D1, D1, _, _>(
ops.parents,
@ -674,9 +672,9 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
grads,
|grad| {
let zeros = B::zeros(shape_rhs, &device);
B::index_assign(grad, indexes_4lhs.unwrap(), zeros)
B::slice_assign(grad, ranges_4lhs.unwrap(), zeros)
},
|grad| B::index(grad, indexes_4rhs.unwrap()),
|grad| B::slice(grad, ranges_4rhs.unwrap()),
);
}
}
@ -687,14 +685,14 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
{
OpsKind::Tracked(prep) => prep.finish(
(
indexes.clone(),
ranges.clone(),
B::shape(&value.primitive),
B::device(&value.primitive),
),
B::index_assign(tensor.primitive, indexes, value.primitive),
B::slice_assign(tensor.primitive, ranges, value.primitive),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::index_assign(tensor.primitive, indexes, value.primitive))
prep.finish(B::slice_assign(tensor.primitive, ranges, value.primitive))
}
}
}
@ -1263,8 +1261,8 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
impl<B: Backend, const D: usize> Step for CatStep<B, D> {
fn step(self: Box<Self>, grads: &mut Gradients) {
let grad = grads.consume::<B, D>(&self.output);
let indexes: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect();
let indexes: [std::ops::Range<usize>; D] = indexes.try_into().unwrap();
let ranges: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect();
let ranges: [std::ops::Range<usize>; D] = ranges.try_into().unwrap();
let mut current_index = 0;
@ -1273,10 +1271,10 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
.zip(self.dim_sizes.into_iter())
.filter_map(|(node, dim_size)| node.map(|node| (node, dim_size)))
.for_each(|(node, dim_size)| {
let mut indexes = indexes.clone();
indexes[self.dim] = current_index..dim_size + current_index;
let mut ranges = ranges.clone();
ranges[self.dim] = current_index..dim_size + current_index;
current_index += dim_size;
grads.register::<B, D>(node, B::index(grad.clone(), indexes));
grads.register::<B, D>(node, B::slice(grad.clone(), ranges));
});
}
@ -1318,26 +1316,26 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => {
let shape = B::shape(&tensor.primitive);
let (tensor, index) = B::max_dim_with_indexes(tensor.primitive, dim);
let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim);
prep.finish((index, shape), tensor)
}
OpsKind::UnTracked(prep) => prep.finish(B::max_dim(tensor.primitive, dim)),
}
}
fn max_dim_with_indexes<const D: usize>(
fn max_dim_with_indices<const D: usize>(
tensor: ADTensor<B, D>,
dim: usize,
) -> (ADTensor<B, D>, IntTensor<B, D>) {
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => {
let shape = B::shape(&tensor.primitive);
let (tensor, index) = B::max_dim_with_indexes(tensor.primitive, dim);
let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim);
let tensor = prep.finish((index.clone(), shape), tensor);
(tensor, index)
}
OpsKind::UnTracked(prep) => {
let (tensor, index) = B::max_dim_with_indexes(tensor.primitive, dim);
let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim);
let tensor = prep.finish(tensor);
(tensor, index)
@ -1348,26 +1346,26 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => {
let shape = B::shape(&tensor.primitive);
let (tensor, index) = B::min_dim_with_indexes(tensor.primitive, dim);
let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim);
prep.finish((index, shape), tensor)
}
OpsKind::UnTracked(prep) => prep.finish(B::min_dim(tensor.primitive, dim)),
}
}
fn min_dim_with_indexes<const D: usize>(
fn min_dim_with_indices<const D: usize>(
tensor: ADTensor<B, D>,
dim: usize,
) -> (ADTensor<B, D>, IntTensor<B, D>) {
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => {
let shape = B::shape(&tensor.primitive);
let (tensor, index) = B::min_dim_with_indexes(tensor.primitive, dim);
let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim);
let tensor = prep.finish((index.clone(), shape), tensor);
(tensor, index)
}
OpsKind::UnTracked(prep) => {
let (tensor, index) = B::min_dim_with_indexes(tensor.primitive, dim);
let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim);
let tensor = prep.finish(tensor);
(tensor, index)

View File

@ -6,16 +6,16 @@ mod tests {
#[test]
fn test_embedding_backward() {
let weights = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indexes = Data::from([[0, 1], [1, 1]]);
let indices = Data::from([[0, 1], [1, 1]]);
let x = Data::from([
[[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]],
[[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]],
]);
let weights = Tensor::<TestADBackend, 2>::from_data(weights).require_grad();
let indexes = Tensor::<TestADBackend, 2, Int>::from_data(indexes);
let indices = Tensor::<TestADBackend, 2, Int>::from_data(indices);
let x = Tensor::<TestADBackend, 3>::from_data(x).require_grad();
let output = embedding(weights.clone(), indexes);
let output = embedding(weights.clone(), indices);
let output = output.matmul(x);
let grads = output.backward();

View File

@ -18,8 +18,8 @@ mod tests {
let mut tensor_2_list = Vec::new();
for i in 0..2 {
tensor_1_list.push(tensor_1.clone().index([i..i + 1]));
tensor_2_list.push(tensor_2.clone().index([i..i + 1]));
tensor_1_list.push(tensor_1.clone().slice([i..i + 1]));
tensor_2_list.push(tensor_2.clone().slice([i..i + 1]));
}
let tensor_1_cat = TestADTensor::cat(tensor_1_list.clone(), 0);
@ -28,31 +28,31 @@ mod tests {
let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone());
let grads = tensor_3_cat.backward();
let grad_1_index_1 = tensor_1.grad(&grads).unwrap().index([0..1]);
let grad_1_index_2 = tensor_1.grad(&grads).unwrap().index([1..2]);
let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]);
let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]);
let grad_2_index_1 = tensor_2.grad(&grads).unwrap().index([0..1]);
let grad_2_index_2 = tensor_2.grad(&grads).unwrap().index([1..2]);
let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]);
let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]);
grad_1
.clone()
.index([0..1])
.slice([0..1])
.to_data()
.assert_approx_eq(&grad_1_index_1.to_data(), 3);
.assert_approx_eq(&grad_1_slice_1.to_data(), 3);
grad_1
.index([1..2])
.slice([1..2])
.to_data()
.assert_approx_eq(&grad_1_index_2.to_data(), 3);
.assert_approx_eq(&grad_1_slice_2.to_data(), 3);
grad_2
.clone()
.index([0..1])
.slice([0..1])
.to_data()
.assert_approx_eq(&grad_2_index_1.to_data(), 3);
.assert_approx_eq(&grad_2_slice_1.to_data(), 3);
grad_2
.index([1..2])
.slice([1..2])
.to_data()
.assert_approx_eq(&grad_2_index_2.to_data(), 3);
.assert_approx_eq(&grad_2_slice_2.to_data(), 3);
}
#[test]

View File

@ -7,10 +7,10 @@ mod tests {
fn test_gather_grad() {
let tensor_1 =
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]));
let indices = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]));
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().gather(1, indexes);
let tensor_3 = tensor_1.clone().gather(1, indices);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
@ -29,10 +29,10 @@ mod tests {
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let values =
TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad();
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]]));
let indices = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]]));
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().scatter(1, indexes, values.clone());
let tensor_3 = tensor_1.clone().scatter(1, indices, values.clone());
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();

View File

@ -17,8 +17,6 @@ mod erf;
mod exp;
mod gather_scatter;
mod gelu;
mod index;
mod index_select;
mod log;
mod log1p;
mod mask;
@ -31,7 +29,9 @@ mod neg;
mod pow;
mod relu;
mod reshape;
mod select;
mod sin;
mod slice;
mod softmax;
mod sqrt;
mod sub;
@ -71,9 +71,9 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_div!();
burn_autodiff::testgen_ad_erf!();
burn_autodiff::testgen_ad_exp!();
burn_autodiff::testgen_ad_index!();
burn_autodiff::testgen_ad_slice!();
burn_autodiff::testgen_ad_gather_scatter!();
burn_autodiff::testgen_ad_index_select!();
burn_autodiff::testgen_ad_select!();
burn_autodiff::testgen_ad_log!();
burn_autodiff::testgen_ad_log1p!();
burn_autodiff::testgen_ad_mask!();

View File

@ -1,16 +1,16 @@
#[burn_tensor_testgen::testgen(ad_index_select)]
#[burn_tensor_testgen::testgen(ad_select)]
mod tests {
use super::*;
use burn_tensor::Data;
#[test]
fn test_index_select_grad() {
fn test_select_grad() {
let tensor_1 =
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let indexes = TestADTensor::from_data(Data::from([1, 0]));
let indices = TestADTensor::from_data(Data::from([1, 0]));
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().index_select(0, indexes);
let tensor_3 = tensor_1.clone().select(0, indices);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
@ -24,17 +24,15 @@ mod tests {
}
#[test]
fn test_index_select_assign_grad() {
fn test_select_assign_grad() {
let tensor_1 =
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let values =
TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad();
let indexes = TestADTensor::from_data(Data::from([1, 0]));
let indices = TestADTensor::from_data(Data::from([1, 0]));
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1
.clone()
.index_select_assign(0, indexes, values.clone());
let tensor_3 = tensor_1.clone().select_assign(0, indices, values.clone());
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();

View File

@ -1,17 +1,17 @@
#[burn_tensor_testgen::testgen(ad_index)]
#[burn_tensor_testgen::testgen(ad_slice)]
mod tests {
use super::*;
use burn_tensor::Data;
#[test]
fn should_diff_matmul_with_index() {
fn should_diff_matmul_with_slice() {
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0, 100.0], [2.0, 3.0, 15.0]]);
let tensor_1 = TestADTensor::from_data(data_1).require_grad();
let tensor_2 = TestADTensor::from_data(data_2).require_grad();
let tensor_3 = tensor_2.clone().index([0..2, 0..2]);
let tensor_3 = tensor_2.clone().slice([0..2, 0..2]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();
@ -26,7 +26,7 @@ mod tests {
}
#[test]
fn should_diff_matmul_with_index_assign() {
fn should_diff_matmul_with_slice_assign() {
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
let data_assigned: Data<f32, 2> = Data::from([[9.0]]);
@ -36,7 +36,7 @@ mod tests {
let tensor_assigned = TestADTensor::from_data(data_assigned).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.index_assign([0..1, 0..1], tensor_assigned);
let tensor_4 = tensor_3.slice_assign([0..1, 0..1], tensor_assigned);
let tensor_5 = tensor_4.matmul(tensor_1.clone());
let grads = tensor_5.backward();
@ -49,7 +49,7 @@ mod tests {
}
#[test]
fn should_diff_matmul_with_index_assign_complex() {
fn should_diff_matmul_with_slice_assign_complex() {
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3: Data<f32, 2> = Data::from([[9.0]]);
@ -59,9 +59,9 @@ mod tests {
let tensor_3 = TestADTensor::from_data(data_3).require_grad();
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_2.clone().index([0..1, 0..1]);
let tensor_5 = tensor_2.clone().slice([0..1, 0..1]);
let tensor_6 = tensor_5.mul(tensor_3.clone());
let tensor_7 = tensor_4.index_assign([0..1, 0..1], tensor_6);
let tensor_7 = tensor_4.slice_assign([0..1, 0..1], tensor_6);
let tensor_8 = tensor_7.matmul(tensor_1.clone());
let grads = tensor_8.backward();

View File

@ -14,7 +14,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
for i in 0..(seq_length - 1) {
let values = Tensor::<B, 3, Int>::ones([1, 1, seq_length - (i + 1)]);
mask = mask.index_assign([0..1, i..i + 1, i + 1..seq_length], values);
mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values);
}
mask = mask.to_device(device).repeat(0, batch_size);
@ -68,7 +68,7 @@ pub fn generate_padding_mask<B: Backend>(
}
}
tensor = tensor.index_assign(
tensor = tensor.slice_assign(
[index..index + 1, 0..tokens.len()],
Tensor::from_data(Data::new(
tokens.into_iter().map(|e| (e as i64).elem()).collect(),

View File

@ -366,7 +366,7 @@ mod tests {
// Create a padding mask
let mask_pad: Tensor<TestBackend, 2, Int> = Tensor::zeros([batch_size, seq_length]);
let mask_pad = mask_pad.index_assign(
let mask_pad = mask_pad.slice_assign(
[0..batch_size, seq_length - num_padded..seq_length],
Tensor::ones([batch_size, num_padded]),
);
@ -377,7 +377,7 @@ mod tests {
Distribution::Standard,
);
// Change the end of the tensor
let tensor_2 = tensor_1.clone().index_assign(
let tensor_2 = tensor_1.clone().slice_assign(
[
0..batch_size,
seq_length - num_padded..seq_length,
@ -395,12 +395,12 @@ mod tests {
// Check that the begginning of each tensor is the same
output_1
.context
.index([0..batch_size, 0..seq_length - num_padded, 0..d_model])
.slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
.into_data()
.assert_approx_eq(
&output_2
.context
.index([0..batch_size, 0..seq_length - num_padded, 0..d_model])
.slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
.into_data(),
3,
);
@ -423,9 +423,9 @@ mod tests {
let mut cache = MhaCache::autoregressive();
for i in 1..seq_length + 1 {
let tensor = tensor.clone().index([0..batch_size, 0..i, 0..d_model]);
let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
let input = MhaInput::self_attn(tensor);
let next_tok = mha.forward_cache(input, &mut cache).context.index([
let next_tok = mha.forward_cache(input, &mut cache).context.slice([
0..batch_size,
i - 1..i,
0..d_model,

View File

@ -21,7 +21,7 @@ impl<B: Backend, const D: usize> TensorCache<B, D> {
CacheState::Value(tensor_old) => {
let [batch_size, seq_length, d_model] = tensor.dims();
let next_seq_token =
tensor.index([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]);
tensor.slice([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]);
let next_seq_token = func(next_seq_token);
Tensor::cat(vec![tensor_old, next_seq_token], dim_cat)

View File

@ -120,14 +120,8 @@ impl<B: Backend> Gru<B> {
for t in 0..seq_length {
let indices = Tensor::arange(t..t + 1);
let input_t = batched_input
.clone()
.index_select(1, indices.clone())
.squeeze(1);
let hidden_t = hidden_state
.clone()
.index_select(1, indices.clone())
.squeeze(1);
let input_t = batched_input.clone().select(1, indices.clone()).squeeze(1);
let hidden_t = hidden_state.clone().select(1, indices.clone()).squeeze(1);
// u(pdate)g(ate) tensors
let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate);
@ -149,7 +143,7 @@ impl<B: Backend> Gru<B> {
.mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1)
+ update_values.clone().mul(hidden_t);
hidden_state = hidden_state.index_assign(
hidden_state = hidden_state.slice_assign(
[0..self.batch_size, t..(t + 1), 0..self.d_hidden],
state_vector.clone().unsqueeze(),
);
@ -268,7 +262,7 @@ mod tests {
let state = gru.forward(input, None);
let output = state.index_select(0, Tensor::arange(0..1)).squeeze(0);
let output = state.select(0, Tensor::arange(0..1)).squeeze(0);
output.to_data().assert_approx_eq(&Data::from([[0.034]]), 3);
}

View File

@ -141,7 +141,7 @@ impl<B: Backend> Lstm<B> {
for t in 0..seq_length {
let indices = Tensor::arange(t..t + 1);
let input_t = batched_input.clone().index_select(1, indices).squeeze(1);
let input_t = batched_input.clone().select(1, indices).squeeze(1);
// f(orget)g(ate) tensors
let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate);
let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state
@ -162,11 +162,11 @@ impl<B: Backend> Lstm<B> {
hidden_state = output_values * cell_state.clone().tanh();
// store the state for this timestep
batched_cell_state = batched_cell_state.index_assign(
batched_cell_state = batched_cell_state.slice_assign(
[0..self.batch_size, t..(t + 1), 0..self.d_hidden],
cell_state.clone().unsqueeze(),
);
batched_hidden_state = batched_hidden_state.index_assign(
batched_hidden_state = batched_hidden_state.slice_assign(
[0..self.batch_size, t..(t + 1), 0..self.d_hidden],
hidden_state.clone().unsqueeze(),
);
@ -312,11 +312,9 @@ mod tests {
let input = Tensor::<TestBackend, 3>::from_data(Data::from([[[0.1]]]));
let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None);
let cell_state = cell_state_batch
.index_select(0, Tensor::arange(0..1))
.squeeze(0);
let cell_state = cell_state_batch.select(0, Tensor::arange(0..1)).squeeze(0);
let hidden_state = hidden_state_batch
.index_select(0, Tensor::arange(0..1))
.select(0, Tensor::arange(0..1))
.squeeze(0);
cell_state
.to_data()

View File

@ -421,14 +421,14 @@ mod tests {
let mut cache = transformer.new_autoregressive_cache();
for i in 1..seq_length + 1 {
let target = target.clone().index([0..batch_size, 0..i, 0..d_model]);
let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]);
let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device());
let input = TransformerDecoderInput::new(target.clone(), memory.clone())
.target_mask_attn(mask_attn);
let next_tok = transformer // Greedy sampling
.forward_autoregressive_inference(input, &mut cache)
.index([0..batch_size, i - 1..i, 0..d_model]);
.slice([0..batch_size, i - 1..i, 0..d_model]);
output_2.push(next_tok);
}

View File

@ -359,11 +359,11 @@ mod tests {
let mut cache = transformer.new_autoregressive_cache();
for i in 1..seq_length + 1 {
let tensor = tensor.clone().index([0..batch_size, 0..i, 0..d_model]);
let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
let input = TransformerEncoderInput::new(tensor.clone());
let next_tok = transformer
.forward_autoregressive_inference(input, &mut cache)
.index([0..batch_size, i - 1..i, 0..d_model]);
.slice([0..batch_size, i - 1..i, 0..d_model]);
output_2.push(next_tok);
}

View File

@ -6,7 +6,7 @@ use std::marker::PhantomData;
/// want a probability distribution that is computed lazily.
pub struct ShuffledDataset<D, I> {
dataset: D,
indexes: Vec<usize>,
indices: Vec<usize>,
input: PhantomData<I>,
}
@ -16,15 +16,15 @@ where
{
/// Creates a new shuffled dataset.
pub fn new(dataset: D, rng: &mut StdRng) -> Self {
let mut indexes = Vec::with_capacity(dataset.len());
let mut indices = Vec::with_capacity(dataset.len());
for i in 0..dataset.len() {
indexes.push(i);
indices.push(i);
}
indexes.shuffle(rng);
indices.shuffle(rng);
Self {
dataset,
indexes,
indices,
input: PhantomData,
}
}
@ -42,7 +42,7 @@ where
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
let index = match self.indexes.get(index) {
let index = match self.indices.get(index) {
Some(index) => index,
None => return None,
};

View File

@ -28,22 +28,22 @@ impl<E> NdArrayOps<E>
where
E: Copy,
{
pub fn index<const D1: usize, const D2: usize>(
pub fn slice<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> NdArrayTensor<E, D1> {
let slices = Self::to_slice_args::<D1, D2>(indexes);
let slices = Self::to_slice_args::<D1, D2>(ranges);
let array = tensor.array.slice_move(slices.as_slice()).into_shared();
NdArrayTensor { array }
}
pub fn index_assign<const D1: usize, const D2: usize>(
pub fn slice_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: NdArrayTensor<E, D1>,
) -> NdArrayTensor<E, D1> {
let slices = Self::to_slice_args::<D1, D2>(indexes);
let slices = Self::to_slice_args::<D1, D2>(ranges);
let mut array = tensor.array.into_owned();
array.slice_mut(slices.as_slice()).assign(&value.array);
let array = array.into_shared();
@ -77,7 +77,7 @@ where
}
fn to_slice_args<const D1: usize, const D2: usize>(
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> [SliceInfoElem; D1] {
let mut slices = [SliceInfoElem::NewAxis; D1];
for i in 0..D1 {
@ -89,8 +89,8 @@ where
}
} else {
slices[i] = SliceInfoElem::Slice {
start: indexes[i].start as isize,
end: Some(indexes[i].end as isize),
start: ranges[i].start as isize,
end: Some(ranges[i].end as isize),
step: 1,
}
}
@ -211,31 +211,31 @@ where
pub fn gather<const D: usize>(
dim: usize,
mut tensor: NdArrayTensor<E, D>,
mut indexes: NdArrayTensor<i64, D>,
mut indices: NdArrayTensor<i64, D>,
) -> NdArrayTensor<E, D> {
if dim != D - 1 {
tensor.array.swap_axes(D - 1, dim);
indexes.array.swap_axes(D - 1, dim);
indices.array.swap_axes(D - 1, dim);
}
let (shape_tensor, shape_indexes) = (tensor.shape(), indexes.shape());
let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indexes.dims[D - 1]);
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indexes);
let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape());
let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indices.dims[D - 1]);
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices);
let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array;
let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array;
let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
let mut output = Array2::zeros((batch_size, size_index));
for b in 0..batch_size {
let indexes = indexes.slice(s!(b, ..));
let indices = indices.slice(s!(b, ..));
for (i, index) in indexes.iter().enumerate() {
for (i, index) in indices.iter().enumerate() {
output[[b, i]] = tensor[[b, *index as usize]];
}
}
let mut output = NdArrayOps::reshape(
NdArrayTensor::<E, 2>::new(output.into_shared().into_dyn()),
shape_indexes,
shape_indices,
);
if dim != D - 1 {
@ -248,36 +248,36 @@ where
pub fn scatter<const D: usize>(
dim: usize,
mut tensor: NdArrayTensor<E, D>,
mut indexes: NdArrayTensor<i64, D>,
mut indices: NdArrayTensor<i64, D>,
mut value: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
if dim != D - 1 {
tensor.array.swap_axes(D - 1, dim);
indexes.array.swap_axes(D - 1, dim);
indices.array.swap_axes(D - 1, dim);
value.array.swap_axes(D - 1, dim);
}
let (shape_tensor, shape_indexes, shape_value) =
(tensor.shape(), indexes.shape(), value.shape());
let (shape_tensor, shape_indices, shape_value) =
(tensor.shape(), indices.shape(), value.shape());
let (size_tensor, size_index, size_value) = (
shape_tensor.dims[D - 1],
shape_indexes.dims[D - 1],
shape_indices.dims[D - 1],
shape_value.dims[D - 1],
);
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indexes);
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices);
if shape_value != shape_indexes {
panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indexes.dims, shape_value.dims);
if shape_value != shape_indices {
panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indices.dims, shape_value.dims);
}
let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array;
let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array;
let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])).array;
let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
for b in 0..batch_size {
let indexes = indexes.slice(s!(b, ..));
let indices = indices.slice(s!(b, ..));
for (i, index) in indexes.iter().enumerate() {
for (i, index) in indices.iter().enumerate() {
let index = *index as usize;
tensor[[b, index]] += value[[b, i]];
}
@ -331,28 +331,28 @@ where
fn gather_batch_size<const D: usize>(
shape_tensor: &Shape<D>,
shape_indexes: &Shape<D>,
shape_indices: &Shape<D>,
) -> usize {
let mut batch_size = 1;
for i in 0..D - 1 {
if shape_tensor.dims[i] != shape_indexes.dims[i] {
panic!("Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indexes.dims);
if shape_tensor.dims[i] != shape_indices.dims[i] {
panic!("Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indices.dims);
}
batch_size *= shape_indexes.dims[i];
batch_size *= shape_indices.dims[i];
}
batch_size
}
pub fn index_select<const D: usize>(
pub fn select<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
indices: NdArrayTensor<i64, 1>,
) -> NdArrayTensor<E, D> {
let array = tensor.array.select(
Axis(dim),
&indexes
&indices
.array
.into_iter()
.map(|i| i as usize)
@ -362,15 +362,15 @@ where
NdArrayTensor::new(array.into_shared())
}
pub fn index_select_assign<const D1: usize, const D2: usize>(
pub fn select_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
indices: NdArrayTensor<i64, 1>,
value: NdArrayTensor<E, D2>,
) -> NdArrayTensor<E, D1> {
let mut output_array = tensor.array.into_owned();
for (index_value, index) in indexes.array.into_iter().enumerate() {
for (index_value, index) in indices.array.into_iter().enumerate() {
let mut view = output_array.index_axis_mut(Axis(dim), index as usize);
let value = value.array.index_axis(Axis(dim), index_value);

View File

@ -57,11 +57,11 @@ impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend
NdArrayOps::reshape(tensor, shape)
}
fn bool_index<const D1: usize, const D2: usize>(
fn bool_slice<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<bool, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> NdArrayTensor<bool, D1> {
NdArrayOps::index(tensor, indexes)
NdArrayOps::slice(tensor, ranges)
}
fn bool_into_int<const D: usize>(
@ -85,12 +85,12 @@ impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend
NdArrayTensor::from_data(Data::new(values, shape))
}
fn bool_index_assign<const D1: usize, const D2: usize>(
fn bool_slice_assign<const D1: usize, const D2: usize>(
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D1>,
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D1> {
NdArrayOps::index_assign(tensor, indexes, value)
NdArrayOps::slice_assign(tensor, ranges, value)
}
fn bool_cat<const D: usize>(

View File

@ -51,11 +51,11 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
NdArrayOps::reshape(tensor, shape)
}
fn int_index<const D1: usize, const D2: usize>(
fn int_slice<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<i64, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> NdArrayTensor<i64, D1> {
NdArrayOps::index(tensor, indexes)
NdArrayOps::slice(tensor, ranges)
}
fn int_device<const D: usize>(
@ -88,12 +88,12 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
NdArrayMathOps::mask_fill(tensor, mask, value)
}
fn int_index_assign<const D1: usize, const D2: usize>(
fn int_slice_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<i64, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: NdArrayTensor<i64, D1>,
) -> NdArrayTensor<i64, D1> {
NdArrayOps::index_assign(tensor, indexes, value)
NdArrayOps::slice_assign(tensor, ranges, value)
}
fn int_cat<const D: usize>(
@ -283,35 +283,35 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
fn int_gather<const D: usize>(
dim: usize,
tensor: NdArrayTensor<i64, D>,
indexes: NdArrayTensor<i64, D>,
indices: NdArrayTensor<i64, D>,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::gather(dim, tensor, indexes)
NdArrayMathOps::gather(dim, tensor, indices)
}
fn int_scatter<const D: usize>(
dim: usize,
tensor: NdArrayTensor<i64, D>,
indexes: NdArrayTensor<i64, D>,
indices: NdArrayTensor<i64, D>,
value: NdArrayTensor<i64, D>,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::scatter(dim, tensor, indexes, value)
NdArrayMathOps::scatter(dim, tensor, indices, value)
}
fn int_index_select_dim<const D: usize>(
fn int_select<const D: usize>(
tensor: NdArrayTensor<i64, D>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
indices: NdArrayTensor<i64, 1>,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::index_select(tensor, dim, indexes)
NdArrayMathOps::select(tensor, dim, indices)
}
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<i64, D1>,
fn int_select_assign<const D: usize>(
tensor: NdArrayTensor<i64, D>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
value: NdArrayTensor<i64, D2>,
) -> NdArrayTensor<i64, D1> {
NdArrayMathOps::index_select_assign(tensor, dim, indexes, value)
indices: NdArrayTensor<i64, 1>,
value: NdArrayTensor<i64, D>,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::select_assign(tensor, dim, indices, value)
}
fn int_argmax<const D: usize>(
tensor: NdArrayTensor<i64, D>,

View File

@ -60,7 +60,7 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
NdArrayTensor::new(output.into_dyn().into_shared())
}
pub(crate) fn max_pool2d_with_indexes<E: FloatNdArrayElement>(
pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
@ -78,10 +78,10 @@ pub(crate) fn max_pool2d_with_indexes<E: FloatNdArrayElement>(
let x = apply_padding_4d(x, padding, inf).array;
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
let mut indexes = Array4::<i64>::zeros((batch_size, channels, out_height, out_width));
let mut indices = Array4::<i64>::zeros((batch_size, channels, out_height, out_width));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
let unsafe_shared_indexes = UnsafeSharedRef::new(&mut indexes);
let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);
run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
@ -89,7 +89,7 @@ pub(crate) fn max_pool2d_with_indexes<E: FloatNdArrayElement>(
let c = k % channels;
let output = unsafe_shared_out.get();
let indexes = unsafe_shared_indexes.get();
let indices = unsafe_shared_indices.get();
for oh in 0..out_height {
for ow in 0..out_width {
@ -115,16 +115,16 @@ pub(crate) fn max_pool2d_with_indexes<E: FloatNdArrayElement>(
}
output[[b, c, oh, ow]] = max_val;
indexes[[b, c, oh, ow]] = index;
indices[[b, c, oh, ow]] = index;
}
}
})
});
let output = NdArrayTensor::new(output.into_dyn().into_shared());
let indexes = NdArrayTensor::new(indexes.into_dyn().into_shared());
let indices = NdArrayTensor::new(indices.into_dyn().into_shared());
(output, indexes)
(output, indices)
}
pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
@ -133,13 +133,13 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
_stride: [usize; 2],
_padding: [usize; 2],
output_grad: NdArrayTensor<E, 4>,
indexes: NdArrayTensor<i64, 4>,
indices: NdArrayTensor<i64, 4>,
) -> NdArrayTensor<E, 4> {
let [_batch_size, _channels, height, width] = output_grad.shape().dims;
let [batch_size, channels, height_x, width_x] = x.shape().dims;
let output_grad = output_grad.array;
let indexes = indexes.array;
let indices = indices.array;
let mut output = Array4::zeros((batch_size, channels, height_x, width_x));
@ -154,7 +154,7 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
for h in 0..height {
for w in 0..width {
let index = indexes[[b, c, h, w]];
let index = indices[[b, c, h, w]];
let grad = output_grad[[b, c, h, w]];
let index_h = index as usize / width_x;

View File

@ -1,7 +1,7 @@
use super::{
avgpool::{avg_pool2d, avg_pool2d_backward},
conv::{conv2d, conv_transpose2d},
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indexes},
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
};
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use burn_tensor::ops::*;
@ -53,24 +53,24 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
max_pool2d(x, kernel_size, stride, padding)
}
fn max_pool2d_with_indexes(
fn max_pool2d_with_indices(
x: NdArrayTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> MaxPool2dWithIndexes<NdArrayBackend<E>> {
let (output, indexes) = max_pool2d_with_indexes(x, kernel_size, stride, padding);
) -> MaxPool2dWithIndices<NdArrayBackend<E>> {
let (output, indices) = max_pool2d_with_indices(x, kernel_size, stride, padding);
MaxPool2dWithIndexes::new(output, indexes)
MaxPool2dWithIndices::new(output, indices)
}
fn max_pool2d_with_indexes_backward(
fn max_pool2d_with_indices_backward(
x: NdArrayTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
output_grad: NdArrayTensor<E, 4>,
indexes: NdArrayTensor<i64, 4>,
indices: NdArrayTensor<i64, 4>,
) -> MaxPool2dBackward<NdArrayBackend<E>> {
MaxPool2dBackward::new(max_pool2d_backward(
x,
@ -78,7 +78,7 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
stride,
padding,
output_grad,
indexes,
indices,
))
}
}

View File

@ -18,7 +18,7 @@ pub(crate) fn apply_padding_4d<E: FloatNdArrayElement>(
);
let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn());
x_new = NdArrayBackend::index_assign(
x_new = NdArrayBackend::slice_assign(
x_new,
[
0..batch_size,

View File

@ -154,50 +154,50 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
fn gather<const D: usize>(
dim: usize,
tensor: NdArrayTensor<E, D>,
indexes: NdArrayTensor<i64, D>,
indices: NdArrayTensor<i64, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::gather(dim, tensor, indexes)
NdArrayMathOps::gather(dim, tensor, indices)
}
fn scatter<const D: usize>(
dim: usize,
tensor: NdArrayTensor<E, D>,
indexes: NdArrayTensor<i64, D>,
indices: NdArrayTensor<i64, D>,
value: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::scatter(dim, tensor, indexes, value)
NdArrayMathOps::scatter(dim, tensor, indices, value)
}
fn index_select<const D: usize>(
fn select<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
indices: NdArrayTensor<i64, 1>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::index_select(tensor, dim, indexes)
NdArrayMathOps::select(tensor, dim, indices)
}
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
fn select_assign<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
value: NdArrayTensor<E, D2>,
) -> NdArrayTensor<E, D1> {
NdArrayMathOps::index_select_assign(tensor, dim, indexes, value)
indices: NdArrayTensor<i64, 1>,
value: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::select_assign(tensor, dim, indices, value)
}
fn index<const D1: usize, const D2: usize>(
fn slice<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> NdArrayTensor<E, D1> {
NdArrayOps::index(tensor, indexes)
NdArrayOps::slice(tensor, ranges)
}
fn index_assign<const D1: usize, const D2: usize>(
fn slice_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: NdArrayTensor<E, D1>,
) -> NdArrayTensor<E, D1> {
NdArrayOps::index_assign(tensor, indexes, value)
NdArrayOps::slice_assign(tensor, ranges, value)
}
fn mask_where<const D: usize>(

View File

@ -18,14 +18,14 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage)
}
pub fn index<const D1: usize, const D2: usize>(
pub fn slice<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> TchTensor<E, D1> {
let storage = tensor.storage.clone();
let mut tensor = tensor.tensor.shallow_clone();
for (i, index) in indexes.iter().enumerate().take(D2) {
for (i, index) in ranges.iter().enumerate().take(D2) {
let start = index.start as i64;
let length = (index.end - index.start) as i64;
tensor = tensor.narrow(i as i64, start, length);
@ -34,9 +34,9 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
TchTensor::from_existing(tensor, storage)
}
pub fn index_assign<const D1: usize, const D2: usize>(
pub fn slice_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: TchTensor<E, D1>,
) -> TchTensor<E, D1> {
let tensor_original = tensor.tensor.copy();
@ -44,7 +44,7 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
let mut tensor = tensor_original.view_(tch_shape.dims);
for (i, index) in indexes.into_iter().enumerate().take(D2) {
for (i, index) in ranges.into_iter().enumerate().take(D2) {
let start = index.start as i64;
let length = (index.end - index.start) as i64;
@ -59,10 +59,10 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
pub fn gather<const D: usize>(
dim: usize,
tensor: TchTensor<E, D>,
indexes: TchTensor<i64, D>,
indices: TchTensor<i64, D>,
) -> TchTensor<E, D> {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.gather(dim as i64, &indexes.tensor, false);
let tensor = tensor.tensor.gather(dim as i64, &indices.tensor, false);
TchTensor::from_existing(tensor, storage)
}
@ -70,13 +70,13 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
pub fn scatter<const D: usize>(
dim: usize,
tensor: TchTensor<E, D>,
indexes: TchTensor<i64, D>,
indices: TchTensor<i64, D>,
value: TchTensor<E, D>,
) -> TchTensor<E, D> {
let storage = tensor.storage.clone();
let tensor = tensor
.tensor
.scatter_add(dim as i64, &indexes.tensor, &value.tensor);
.scatter_add(dim as i64, &indices.tensor, &value.tensor);
TchTensor::from_existing(tensor, storage)
}
@ -84,25 +84,25 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
pub fn index_select_dim<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
indexes: TchTensor<i64, 1>,
indices: TchTensor<i64, 1>,
) -> TchTensor<E, D> {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.index_select(dim as i64, &indexes.tensor);
let tensor = tensor.tensor.index_select(dim as i64, &indices.tensor);
TchTensor::from_existing(tensor, storage)
}
pub fn index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
pub fn select_assign<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
indexes: TchTensor<i64, 1>,
value: TchTensor<E, D2>,
) -> TchTensor<E, D1> {
let mut indices = Vec::with_capacity(D1);
for _ in 0..D1 {
indices_tensor: TchTensor<i64, 1>,
value: TchTensor<E, D>,
) -> TchTensor<E, D> {
let mut indices = Vec::with_capacity(D);
for _ in 0..D {
indices.push(None);
}
indices[dim] = Some(indexes.tensor);
indices[dim] = Some(indices_tensor.tensor);
tensor.unary_ops(
|mut tensor| tensor.index_put_(&indices, &value.tensor, true),
@ -321,41 +321,41 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
pub fn max_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
let storage = tensor.storage.clone();
let (tensor, _indexes) = tensor.tensor.max_dim(dim as i64, true);
let (tensor, _indices) = tensor.tensor.max_dim(dim as i64, true);
TchTensor::from_existing(tensor, storage)
}
pub fn max_dim_with_indexes<const D: usize>(
pub fn max_dim_with_indices<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
let storage = tensor.storage.clone();
let (tensor, indexes) = tensor.tensor.max_dim(dim as i64, true);
let (tensor, indices) = tensor.tensor.max_dim(dim as i64, true);
let tensor = TchTensor::from_existing(tensor, storage);
let indexes = TchTensor::new(indexes);
let indices = TchTensor::new(indices);
(tensor, indexes)
(tensor, indices)
}
pub fn min_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
let storage = tensor.storage.clone();
let (tensor, _indexes) = tensor.tensor.min_dim(dim as i64, true);
let (tensor, _indices) = tensor.tensor.min_dim(dim as i64, true);
TchTensor::from_existing(tensor, storage)
}
pub fn min_dim_with_indexes<const D: usize>(
pub fn min_dim_with_indices<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
let storage = tensor.storage.clone();
let (tensor, indexes) = tensor.tensor.min_dim(dim as i64, true);
let (tensor, indices) = tensor.tensor.min_dim(dim as i64, true);
let tensor = TchTensor::from_existing(tensor, storage);
let indexes = TchTensor::new(indexes);
let indices = TchTensor::new(indices);
(tensor, indexes)
(tensor, indices)
}
}

View File

@ -60,18 +60,18 @@ impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
TchTensor::new(tensor)
}
fn bool_index<const D1: usize, const D2: usize>(
fn bool_slice<const D1: usize, const D2: usize>(
tensor: TchTensor<bool, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> TchTensor<bool, D1> {
TchOps::index(tensor, indexes)
TchOps::slice(tensor, ranges)
}
fn bool_index_assign<const D1: usize, const D2: usize>(
fn bool_slice_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<bool, D1>,
indexes: [std::ops::Range<usize>; D2],
ranges: [std::ops::Range<usize>; D2],
value: TchTensor<bool, D1>,
) -> TchTensor<bool, D1> {
TchOps::index_assign(tensor, indexes, value)
TchOps::slice_assign(tensor, ranges, value)
}
fn bool_cat<const D: usize>(

View File

@ -57,18 +57,18 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
TchTensor::new(tensor)
}
fn int_index<const D1: usize, const D2: usize>(
fn int_slice<const D1: usize, const D2: usize>(
tensor: TchTensor<i64, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> TchTensor<i64, D1> {
TchOps::index(tensor, indexes)
TchOps::slice(tensor, ranges)
}
fn int_index_assign<const D1: usize, const D2: usize>(
fn int_slice_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<i64, D1>,
indexes: [std::ops::Range<usize>; D2],
ranges: [std::ops::Range<usize>; D2],
value: TchTensor<i64, D1>,
) -> TchTensor<i64, D1> {
TchOps::index_assign(tensor, indexes, value)
TchOps::slice_assign(tensor, ranges, value)
}
fn int_cat<const D: usize>(tensors: Vec<TchTensor<i64, D>>, dim: usize) -> TchTensor<i64, D> {
@ -234,35 +234,35 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
fn int_gather<const D: usize>(
dim: usize,
tensor: TchTensor<i64, D>,
indexes: TchTensor<i64, D>,
indices: TchTensor<i64, D>,
) -> TchTensor<i64, D> {
TchOps::gather(dim, tensor, indexes)
TchOps::gather(dim, tensor, indices)
}
fn int_scatter<const D: usize>(
dim: usize,
tensor: TchTensor<i64, D>,
indexes: TchTensor<i64, D>,
indices: TchTensor<i64, D>,
value: TchTensor<i64, D>,
) -> TchTensor<i64, D> {
TchOps::scatter(dim, tensor, indexes, value)
TchOps::scatter(dim, tensor, indices, value)
}
fn int_index_select_dim<const D: usize>(
fn int_select<const D: usize>(
tensor: TchTensor<i64, D>,
dim: usize,
indexes: TchTensor<i64, 1>,
indices: TchTensor<i64, 1>,
) -> TchTensor<i64, D> {
TchOps::index_select_dim(tensor, dim, indexes)
TchOps::index_select_dim(tensor, dim, indices)
}
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<i64, D1>,
fn int_select_assign<const D: usize>(
tensor: TchTensor<i64, D>,
dim: usize,
indexes: TchTensor<i64, 1>,
value: TchTensor<i64, D2>,
) -> TchTensor<i64, D1> {
TchOps::index_select_dim_assign(tensor, dim, indexes, value)
indices: TchTensor<i64, 1>,
value: TchTensor<i64, D>,
) -> TchTensor<i64, D> {
TchOps::select_assign(tensor, dim, indices, value)
}
fn int_mask_where<const D: usize>(
@ -301,21 +301,21 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
TchOps::max_dim(tensor, dim)
}
fn int_max_dim_with_indexes<const D: usize>(
fn int_max_dim_with_indices<const D: usize>(
tensor: TchTensor<i64, D>,
dim: usize,
) -> (TchTensor<i64, D>, TchTensor<i64, D>) {
TchOps::max_dim_with_indexes(tensor, dim)
TchOps::max_dim_with_indices(tensor, dim)
}
fn int_min_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
TchOps::min_dim(tensor, dim)
}
fn int_min_dim_with_indexes<const D: usize>(
fn int_min_dim_with_indices<const D: usize>(
tensor: TchTensor<i64, D>,
dim: usize,
) -> (TchTensor<i64, D>, TchTensor<i64, D>) {
TchOps::min_dim_with_indexes(tensor, dim)
TchOps::min_dim_with_indices(tensor, dim)
}
}

View File

@ -1,11 +1,11 @@
use crate::{element::TchElement, TchBackend, TchTensor};
use burn_tensor::ops::{
ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndexes, ModuleOps,
ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
};
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
fn embedding(weights: TchTensor<E, 2>, indexes: TchTensor<i64, 2>) -> TchTensor<E, 3> {
let tensor = tch::Tensor::embedding(&weights.tensor, &indexes.tensor, -1, false, false);
fn embedding(weights: TchTensor<E, 2>, indices: TchTensor<i64, 2>) -> TchTensor<E, 3> {
let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);
TchTensor::new(tensor)
}
@ -13,12 +13,12 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
fn embedding_backward(
weights: TchTensor<E, 2>,
output: TchTensor<E, 3>,
indexes: TchTensor<i64, 2>,
indices: TchTensor<i64, 2>,
) -> TchTensor<E, 2> {
let [n_embedding, _d_model] = weights.shape().dims;
let tensor = tch::Tensor::embedding_backward(
&output.tensor,
&indexes.tensor,
&indices.tensor,
n_embedding as i64,
-1,
false,
@ -181,13 +181,13 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
TchTensor::new(tensor)
}
fn max_pool2d_with_indexes(
fn max_pool2d_with_indices(
x: TchTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> MaxPool2dWithIndexes<TchBackend<E>> {
let (tensor, indexes) = tch::Tensor::max_pool2d_with_indices(
) -> MaxPool2dWithIndices<TchBackend<E>> {
let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
&x.tensor,
[kernel_size[0] as i64, kernel_size[1] as i64],
[stride[0] as i64, stride[1] as i64],
@ -196,16 +196,16 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
false,
);
MaxPool2dWithIndexes::new(TchTensor::new(tensor), TchTensor::new(indexes))
MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
}
fn max_pool2d_with_indexes_backward(
fn max_pool2d_with_indices_backward(
x: TchTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
output_grad: TchTensor<E, 4>,
indexes: TchTensor<i64, 4>,
indices: TchTensor<i64, 4>,
) -> MaxPool2dBackward<TchBackend<E>> {
let grad = tch::Tensor::max_pool2d_with_indices_backward(
&x.tensor,
@ -215,7 +215,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
[padding[0] as i64, padding[1] as i64],
[1, 1],
false,
&indexes.tensor,
&indices.tensor,
);
MaxPool2dBackward::new(TchTensor::new(grad))

View File

@ -182,50 +182,50 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
fn gather<const D: usize>(
dim: usize,
tensor: TchTensor<E, D>,
indexes: TchTensor<i64, D>,
indices: TchTensor<i64, D>,
) -> TchTensor<E, D> {
TchOps::gather(dim, tensor, indexes)
TchOps::gather(dim, tensor, indices)
}
fn scatter<const D: usize>(
dim: usize,
tensor: TchTensor<E, D>,
indexes: TchTensor<i64, D>,
indices: TchTensor<i64, D>,
value: TchTensor<E, D>,
) -> TchTensor<E, D> {
TchOps::scatter(dim, tensor, indexes, value)
TchOps::scatter(dim, tensor, indices, value)
}
fn index_select<const D: usize>(
fn select<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
indexes: TchTensor<i64, 1>,
indices: TchTensor<i64, 1>,
) -> TchTensor<E, D> {
TchOps::index_select_dim(tensor, dim, indexes)
TchOps::index_select_dim(tensor, dim, indices)
}
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
fn select_assign<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
indexes: TchTensor<i64, 1>,
value: TchTensor<E, D2>,
) -> TchTensor<E, D1> {
TchOps::index_select_dim_assign(tensor, dim, indexes, value)
indices: TchTensor<i64, 1>,
value: TchTensor<E, D>,
) -> TchTensor<E, D> {
TchOps::select_assign(tensor, dim, indices, value)
}
fn index<const D1: usize, const D2: usize>(
fn slice<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> TchTensor<E, D1> {
TchOps::index(tensor, indexes)
TchOps::slice(tensor, ranges)
}
fn index_assign<const D1: usize, const D2: usize>(
fn slice_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: TchTensor<E, D1>,
) -> <TchBackend<E> as Backend>::TensorPrimitive<D1> {
TchOps::index_assign(tensor, indexes, value)
TchOps::slice_assign(tensor, ranges, value)
}
fn mask_where<const D: usize>(
@ -339,22 +339,22 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
TchOps::max_dim(tensor, dim)
}
fn max_dim_with_indexes<const D: usize>(
fn max_dim_with_indices<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
TchOps::max_dim_with_indexes(tensor, dim)
TchOps::max_dim_with_indices(tensor, dim)
}
fn min_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchOps::min_dim(tensor, dim)
}
fn min_dim_with_indexes<const D: usize>(
fn min_dim_with_indices<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
TchOps::min_dim_with_indexes(tensor, dim)
TchOps::min_dim_with_indices(tensor, dim)
}
fn exp<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {

View File

@ -268,11 +268,11 @@ mod tests {
}
#[test]
fn should_not_update_inplace_after_index() {
fn should_not_update_inplace_after_slice() {
let tensor_1 = Tensor::<TchBackend<f32>, 1>::from_floats([4.0, 4.0]);
let tensor_2 = tensor_1.clone();
let tensor_3 = tensor_2.index([0..2]).add_scalar(2.0);
let tensor_3 = tensor_2.slice([0..2]).add_scalar(2.0);
assert_ne!(tensor_3.to_data().value, tensor_1.to_data().value);
}

View File

@ -211,18 +211,18 @@ where
///
/// fn example<B: Backend>() {
/// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]));
/// let tensor_indexed = tensor.index([0..1, 0..3, 1..2]);
/// println!("{:?}", tensor_indexed.shape());
/// // Shape { dims: [1, 3, 2] }
/// let tensor_slices = tensor.slice([0..1, 0..3, 1..2]);
/// println!("{:?}", tensor_slices.dims()); // [1, 3, 2]
///
/// }
/// ```
pub fn index<const D2: usize>(self, indexes: [core::ops::Range<usize>; D2]) -> Self {
check!(TensorCheck::index(&self.shape(), &indexes));
Self::new(K::index(self.primitive, indexes))
pub fn slice<const D2: usize>(self, ranges: [core::ops::Range<usize>; D2]) -> Self {
check!(TensorCheck::slice(&self.shape(), &ranges));
Self::new(K::slice(self.primitive, ranges))
}
/// Returns a copy of the current tensor with the selected elements changed to the new ones at
/// the selected indexes.
/// the selected indices.
///
/// # Panics
///
@ -238,22 +238,21 @@ where
/// fn example<B: Backend>() {
/// let tensor = Tensor::<B, 3>::ones([2, 3, 3]);
/// let values = Tensor::<B, 3>::zeros([1, 1, 1]);
/// let tensor_indexed = tensor.index_assign([0..1, 0..1, 0..1], values);
/// println!("{:?}", tensor_indexed.shape());
/// // Shape { dims: [2, 3, 3] }
/// let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values);
/// println!("{:?}", tensor_sliced.dims()); // [2, 3, 3]
/// }
/// ```
pub fn index_assign<const D2: usize>(
pub fn slice_assign<const D2: usize>(
self,
indexes: [core::ops::Range<usize>; D2],
ranges: [core::ops::Range<usize>; D2],
values: Self,
) -> Self {
check!(TensorCheck::index_assign(
check!(TensorCheck::slice_assign(
&self.shape(),
&values.shape(),
&indexes
&ranges
));
Self::new(K::index_assign(self.primitive, indexes, values.primitive))
Self::new(K::slice_assign(self.primitive, ranges, values.primitive))
}
/// Returns the device of the current tensor.
@ -358,7 +357,7 @@ where
multi_index[depth] = i;
let range: [core::ops::Range<usize>; D] =
core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
let elem = &self.clone().index(range).to_data().value[0];
let elem = &self.clone().slice(range).to_data().value[0];
acc.push_str(&format!("{elem:?}"));
}
} else {
@ -480,12 +479,12 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
shape: Shape<D2>,
) -> Self::Primitive<D2>;
/// Select tensor elements corresponding for the given indexes.
/// Select tensor elements corresponding for the given ranges.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `indexes` - The indexes of the elements to select.
/// * `ranges` - The ranges of the elements to select.
///
/// # Returns
///
@ -497,19 +496,19 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For selecting elements of a tensor, users should prefer the [Tensor::index](Tensor::index) function,
/// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function,
/// which is more high-level and designed for public use.
fn index<const D1: usize, const D2: usize>(
fn slice<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
indexes: [Range<usize>; D2],
range: [Range<usize>; D2],
) -> Self::Primitive<D1>;
/// Assigns the given value to the tensor elements corresponding for the given indexes.
/// Assigns the given value to the tensor elements corresponding for the given ranges.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `indexes` - The indexes of the elements to select.
/// * `ranges` - The ranges of the elements to select.
/// * `value` - The value to assign.
///
/// # Returns
@ -522,11 +521,11 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For assigning values to elements of a tensor, users should prefer the [Tensor::index_assign](Tensor::index_assign) function,
/// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function,
/// which is more high-level and designed for public use.
fn index_assign<const D1: usize, const D2: usize>(
fn slice_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: Self::Primitive<D1>,
) -> Self::Primitive<D1>;
@ -712,19 +711,19 @@ impl<B: Backend> BasicOps<B> for Float {
B::reshape(tensor, shape)
}
fn index<const D1: usize, const D2: usize>(
fn slice<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> Self::Primitive<D1> {
B::index(tensor, indexes)
B::slice(tensor, ranges)
}
fn index_assign<const D1: usize, const D2: usize>(
fn slice_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: Self::Primitive<D1>,
) -> Self::Primitive<D1> {
B::index_assign(tensor, indexes, value)
B::slice_assign(tensor, ranges, value)
}
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> <B as Backend>::Device {
@ -786,19 +785,19 @@ impl<B: Backend> BasicOps<B> for Int {
B::int_reshape(tensor, shape)
}
fn index<const D1: usize, const D2: usize>(
fn slice<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> Self::Primitive<D1> {
B::int_index(tensor, indexes)
B::int_slice(tensor, ranges)
}
fn index_assign<const D1: usize, const D2: usize>(
fn slice_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: Self::Primitive<D1>,
) -> Self::Primitive<D1> {
B::int_index_assign(tensor, indexes, value)
B::int_slice_assign(tensor, ranges, value)
}
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> <B as Backend>::Device {
@ -860,19 +859,19 @@ impl<B: Backend> BasicOps<B> for Bool {
B::bool_reshape(tensor, shape)
}
fn index<const D1: usize, const D2: usize>(
fn slice<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> Self::Primitive<D1> {
B::bool_index(tensor, indexes)
B::bool_slice(tensor, ranges)
}
fn index_assign<const D1: usize, const D2: usize>(
fn slice_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: Self::Primitive<D1>,
) -> Self::Primitive<D1> {
B::bool_index_assign(tensor, indexes, value)
B::bool_slice_assign(tensor, ranges, value)
}
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> <B as Backend>::Device {

View File

@ -264,56 +264,56 @@ impl TensorCheck {
check
}
pub(crate) fn index<const D1: usize, const D2: usize>(
pub(crate) fn slice<const D1: usize, const D2: usize>(
shape: &Shape<D1>,
indexes: &[Range<usize>; D2],
ranges: &[Range<usize>; D2],
) -> Self {
let mut check = Self::Ok;
let n_dims_tensor = D1;
let n_dims_indexes = D2;
let n_dims_ranges = D2;
if n_dims_tensor < n_dims_indexes {
check = check.register("Index",
TensorError::new ("The provided indexes array has a higher number of dimensions than the current tensor.")
if n_dims_tensor < n_dims_ranges {
check = check.register("Slice",
TensorError::new ("The provided ranges array has a higher number of dimensions than the current tensor.")
.details(
format!(
"The indexes array must be smaller or equal to the tensor number of dimensions. \
Tensor number of dimensions: {n_dims_tensor}, indexes array lenght {n_dims_indexes}."
"The ranges array must be smaller or equal to the tensor number of dimensions. \
Tensor number of dimensions: {n_dims_tensor}, ranges array lenght {n_dims_ranges}."
)));
}
for i in 0..usize::min(D1, D2) {
let d_tensor = shape.dims[i];
let index = indexes.get(i).unwrap();
let range = ranges.get(i).unwrap();
if index.end > d_tensor {
if range.end > d_tensor {
check = check.register(
"Index",
TensorError::new("The provided indexes array has a range that exceeds the current tensor size.")
"Slice",
TensorError::new("The provided ranges array has a range that exceeds the current tensor size.")
.details(format!(
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
Tensor shape {:?}, provided indexes {:?}.",
index.start,
index.end,
Tensor shape {:?}, provided ranges {:?}.",
range.start,
range.end,
d_tensor,
i,
shape.dims,
indexes,
ranges,
)));
}
if index.start >= index.end {
if range.start >= range.end {
check = check.register(
"Index",
TensorError::new("The provided indexes array has a range where the start index is bigger or equal to its end.")
"Slice",
TensorError::new("The provided range array has a range where the start index is bigger or equal to its end.")
.details(format!(
"The range at dimension '{}' starts at '{}' and is greater or equal to its end '{}'. \
Tensor shape {:?}, provided indexes {:?}.",
Tensor shape {:?}, provided ranges {:?}.",
i,
index.start,
index.end,
range.start,
range.end,
shape.dims,
indexes,
ranges,
)));
}
}
@ -321,75 +321,75 @@ impl TensorCheck {
check
}
pub(crate) fn index_assign<const D1: usize, const D2: usize>(
pub(crate) fn slice_assign<const D1: usize, const D2: usize>(
shape: &Shape<D1>,
shape_value: &Shape<D1>,
indexes: &[Range<usize>; D2],
ranges: &[Range<usize>; D2],
) -> Self {
let mut check = Self::Ok;
if D1 < D2 {
check = check.register("Index Assign",
TensorError::new ("The provided indexes array has a higher number of dimensions than the current tensor.")
check = check.register("Slice Assign",
TensorError::new ("The provided ranges array has a higher number of dimensions than the current tensor.")
.details(
format!(
"The indexes array must be smaller or equal to the tensor number of dimensions. \
Tensor number of dimensions: {D1}, indexes array lenght {D2}."
"The ranges array must be smaller or equal to the tensor number of dimensions. \
Tensor number of dimensions: {D1}, ranges array lenght {D2}."
)));
}
for i in 0..usize::min(D1, D2) {
let d_tensor = shape.dims[i];
let d_tensor_value = shape_value.dims[i];
let index = indexes.get(i).unwrap();
let range = ranges.get(i).unwrap();
if index.end > d_tensor {
if range.end > d_tensor {
check = check.register(
"Index Assign",
TensorError::new("The provided indexes array has a range that exceeds the current tensor size.")
"Range Assign",
TensorError::new("The provided ranges array has a range that exceeds the current tensor size.")
.details(format!(
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
Current tensor shape {:?}, value tensor shape {:?}, provided indexes {:?}.",
index.start,
index.end,
Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.",
range.start,
range.end,
d_tensor,
i,
shape.dims,
shape_value.dims,
indexes,
ranges,
)));
}
if index.end - index.start != d_tensor_value {
if range.end - range.start != d_tensor_value {
check = check.register(
"Index Assign",
TensorError::new("The value tensor must match the amount of elements selected with the indexes array")
"Slice Assign",
TensorError::new("The value tensor must match the amount of elements selected with the ranges array")
.details(format!(
"The range ({}..{}) doesn't match the number of elements of the value tensor ({}) at dimension {}. \
Current tensor shape {:?}, value tensor shape {:?}, provided indexes {:?}.",
index.start,
index.end,
Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.",
range.start,
range.end,
d_tensor_value,
i,
shape.dims,
shape_value.dims,
indexes,
ranges,
)));
}
if index.start >= index.end {
if range.start >= range.end {
check = check.register(
"Index Assign",
TensorError::new("The provided indexes array has a range where the start index is bigger or equal to its end.")
"Slice Assign",
TensorError::new("The provided ranges array has a range where the start index is bigger or equal to its end.")
.details(format!(
"The range at dimension '{}' starts at '{}' and is greater or equal to its end '{}'. \
Current tensor shape {:?}, value tensor shape {:?}, provided indexes {:?}.",
Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.",
i,
index.start,
index.end,
range.start,
range.end,
shape.dims,
shape_value.dims,
indexes,
ranges,
)));
}
}
@ -400,31 +400,31 @@ impl TensorCheck {
pub(crate) fn gather<const D: usize>(
dim: usize,
shape: &Shape<D>,
shape_indexes: &Shape<D>,
shape_indices: &Shape<D>,
) -> Self {
Self::check_gather_scatter_indexes(Self::Ok, "Gather", dim, shape, shape_indexes)
Self::check_gather_scatter_indices(Self::Ok, "Gather", dim, shape, shape_indices)
}
pub(crate) fn scatter<const D: usize>(
dim: usize,
shape: &Shape<D>,
shape_indexes: &Shape<D>,
shape_indices: &Shape<D>,
shape_value: &Shape<D>,
) -> Self {
let ops = "Scatter";
let mut check =
Self::check_gather_scatter_indexes(Self::Ok, ops, dim, shape, shape_indexes);
Self::check_gather_scatter_indices(Self::Ok, ops, dim, shape, shape_indices);
if shape_indexes != shape_value {
if shape_indices != shape_value {
check = check.register(
ops,
TensorError::new(
"Indexes tensor shape should be the same as the value tensor shape."
"Indices tensor shape should be the same as the value tensor shape."
.to_string(),
)
.details(format!(
"The shape differs: {:?} != {:?}",
shape_indexes.dims, shape_value.dims
shape_indices.dims, shape_value.dims
)),
);
}
@ -432,15 +432,15 @@ impl TensorCheck {
check
}
pub(crate) fn index_select<const D: usize>(dim: usize) -> Self {
Self::check_index_select_basic::<D>(Self::Ok, "index_select", dim)
pub(crate) fn select<const D: usize>(dim: usize) -> Self {
Self::check_select_basic::<D>(Self::Ok, "select", dim)
}
pub(crate) fn index_select_assign<const D: usize>(dim: usize) -> Self {
Self::check_index_select_basic::<D>(Self::Ok, "index_select_assign", dim)
pub(crate) fn select_assign<const D: usize>(dim: usize) -> Self {
Self::check_select_basic::<D>(Self::Ok, "select_assign", dim)
}
fn check_index_select_basic<const D: usize>(mut check: Self, ops: &str, dim: usize) -> Self {
fn check_select_basic<const D: usize>(mut check: Self, ops: &str, dim: usize) -> Self {
if dim > D {
check = check.register(
ops,
@ -452,12 +452,12 @@ impl TensorCheck {
check
}
fn check_gather_scatter_indexes<const D: usize>(
fn check_gather_scatter_indices<const D: usize>(
mut check: Self,
ops: &str,
dim: usize,
shape: &Shape<D>,
shape_indexes: &Shape<D>,
shape_indices: &Shape<D>,
) -> Self {
if dim > D {
check = check.register(
@ -474,9 +474,9 @@ impl TensorCheck {
}
let tensor_dim_i = shape.dims[i];
let indexes_dim_i = shape_indexes.dims[i];
let indices_dim_i = shape_indices.dims[i];
if tensor_dim_i != indexes_dim_i {
if tensor_dim_i != indices_dim_i {
check = check.register(
ops,
TensorError::new(
@ -484,7 +484,7 @@ impl TensorCheck {
.to_string(),
)
.details(format!(
"The shape differs at dimension {i}: {tensor_dim_i} != {indexes_dim_i}"
"The shape differs at dimension {i}: {tensor_dim_i} != {indices_dim_i}"
)),
);
}
@ -669,7 +669,7 @@ mod tests {
#[test]
#[should_panic]
fn index_range_exceed_dimension() {
check!(TensorCheck::index(
check!(TensorCheck::slice(
&Shape::new([3, 5, 7]),
&[0..2, 0..4, 1..8]
));
@ -678,7 +678,7 @@ mod tests {
#[test]
#[should_panic]
fn index_range_exceed_number_of_dimensions() {
check!(TensorCheck::index(&Shape::new([3, 5]), &[0..1, 0..1, 0..1]));
check!(TensorCheck::slice(&Shape::new([3, 5]), &[0..1, 0..1, 0..1]));
}
#[test]

View File

@ -142,7 +142,7 @@ where
let mut ranges: [core::ops::Range<usize>; D] = ranges.try_into().unwrap();
ranges[D - 1] = index..index + 1;
tensor.index_assign(ranges, Tensor::ones(Shape::new([1; D])))
tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D])))
}
/// Applies the transpose operation.

View File

@ -214,36 +214,36 @@ where
Self::new(K::mask_fill(self.primitive, mask, value.elem()))
}
/// Gather tensor elements corresponding to the given indexes from the specified dim.
/// Gather tensor elements corresponding to the given indices from the specified dim.
///
/// Example using a 3D tensor:
///
/// `output[i, j, k] = input[indexes[i, j, k], j, k]; // dim = 0`
/// `output[i, j, k] = input[i, indexes[i, j, k], k]; // dim = 1`
/// `output[i, j, k] = input[i, j, indexes[i, j, k]]; // dim = 2`
/// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
/// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
/// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
///
/// # Notes
///
/// The index tensor shoud have the same shape as the original tensor except for the dim
/// specified.
pub fn gather(self, dim: usize, indexes: Tensor<B, D, Int>) -> Self {
pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
check!(TensorCheck::gather::<D>(
dim,
&self.shape(),
&indexes.shape()
&indices.shape()
));
Self::new(K::gather(dim, self.primitive, indexes))
Self::new(K::gather(dim, self.primitive, indices))
}
/// Assign the gathered elements corresponding to the given indexes along the speficied dimension
/// Assign the gathered elements corresponding to the given indices along the speficied dimension
/// from the value tensor to the original tensor using sum reduction.
///
/// Example using a 3D tensor:
///
/// `input[indexes[i, j, k], j, k] += values[i, j, k]; // dim = 0`
/// `input[i, indexes[i, j, k], k] += values[i, j, k]; // dim = 1`
/// `input[i, j, indexes[i, j, k]] += values[i, j, k]; // dim = 2`
/// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
/// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
/// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
///
/// # Notes
///
@ -251,49 +251,49 @@ where
/// dimension. The value and index tensors should have the same shape.
///
/// Other references to the input tensor will not be modified by this operation.
pub fn scatter(self, dim: usize, indexes: Tensor<B, D, Int>, values: Self) -> Self {
pub fn scatter(self, dim: usize, indices: Tensor<B, D, Int>, values: Self) -> Self {
check!(TensorCheck::scatter::<D>(
dim,
&self.shape(),
&indexes.shape(),
&indices.shape(),
&values.shape()
));
Self::new(K::scatter(dim, self.primitive, indexes, values.primitive))
Self::new(K::scatter(dim, self.primitive, indices, values.primitive))
}
/// Select the tensor elements along the given dimension corresponding to the given indexes.
/// Select the tensor elements along the given dimension corresponding to the given indices.
///
/// Example using a 3D tensor:
///
/// `output[i, j, k] = input[indexes[i], j, k]; // dim = 0`
/// `output[i, j, k] = input[i, indexes[j], k]; // dim = 1`
/// `output[i, j, k] = input[i, j, indexes[k]]; // dim = 2`
pub fn index_select(self, dim: usize, indexes: Tensor<B, 1, Int>) -> Self {
check!(TensorCheck::index_select::<D>(dim));
Self::new(K::index_select(self.primitive, dim, indexes))
/// `output[i, j, k] = input[indices[i], j, k]; // dim = 0`
/// `output[i, j, k] = input[i, indices[j], k]; // dim = 1`
/// `output[i, j, k] = input[i, j, indices[k]]; // dim = 2`
pub fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self {
check!(TensorCheck::select::<D>(dim));
Self::new(K::select(self.primitive, dim, indices))
}
/// Assign the selected elements along the given dimension corresponding to the given indexes
/// Assign the selected elements along the given dimension corresponding to the given indices
/// from the value tensor to the original tensor using sum reduction.
///
/// Example using a 3D tensor:
///
/// `input[indexes[i], j, k] += values[i, j, k]; // dim = 0`
/// `input[i, indexes[j], k] += values[i, j, k]; // dim = 1`
/// `input[i, j, indexes[k]] += values[i, j, k]; // dim = 2`
pub fn index_select_assign<const D2: usize>(
/// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
/// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
/// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
pub fn select_assign(
self,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Tensor<B, D2, K>,
indices: Tensor<B, 1, Int>,
values: Tensor<B, D, K>,
) -> Self {
check!(TensorCheck::index_select_assign::<D>(dim));
check!(TensorCheck::select_assign::<D>(dim));
Self::new(K::index_select_assign(
Self::new(K::select_assign(
self.primitive,
dim,
indexes,
indices,
values.primitive,
))
}
@ -331,11 +331,11 @@ where
/// Find the maximum value along the given dimension.
///
/// Also returns the indexes.
pub fn max_dim_with_indexes(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
/// Also returns the indices.
pub fn max_dim_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
check!(TensorCheck::aggregate_dim::<D>("Max", dim));
let (tensor, index) = K::max_dim_with_indexes(self.primitive, dim);
let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
let tensor = Tensor::new(tensor);
let index = Tensor::new(index);
@ -375,11 +375,11 @@ where
/// Find the minimum value along the given dimension.
///
/// Also returns the indexes.
pub fn min_dim_with_indexes(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
/// Also returns the indices.
pub fn min_dim_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
check!(TensorCheck::aggregate_dim::<D>("Min", dim));
let (tensor, index) = K::min_dim_with_indexes(self.primitive, dim);
let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
let tensor = Tensor::new(tensor);
let index = Tensor::new(index);
@ -1008,7 +1008,7 @@ where
///
/// * `dim` - The axis along which to gather elements.
/// * `tensor` - The tensor to gather elements from.
/// * `indexes` - The indexes of the elements to gather.
/// * `indices` - The indices of the elements to gather.
///
/// # Returns
///
@ -1026,7 +1026,7 @@ where
fn gather<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
indices: Tensor<B, D, Int>,
) -> Self::Primitive<D>;
/// Scatters elements into a tensor along an axis.
@ -1035,14 +1035,14 @@ where
///
/// * `dim` - The axis along which to scatter elements.
/// * `tensor` - The tensor to scatter elements into.
/// * `indices` - The indexes of the elements to scatter.
/// * `indices` - The indices of the elements to scatter.
/// * `values` - The values to scatter into the tensor.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is taken from the
/// corresponding element of the input tensor at the corresponding index along the specified axis,
/// except for the elements at the specified indexes, which are taken from the corresponding
/// except for the elements at the specified indices, which are taken from the corresponding
/// element of the values tensor.
///
/// # Remarks
@ -1056,17 +1056,17 @@ where
fn scatter<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
indices: Tensor<B, D, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D>;
/// Select tensor elements along the given dimension corresponding for the given indexes.
/// Select tensor elements along the given dimension corresponding for the given indices.
///
/// # Arguments
///
/// * `tensor` - The tensor to select elements from.
/// * `dim` - The axis along which to select elements.
/// * `indexes` - The indexes of the elements to select.
/// * `indices` - The indices of the elements to select.
///
/// # Returns
///
@ -1080,28 +1080,28 @@ where
/// or use this function directly.
///
/// For selecting elements from a tensor along an axis, users should prefer the
/// [Tensor::index_select](Tensor::index_select) function, which is more high-level and designed for public use.
fn index_select<const D: usize>(
/// [Tensor::select](Tensor::select) function, which is more high-level and designed for public use.
fn select<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
indices: Tensor<B, 1, Int>,
) -> Self::Primitive<D>;
/// Assign the selected elements along the given dimension corresponding to the given indexes
/// Assign the selected elements along the given dimension corresponding to the given indices
/// from the value tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to assign elements to.
/// * `dim` - The axis along which to assign elements.
/// * `indexes` - The indexes of the elements to assign.
/// * `indices` - The indices of the elements to assign.
/// * `values` - The values to assign to the tensor.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is taken from the
/// corresponding element of the input tensor at the corresponding index along the specified axis,
/// except for the elements at the specified indexes, which are taken from the corresponding
/// except for the elements at the specified indices, which are taken from the corresponding
/// element of the values tensor.
///
/// # Remarks
@ -1111,20 +1111,20 @@ where
/// or use this function directly.
///
/// For assigning elements to a tensor along an axis, users should prefer the
/// [Tensor::index_select_assign](Tensor::index_select_assign) function, which is more high-level and designed for public use.
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
/// [Tensor::select_assign](Tensor::select_assign) function, which is more high-level and designed for public use.
fn select_assign<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Self::Primitive<D2>,
) -> Self::Primitive<D1>;
indices: Tensor<B, 1, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D>;
/// Gets the indexes of the maximum elements of a tensor along an axis.
/// Gets the indices of the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to get the indexes of the maximum elements.
/// * `tensor` - The tensor to get the indexes of the maximum elements from.
/// * `dim` - The axis along which to get the indices of the maximum elements.
/// * `tensor` - The tensor to get the indices of the maximum elements from.
///
/// # Returns
///
@ -1137,16 +1137,16 @@ where
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the indexes of the maximum elements of a tensor along an axis, users should prefer the
/// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the
/// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use.
fn argmax<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> B::IntTensorPrimitive<D>;
/// Gets the indexes of the minimum elements of a tensor along an axis.
/// Gets the indices of the minimum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to get the indexes of the minimum elements.
/// * `tensor` - The tensor to get the indexes of the minimum elements from.
/// * `dim` - The axis along which to get the indices of the minimum elements.
/// * `tensor` - The tensor to get the indices of the minimum elements from.
///
/// # Returns
///
@ -1159,7 +1159,7 @@ where
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the indexes of the minimum elements of a tensor along an axis, users should prefer the
/// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the
/// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use.
fn argmin<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> B::IntTensorPrimitive<D>;
@ -1224,8 +1224,8 @@ where
/// or use this function directly.
///
/// For getting the maximum elements of a tensor along an axis, users should prefer the
/// [Tensor::max_dim_with_indexes](Tensor::max_dim_with_indexes) function, which is more high-level and designed for public use.
fn max_dim_with_indexes<const D: usize>(
/// [Tensor::max_dim_with_indices](Tensor::max_dim_with_indices) function, which is more high-level and designed for public use.
fn max_dim_with_indices<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
) -> (Self::Primitive<D>, B::IntTensorPrimitive<D>);
@ -1291,8 +1291,8 @@ where
/// or use this function directly.
///
/// For getting the minimum elements of a tensor along an axis, users should prefer the
/// [Tensor::min_dim_with_indexes](Tensor::min_dim_with_indexes) function, which is more high-level and designed for public use.
fn min_dim_with_indexes<const D: usize>(
/// [Tensor::min_dim_with_indices](Tensor::min_dim_with_indices) function, which is more high-level and designed for public use.
fn min_dim_with_indices<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
) -> (Self::Primitive<D>, B::IntTensorPrimitive<D>);
@ -1441,37 +1441,37 @@ impl<B: Backend> Numeric<B> for Int {
B::int_mask_fill(tensor, mask.primitive, value)
}
fn index_select<const D: usize>(
fn select<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
indices: Tensor<B, 1, Int>,
) -> Self::Primitive<D> {
B::int_index_select_dim(tensor, dim, indexes.primitive)
B::int_select(tensor, dim, indices.primitive)
}
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
fn select_assign<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Self::Primitive<D2>,
) -> Self::Primitive<D1> {
B::int_index_select_dim_assign(tensor, dim, indexes.primitive, values)
indices: Tensor<B, 1, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::int_select_assign(tensor, dim, indices.primitive, values)
}
fn gather<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
indices: Tensor<B, D, Int>,
) -> Self::Primitive<D> {
B::int_gather(dim, tensor, indexes.primitive)
B::int_gather(dim, tensor, indices.primitive)
}
fn scatter<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
indices: Tensor<B, D, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::int_scatter(dim, tensor, indexes.primitive, values)
B::int_scatter(dim, tensor, indices.primitive, values)
}
fn argmax<const D: usize>(
@ -1496,11 +1496,11 @@ impl<B: Backend> Numeric<B> for Int {
B::int_max_dim(tensor, dim)
}
fn max_dim_with_indexes<const D: usize>(
fn max_dim_with_indices<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
) -> (Self::Primitive<D>, <B as Backend>::IntTensorPrimitive<D>) {
B::int_max_dim_with_indexes(tensor, dim)
B::int_max_dim_with_indices(tensor, dim)
}
fn min<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
@ -1511,11 +1511,11 @@ impl<B: Backend> Numeric<B> for Int {
B::int_min_dim(tensor, dim)
}
fn min_dim_with_indexes<const D: usize>(
fn min_dim_with_indices<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
) -> (Self::Primitive<D>, <B as Backend>::IntTensorPrimitive<D>) {
B::int_min_dim_with_indexes(tensor, dim)
B::int_min_dim_with_indices(tensor, dim)
}
}
@ -1662,38 +1662,38 @@ impl<B: Backend> Numeric<B> for Float {
B::mask_fill(tensor, mask.primitive, value)
}
fn index_select<const D: usize>(
fn select<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
indices: Tensor<B, 1, Int>,
) -> Self::Primitive<D> {
B::index_select(tensor, dim, indexes.primitive)
B::select(tensor, dim, indices.primitive)
}
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
fn select_assign<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Self::Primitive<D2>,
) -> Self::Primitive<D1> {
B::index_select_assign(tensor, dim, indexes.primitive, values)
indices: Tensor<B, 1, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::select_assign(tensor, dim, indices.primitive, values)
}
fn gather<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
indices: Tensor<B, D, Int>,
) -> Self::Primitive<D> {
B::gather(dim, tensor, indexes.primitive)
B::gather(dim, tensor, indices.primitive)
}
fn scatter<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
indices: Tensor<B, D, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::scatter(dim, tensor, indexes.primitive, values)
B::scatter(dim, tensor, indices.primitive, values)
}
fn argmax<const D: usize>(
@ -1718,11 +1718,11 @@ impl<B: Backend> Numeric<B> for Float {
B::max_dim(tensor, dim)
}
fn max_dim_with_indexes<const D: usize>(
fn max_dim_with_indices<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
) -> (Self::Primitive<D>, <B as Backend>::IntTensorPrimitive<D>) {
B::max_dim_with_indexes(tensor, dim)
B::max_dim_with_indices(tensor, dim)
}
fn min<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
@ -1733,11 +1733,11 @@ impl<B: Backend> Numeric<B> for Float {
B::min_dim(tensor, dim)
}
fn min_dim_with_indexes<const D: usize>(
fn min_dim_with_indices<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
) -> (Self::Primitive<D>, <B as Backend>::IntTensorPrimitive<D>) {
B::min_dim_with_indexes(tensor, dim)
B::min_dim_with_indices(tensor, dim)
}
}

View File

@ -5,11 +5,11 @@ use crate::{
};
/// Applies the [embedding module](crate::ops::ModuleOps::embedding).
pub fn embedding<B>(weights: Tensor<B, 2>, indexes: Tensor<B, 2, Int>) -> Tensor<B, 3>
pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(B::embedding(weights.primitive, indexes.primitive))
Tensor::new(B::embedding(weights.primitive, indices.primitive))
}
/// Applies a [1D convolution](crate::ops::ModuleOps::conv2d).
@ -123,8 +123,8 @@ where
Tensor::new(B::avg_pool1d(x.primitive, kernel_size, stride, padding))
}
/// Applies a [2D max pooling with indexes](crate::ops::ModuleOps::max_pool2d_with_indexes).
pub fn max_pool2d_with_indexes<B>(
/// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
pub fn max_pool2d_with_indices<B>(
x: Tensor<B, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
@ -133,7 +133,7 @@ pub fn max_pool2d_with_indexes<B>(
where
B: Backend,
{
let output = B::max_pool2d_with_indexes(x.primitive, kernel_size, stride, padding);
let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding);
(Tensor::new(output.output), Tensor::new(output.indexes))
(Tensor::new(output.output), Tensor::new(output.indices))
}

View File

@ -114,35 +114,35 @@ pub trait BoolTensorOps<B: Backend> {
shape: Shape<D2>,
) -> B::BoolTensorPrimitive<D2>;
/// Gets the values from the tensor for the given indexes.
/// Gets the values from the tensor for the given ranges.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `indexes` - The indexes to get the values from.
/// * `ranges` - The ranges to get the values from.
///
/// # Returns
///
/// The tensor with the values for the given indexes.
fn bool_index<const D1: usize, const D2: usize>(
/// The tensor with the values for the given ranges.
fn bool_slice<const D1: usize, const D2: usize>(
tensor: B::BoolTensorPrimitive<D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> B::BoolTensorPrimitive<D1>;
/// Sets the values in the tensor for the given indexes.
/// Sets the values in the tensor for the given ranges.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `indexes` - The indexes to set the values for.
/// * `ranges` - The ranges to set the values for.
/// * `value` - The values to set.
///
/// # Returns
///
/// The tensor with the values set for the given indexes.
fn bool_index_assign<const D1: usize, const D2: usize>(
/// The tensor with the values set for the given ranges.
fn bool_slice_assign<const D1: usize, const D2: usize>(
tensor: B::BoolTensorPrimitive<D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: B::BoolTensorPrimitive<D1>,
) -> B::BoolTensorPrimitive<D1>;
@ -169,7 +169,7 @@ pub trait BoolTensorOps<B: Backend> {
shape.dims[dim] = times;
let mut i = 0;
let indexes_select_all = [0; D].map(|_| {
let ranges_select_all = [0; D].map(|_| {
let start = 0;
let end = shape.dims[i];
i += 1;
@ -178,9 +178,9 @@ pub trait BoolTensorOps<B: Backend> {
let mut tensor_output = Self::bool_empty(shape, &Self::bool_device(&tensor));
for i in 0..times {
let mut indexes = indexes_select_all.clone();
indexes[dim] = i..i + 1;
tensor_output = Self::bool_index_assign(tensor_output, indexes, tensor.clone());
let mut ranges = ranges_select_all.clone();
ranges[dim] = i..i + 1;
tensor_output = Self::bool_slice_assign(tensor_output, ranges, tensor.clone());
}
tensor_output

View File

@ -110,9 +110,9 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// The elements at the given indices.
fn int_index<const D1: usize, const D2: usize>(
fn int_slice<const D1: usize, const D2: usize>(
tensor: B::IntTensorPrimitive<D1>,
indexes: [Range<usize>; D2],
indices: [Range<usize>; D2],
) -> B::IntTensorPrimitive<D1>;
/// Sets the element at the given indices.
@ -125,7 +125,7 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the element at the given indices set.
fn int_index_assign<const D1: usize, const D2: usize>(
fn int_slice_assign<const D1: usize, const D2: usize>(
tensor: B::IntTensorPrimitive<D1>,
indices: [Range<usize>; D2],
value: B::IntTensorPrimitive<D1>,
@ -204,15 +204,15 @@ pub trait IntTensorOps<B: Backend> {
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension to select from.
/// * `indexes` - The indexes.
/// * `indices` - The indices.
///
/// # Returns
///
/// The tensor with the selected elements.
fn int_index_select_dim<const D: usize>(
fn int_select<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
indices: B::IntTensorPrimitive<1>,
) -> B::IntTensorPrimitive<D>;
/// Assign the selected elements along the given dimension corresponding to the given indices
@ -222,18 +222,18 @@ pub trait IntTensorOps<B: Backend> {
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension to select from.
/// * `indexes` - The indexes.
/// * `indices` - The indices.
/// * `value` - The value.
///
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: B::IntTensorPrimitive<D1>,
fn int_select_assign<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
value: B::IntTensorPrimitive<D2>,
) -> B::IntTensorPrimitive<D1>;
indices: B::IntTensorPrimitive<1>,
value: B::IntTensorPrimitive<D>,
) -> B::IntTensorPrimitive<D>;
/// Repeats the tensor along the given dimension the given number of times.
///
@ -258,7 +258,7 @@ pub trait IntTensorOps<B: Backend> {
shape.dims[dim] = times;
let mut i = 0;
let indexes_select_all = [0; D].map(|_| {
let indices_select_all = [0; D].map(|_| {
let start = 0;
let end = shape.dims[i];
i += 1;
@ -267,9 +267,9 @@ pub trait IntTensorOps<B: Backend> {
let mut tensor_output = Self::int_empty(shape, &Self::int_device(&tensor));
for i in 0..times {
let mut indexes = indexes_select_all.clone();
indexes[dim] = i..i + 1;
tensor_output = Self::int_index_assign(tensor_output, indexes, tensor.clone());
let mut indices = indices_select_all.clone();
indices[dim] = i..i + 1;
tensor_output = Self::int_slice_assign(tensor_output, indices, tensor.clone());
}
tensor_output
@ -725,7 +725,7 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// The maximum elements and corresponding indices along the dimension.
fn int_max_dim_with_indexes<const D: usize>(
fn int_max_dim_with_indices<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
@ -780,13 +780,13 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// The minimum elements and corresponding indices along the dimension.
fn int_min_dim_with_indexes<const D: usize>(
fn int_min_dim_with_indices<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
let index = B::int_argmin(tensor.clone(), dim);
let values = B::int_gather(D - 1, tensor, index.clone());
let indices = B::int_argmin(tensor.clone(), dim);
let values = B::int_gather(D - 1, tensor, indices.clone());
(values, index)
(values, indices)
}
}

View File

@ -21,14 +21,14 @@ pub struct MaxPool2dBackward<B: Backend> {
pub x_grad: B::TensorPrimitive<4>,
}
/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indexes).
/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices).
#[derive(new)]
pub struct MaxPool2dWithIndexes<B: Backend> {
pub struct MaxPool2dWithIndices<B: Backend> {
/// The output tensor.
pub output: B::TensorPrimitive<4>,
/// The indexes tensor.
pub indexes: B::IntTensorPrimitive<4>,
/// The indices tensor.
pub indices: B::IntTensorPrimitive<4>,
}
/// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d).
@ -86,20 +86,20 @@ pub trait ModuleOps<B: Backend> {
/// # Arguments
///
/// * `weights` - The embedding weights.
/// * `indexes` - The indexes tensor.
/// * `indices` - The indices tensor.
///
/// # Returns
///
/// The output tensor.
fn embedding(
weights: B::TensorPrimitive<2>,
indexes: B::IntTensorPrimitive<2>,
indices: B::IntTensorPrimitive<2>,
) -> B::TensorPrimitive<3> {
let [batch_size, seq_length] = B::int_shape(&indexes).dims;
let [batch_size, seq_length] = B::int_shape(&indices).dims;
let [_, d_model] = B::shape(&weights).dims;
let indexes = B::int_reshape(indexes, Shape::new([batch_size * seq_length]));
let output = B::index_select(weights, 0, indexes);
let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
let output = B::select(weights, 0, indices);
B::reshape(output, Shape::new([batch_size, seq_length, d_model]))
}
@ -110,7 +110,7 @@ pub trait ModuleOps<B: Backend> {
///
/// * `weights` - The embedding weights.
/// * `output_grad` - The output gradient.
/// * `indexes` - The indexes tensor.
/// * `indices` - The indices tensor.
///
/// # Returns
///
@ -118,17 +118,17 @@ pub trait ModuleOps<B: Backend> {
fn embedding_backward(
weights: B::TensorPrimitive<2>,
output_grad: B::TensorPrimitive<3>,
indexes: B::IntTensorPrimitive<2>,
indices: B::IntTensorPrimitive<2>,
) -> B::TensorPrimitive<2> {
let [batch_size, seq_length] = B::int_shape(&indexes).dims;
let [batch_size, seq_length] = B::int_shape(&indices).dims;
let [n_embeddings, d_model] = B::shape(&weights).dims;
let device = B::device(&weights);
let indexes = B::int_reshape(indexes, Shape::new([batch_size * seq_length]));
let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
let output_grad = B::reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
let grad = B::zeros(Shape::new([n_embeddings, d_model]), &device);
B::index_select_assign(grad, 0, indexes, output_grad)
B::select_assign(grad, 0, indices, output_grad)
}
/// Two dimensional convolution.
@ -263,24 +263,24 @@ pub trait ModuleOps<B: Backend> {
padding: [usize; 2],
) -> B::TensorPrimitive<4>;
/// Two dimensional max pooling with indexes.
/// Two dimensional max pooling with indices.
///
/// # Shapes
///
/// x: [batch_size, channels, height, width],
fn max_pool2d_with_indexes(
fn max_pool2d_with_indices(
x: B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> MaxPool2dWithIndexes<B>;
/// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indexes) operation.
fn max_pool2d_with_indexes_backward(
) -> MaxPool2dWithIndices<B>;
/// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.
fn max_pool2d_with_indices_backward(
x: B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
output_grad: B::TensorPrimitive<4>,
indexes: B::IntTensorPrimitive<4>,
indices: B::IntTensorPrimitive<4>,
) -> MaxPool2dBackward<B>;
}

View File

@ -243,8 +243,8 @@ fn conv1d_weight_grad_groups<B: Backend>(
let start_idx_co = g * increment_co;
let end_idx_co = (g + 1) * increment_co;
let x = B::index(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::index(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let mut weight_grad_tmp = B::conv1d(
x,
grad,
@ -252,7 +252,7 @@ fn conv1d_weight_grad_groups<B: Backend>(
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1);
weight_grad = B::index_assign(
weight_grad = B::slice_assign(
weight_grad,
[start_idx_co..end_idx_co, 0..increment_ci, 0..kernel_size],
weight_grad_tmp,
@ -280,8 +280,8 @@ fn conv2d_weight_grad_groups<B: Backend>(
let start_idx_co = g * increment_co;
let end_idx_co = (g + 1) * increment_co;
let x = B::index(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::index(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let mut weight_grad_tmp = B::conv2d(
x,
grad,
@ -289,7 +289,7 @@ fn conv2d_weight_grad_groups<B: Backend>(
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1);
weight_grad = B::index_assign(
weight_grad = B::slice_assign(
weight_grad,
[
start_idx_co..end_idx_co,
@ -321,7 +321,7 @@ fn conv1d_weight_grad_no_groups<B: Backend>(
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
if B::shape(&weight_grad) != weight_shape {
weight_grad = B::index(
weight_grad = B::slice(
weight_grad,
[
0..weight_shape.dims[0],
@ -350,7 +350,7 @@ fn conv2d_weight_grad_no_groups<B: Backend>(
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
if B::shape(&weight_grad) != weight_shape {
weight_grad = B::index(
weight_grad = B::slice(
weight_grad,
[
0..weight_shape.dims[0],

View File

@ -181,7 +181,7 @@ pub trait TensorOps<B: Backend> {
shape.dims[dim] = times;
let mut i = 0;
let indexes_select_all = [0; D].map(|_| {
let indices_select_all = [0; D].map(|_| {
let start = 0;
let end = shape.dims[i];
i += 1;
@ -190,9 +190,9 @@ pub trait TensorOps<B: Backend> {
let mut tensor_output = B::empty(shape, &B::device(&tensor));
for i in 0..times {
let mut indexes = indexes_select_all.clone();
indexes[dim] = i..i + 1;
tensor_output = B::index_assign(tensor_output, indexes, tensor.clone());
let mut indices = indices_select_all.clone();
indices[dim] = i..i + 1;
tensor_output = B::slice_assign(tensor_output, indices, tensor.clone());
}
tensor_output
@ -380,7 +380,7 @@ pub trait TensorOps<B: Backend> {
///
/// * `dim` - The dimension to gather from.
/// * `tensor` - The tensor to gather from.
/// * `indexes` - The indexes to gather.
/// * `indices` - The indices to gather.
///
/// # Returns
///
@ -388,7 +388,7 @@ pub trait TensorOps<B: Backend> {
fn gather<const D: usize>(
dim: usize,
tensor: B::TensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
indices: B::IntTensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
/// Scatter elements into a tensor.
@ -397,7 +397,7 @@ pub trait TensorOps<B: Backend> {
///
/// * `dim` - The dimension to scatter into.
/// * `tensor` - The tensor to scatter into.
/// * `indexes` - The indexes to scatter into.
/// * `indices` - The indices to scatter into.
/// * `value` - The value to scatter.
///
/// # Returns
@ -406,76 +406,76 @@ pub trait TensorOps<B: Backend> {
fn scatter<const D: usize>(
dim: usize,
tensor: B::TensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
indices: B::IntTensorPrimitive<D>,
value: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
/// Select tensor elements along the given dimension corresponding for the given indexes.
/// Select tensor elements along the given dimension corresponding for the given indices.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `dim` - The dimension to select from.
/// * `indexes` - The indexes to select.
/// * `indices` - The indices to select.
///
/// # Returns
///
/// The selected elements.
fn index_select<const D: usize>(
fn select<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
indices: B::IntTensorPrimitive<1>,
) -> B::TensorPrimitive<D>;
/// Assign the selected elements along the given dimension corresponding for the given indexes
/// Assign the selected elements along the given dimension corresponding for the given indices
/// to the given value.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `dim` - The dimension to select from.
/// * `indexes` - The indexes to select.
/// * `indices` - The indices to select.
/// * `value` - The value to assign.
///
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: B::TensorPrimitive<D1>,
fn select_assign<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
value: B::TensorPrimitive<D2>,
) -> B::TensorPrimitive<D1>;
indices: B::IntTensorPrimitive<1>,
value: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
/// Select tensor elements corresponding for the given indexes.
/// Select tensor elements corresponding for the given ranges.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `indexes` - The indexes to select.
/// * `ranges` - The ranges to select.
///
/// # Returns
///
/// The selected elements in a new tensor.
fn index<const D1: usize, const D2: usize>(
fn slice<const D1: usize, const D2: usize>(
tensor: B::TensorPrimitive<D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> B::TensorPrimitive<D1>;
/// Assign the selected elements corresponding for the given indexes to the given value.
/// Assign the selected elements corresponding for the given ranges to the given value.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `indexes` - The indexes to select.
/// * `ranges` - The ranges to select.
/// * `value` - The value to assign.
///
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn index_assign<const D1: usize, const D2: usize>(
fn slice_assign<const D1: usize, const D2: usize>(
tensor: B::TensorPrimitive<D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: B::TensorPrimitive<D1>,
) -> B::TensorPrimitive<D1>;
@ -872,7 +872,7 @@ pub trait TensorOps<B: Backend> {
dim: usize,
) -> B::TensorPrimitive<D>;
/// Gets the indexes of the maximum elements of a tensor along an axis.
/// Gets the indices of the maximum elements of a tensor along an axis.
///
/// # Arguments
///
@ -881,13 +881,13 @@ pub trait TensorOps<B: Backend> {
///
/// # Returns
///
/// A tensor with the indexes of the maximum elements of `tensor` along `dim`.
/// A tensor with the indices of the maximum elements of `tensor` along `dim`.
fn argmax<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
) -> B::IntTensorPrimitive<D>;
/// Gets the indexes of the minimum elements of a tensor along an axis.
/// Gets the indices of the minimum elements of a tensor along an axis.
///
/// # Arguments
///
@ -896,7 +896,7 @@ pub trait TensorOps<B: Backend> {
///
/// # Returns
///
/// A tensor with the indexes of the minimum elements of `tensor` along `dim`.
/// A tensor with the indices of the minimum elements of `tensor` along `dim`.
fn argmin<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
@ -934,7 +934,7 @@ pub trait TensorOps<B: Backend> {
B::gather(D - 1, tensor, index)
}
/// Gets the maximum elements of a tensor along an axis and their indexes.
/// Gets the maximum elements of a tensor along an axis and their indices.
///
/// # Arguments
///
@ -943,8 +943,8 @@ pub trait TensorOps<B: Backend> {
///
/// # Returns
///
/// A tuple with the maximum elements of `tensor` along `dim` and their indexes.
fn max_dim_with_indexes<const D: usize>(
/// A tuple with the maximum elements of `tensor` along `dim` and their indices.
fn max_dim_with_indices<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
) -> (B::TensorPrimitive<D>, B::IntTensorPrimitive<D>) {
@ -986,7 +986,7 @@ pub trait TensorOps<B: Backend> {
B::gather(D - 1, tensor, index)
}
/// Gets the minimum elements of a tensor along an axis and their indexes.
/// Gets the minimum elements of a tensor along an axis and their indices.
///
/// # Arguments
///
@ -995,8 +995,8 @@ pub trait TensorOps<B: Backend> {
///
/// # Returns
///
/// A tuple with the minimum elements of `tensor` along `dim` and their indexes.
fn min_dim_with_indexes<const D: usize>(
/// A tuple with the minimum elements of `tensor` along `dim` and their indices.
fn min_dim_with_indices<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
) -> (B::TensorPrimitive<D>, B::IntTensorPrimitive<D>) {

View File

@ -37,9 +37,9 @@ macro_rules! testgen_all {
burn_tensor::testgen_log!();
burn_tensor::testgen_sqrt!();
burn_tensor::testgen_log1p!();
burn_tensor::testgen_index!();
burn_tensor::testgen_slice!();
burn_tensor::testgen_gather_scatter!();
burn_tensor::testgen_index_select!();
burn_tensor::testgen_select!();
burn_tensor::testgen_map_comparison!();
burn_tensor::testgen_mask!();
burn_tensor::testgen_matmul!();

View File

@ -6,11 +6,11 @@ mod tests {
#[test]
fn test_embedding_forward() {
let weights = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indexes = Data::from([[0, 1], [1, 1]]);
let indices = Data::from([[0, 1], [1, 1]]);
let weights = Tensor::<TestBackend, 2>::from_data(weights);
let indexes = Tensor::<TestBackend, 2, Int>::from_data(indexes);
let indices = Tensor::<TestBackend, 2, Int>::from_data(indices);
let output = embedding(weights, indexes);
let output = embedding(weights, indices);
let expected = Data::from([
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
[[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],

View File

@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(module_max_pool2d)]
mod tests {
use super::*;
use burn_tensor::module::{max_pool2d, max_pool2d_with_indexes};
use burn_tensor::module::{max_pool2d, max_pool2d_with_indices};
use burn_tensor::{Data, Tensor};
#[test]
@ -180,7 +180,7 @@ mod tests {
}
#[test]
fn test_max_pool2d_with_indexes() {
fn test_max_pool2d_with_indices() {
let batch_size = 1;
let channels_in = 1;
let kernel_size_1 = 2;
@ -196,7 +196,7 @@ mod tests {
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]]);
let indexes = Data::<i64, 4>::from([[[
let indices = Data::<i64, 4>::from([[[
[0, 1, 1, 3, 3],
[4, 4, 1, 7, 7],
[4, 9, 9, 7, 7],
@ -211,7 +211,7 @@ mod tests {
[0.3358, 0.3358, 0.8293, 0.8293, 0.0990],
]]]);
let (output, output_indexes) = max_pool2d_with_indexes(
let (output, output_indices) = max_pool2d_with_indices(
x,
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
@ -219,7 +219,7 @@ mod tests {
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
assert_eq!(indexes.value, output_indexes.into_data().value);
assert_eq!(indices.value, output_indices.into_data().value);
}
#[test]
@ -240,7 +240,7 @@ mod tests {
[0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
[0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
]]]);
let indexes = Data::<i64, 4>::from([[[
let indices = Data::<i64, 4>::from([[[
[5, 7, 3],
[5, 7, 3],
[5, 16, 3],
@ -256,7 +256,7 @@ mod tests {
[0.4384, 0.9963, 0.688],
[0.4384, 0.9963, 0.688],
]]]);
let (output, output_indexes) = max_pool2d_with_indexes(
let (output, output_indices) = max_pool2d_with_indices(
x,
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
@ -264,6 +264,6 @@ mod tests {
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
assert_eq!(indexes.value, output_indexes.into_data().value);
assert_eq!(indices.value, output_indices.into_data().value);
}
}

View File

@ -6,9 +6,9 @@ mod tests {
#[test]
fn should_gather_1d_dim0() {
let tensor = TestTensor::from_floats([0.0, 1.0, 2.0]);
let indexes = TestTensorInt::from_ints([1, 1, 0, 1, 2]);
let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2]);
let output = tensor.gather(0, indexes);
let output = tensor.gather(0, indices);
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
}
@ -16,9 +16,9 @@ mod tests {
#[test]
fn should_gather_2d_dim0() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indexes = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]]);
let indices = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]]);
let output = tensor.gather(0, indexes);
let output = tensor.gather(0, indices);
assert_eq!(
output.into_data(),
@ -29,9 +29,9 @@ mod tests {
#[test]
fn should_gather_2d_dim1() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indexes = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]]);
let indices = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]]);
let output = tensor.gather(1, indexes);
let output = tensor.gather(1, indices);
assert_eq!(
output.into_data(),
@ -45,9 +45,9 @@ mod tests {
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
]);
let indexes = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]);
let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]);
let output = tensor.gather(1, indexes);
let output = tensor.gather(1, indices);
assert_eq!(
output.into_data(),
@ -61,9 +61,9 @@ mod tests {
#[test]
fn should_gather_2d_only_1dim() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indexes = TestTensorInt::from_ints([[1, 2]]).reshape([2, 1]);
let indices = TestTensorInt::from_ints([[1, 2]]).reshape([2, 1]);
let output = tensor.gather(1, indexes);
let output = tensor.gather(1, indices);
assert_eq!(output.into_data(), Data::from([[1.0], [5.0]]));
}
@ -72,9 +72,9 @@ mod tests {
fn should_scatter_1d() {
let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]);
let values = TestTensor::from_floats([5.0, 4.0, 3.0]);
let indexes = TestTensorInt::from_ints([1, 0, 2]);
let indices = TestTensorInt::from_ints([1, 0, 2]);
let output = tensor.scatter(0, indexes, values);
let output = tensor.scatter(0, indices, values);
assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0]));
}
@ -83,9 +83,9 @@ mod tests {
fn should_scatter_2d_dim0() {
let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);
let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let indexes = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]]);
let indices = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]]);
let output = tensor.scatter(0, indexes, values);
let output = tensor.scatter(0, indices, values);
assert_eq!(
output.into_data(),
@ -97,9 +97,9 @@ mod tests {
fn should_scatter_2d_dim1() {
let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);
let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let indexes = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]]);
let indices = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]]);
let output = tensor.scatter(1, indexes, values);
let output = tensor.scatter(1, indices, values);
assert_eq!(
output.into_data(),
@ -117,9 +117,9 @@ mod tests {
[[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]],
[[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],
]);
let indexes = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]);
let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]);
let output = tensor.scatter(1, indexes, values);
let output = tensor.scatter(1, indices, values);
assert_eq!(
output.into_data(),

View File

@ -14,10 +14,10 @@ mod tests {
}
#[test]
fn test_max_dim_with_indexes_2d() {
fn test_max_dim_with_indices_2d() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let (output_actual, index_actual) = tensor.max_dim_with_indexes(1);
let (output_actual, index_actual) = tensor.max_dim_with_indices(1);
let output_expected = Data::from([[2.], [5.]]);
let index_expected = Data::from([[2], [2]]);
@ -37,10 +37,10 @@ mod tests {
}
#[test]
fn test_min_dim_with_indexes_2d() {
fn test_min_dim_with_indices_2d() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let (output_actual, index_actual) = tensor.min_dim_with_indexes(1);
let (output_actual, index_actual) = tensor.min_dim_with_indices(1);
let output_expected = Data::from([[0.], [3.]]);
let index_expected = Data::from([[0], [0]]);

View File

@ -9,8 +9,6 @@ mod erf;
mod exp;
mod flatten;
mod gather_scatter;
mod index;
mod index_select;
mod log;
mod log1p;
mod map_comparison;
@ -22,7 +20,9 @@ mod neg;
mod powf;
mod repeat;
mod reshape;
mod select;
mod sin;
mod slice;
mod sqrt;
mod squeeze;
mod sub;

View File

@ -1,4 +1,4 @@
#[burn_tensor_testgen::testgen(index_select)]
#[burn_tensor_testgen::testgen(select)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
@ -6,9 +6,9 @@ mod tests {
#[test]
fn should_select_1d() {
let tensor = TestTensor::from_data([0.0, 1.0, 2.0]);
let indexes = TestTensorInt::from_data([1, 1, 0, 1, 2]);
let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]);
let output = tensor.index_select(0, indexes);
let output = tensor.select(0, indices);
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
}
@ -16,9 +16,9 @@ mod tests {
#[test]
fn should_select_2d_dim0_same_num_dim() {
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indexes = TestTensorInt::from_data(([1, 0]));
let indices = TestTensorInt::from_data(([1, 0]));
let output = tensor.index_select(0, indexes);
let output = tensor.select(0, indices);
assert_eq!(
output.into_data(),
@ -29,9 +29,9 @@ mod tests {
#[test]
fn should_select_2d_dim0_more_num_dim() {
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indexes = TestTensorInt::from_data([1, 0, 1, 1]);
let indices = TestTensorInt::from_data([1, 0, 1, 1]);
let output = tensor.index_select(0, indexes);
let output = tensor.select(0, indices);
assert_eq!(
output.into_data(),
@ -47,9 +47,9 @@ mod tests {
#[test]
fn should_select_2d_dim1() {
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indexes = TestTensorInt::from_data([1, 1, 0, 1, 2]);
let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]);
let output = tensor.index_select(1, indexes);
let output = tensor.select(1, indices);
assert_eq!(
output.into_data(),
@ -61,9 +61,9 @@ mod tests {
fn should_select_assign_1d() {
let tensor = TestTensor::from_data([0.0, 1.0, 2.0]);
let values = TestTensor::from_data([5.0, 4.0, 3.0, 2.0, 1.0]);
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
let indices = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
let output = tensor.index_select_assign(0, indexes, values);
let output = tensor.select_assign(0, indices, values);
assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0]));
}
@ -72,9 +72,9 @@ mod tests {
fn should_select_assign_2d_dim0() {
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
let indices = TestTensorInt::from_data(Data::from([1, 0]));
let output = tensor.index_select_assign(0, indexes, values);
let output = tensor.select_assign(0, indices, values);
assert_eq!(
output.into_data(),
@ -86,9 +86,9 @@ mod tests {
fn should_select_assign_2d_dim1() {
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let indexes = TestTensorInt::from_data(Data::from([1, 0, 2]));
let indices = TestTensorInt::from_data(Data::from([1, 0, 2]));
let output = tensor.index_select_assign(1, indexes, values);
let output = tensor.select_assign(1, indices, values);
assert_eq!(
output.into_data(),

View File

@ -1,94 +1,94 @@
#[burn_tensor_testgen::testgen(index)]
#[burn_tensor_testgen::testgen(slice)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn should_support_full_indexing_1d() {
fn should_support_full_sliceing_1d() {
let data = Data::from([0.0, 1.0, 2.0]);
let tensor = Tensor::<TestBackend, 1>::from_data(data.clone());
let data_actual = tensor.index([0..3]).into_data();
let data_actual = tensor.slice([0..3]).into_data();
assert_eq!(data, data_actual);
}
#[test]
fn should_support_partial_indexing_1d() {
fn should_support_partial_sliceing_1d() {
let data = Data::from([0.0, 1.0, 2.0]);
let tensor = Tensor::<TestBackend, 1>::from_data(data);
let data_actual = tensor.index([1..3]).into_data();
let data_actual = tensor.slice([1..3]).into_data();
let data_expected = Data::from([1.0, 2.0]);
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_full_indexing_2d() {
fn should_support_full_sliceing_2d() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data.clone());
let data_actual_1 = tensor.clone().index([0..2]).into_data();
let data_actual_2 = tensor.index([0..2, 0..3]).into_data();
let data_actual_1 = tensor.clone().slice([0..2]).into_data();
let data_actual_2 = tensor.slice([0..2, 0..3]).into_data();
assert_eq!(data, data_actual_1);
assert_eq!(data, data_actual_2);
}
#[test]
fn should_support_partial_indexing_2d() {
fn should_support_partial_sliceing_2d() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.index([0..2, 0..2]).into_data();
let data_actual = tensor.slice([0..2, 0..2]).into_data();
let data_expected = Data::from([[0.0, 1.0], [3.0, 4.0]]);
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_partial_indexing_3d() {
fn should_support_partial_sliceing_3d() {
let tensor = TestTensor::from_floats([
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
]);
let data_actual = tensor.index([1..2, 1..2, 0..2]).into_data();
let data_actual = tensor.slice([1..2, 1..2, 0..2]).into_data();
let data_expected = Data::from([[[9.0, 10.0]]]);
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_partial_indexing_3d_non_continuous() {
fn should_support_partial_sliceing_3d_non_continuous() {
let tensor = TestTensor::from_floats([
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
]);
let data_actual = tensor.transpose().index([1..2, 1..2, 0..2]).into_data();
let data_actual = tensor.transpose().slice([1..2, 1..2, 0..2]).into_data();
let data_expected = Data::from([[[7.0, 10.0]]]);
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_indexe_assign_1d() {
fn should_support_slicee_assign_1d() {
let data = Data::from([0.0, 1.0, 2.0]);
let data_assigned = Data::from([10.0, 5.0]);
let tensor = Tensor::<TestBackend, 1>::from_data(data);
let tensor_assigned = Tensor::<TestBackend, 1>::from_data(data_assigned);
let data_actual = tensor.index_assign([0..2], tensor_assigned).into_data();
let data_actual = tensor.slice_assign([0..2], tensor_assigned).into_data();
let data_expected = Data::from([10.0, 5.0, 2.0]);
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_indexe_assign_2d() {
fn should_support_slicee_assign_2d() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let data_assigned = Data::from([[10.0, 5.0]]);
@ -96,7 +96,7 @@ mod tests {
let tensor_assigned = Tensor::<TestBackend, 2>::from_data(data_assigned);
let data_actual = tensor
.index_assign([1..2, 0..2], tensor_assigned)
.slice_assign([1..2, 0..2], tensor_assigned)
.into_data();
let data_expected = Data::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]);

View File

@ -52,4 +52,4 @@ harness = false
[[bench]]
name = "matmul"
harness = false
harness = false

View File

@ -35,6 +35,41 @@ macro_rules! kernel_wgsl {
};
}
kernel_wgsl!(ContinuousRaw, "../template/continuous.wgsl");
pub(crate) fn into_continuous<E: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
if tensor.is_continuous() {
return tensor;
}
let buffer = tensor
.context
.create_buffer(tensor.shape.num_elements() * core::mem::size_of::<E>());
let output = WgpuTensor::new(tensor.context.clone(), tensor.shape.clone(), buffer);
let info = build_info(&[&tensor, &output]);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<ContinuousRaw, E, i32, 256, 1, 1>>();
tensor.context.execute(
WorkGroup::new(
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
kernel,
&[&tensor.buffer, &output.buffer, &info_buffer],
);
output
}
/// Generates kernel source code by replacing some information using templating.
pub struct KernelSettings<
K: StaticKernel,

View File

@ -0,0 +1,80 @@
use crate::{
element::WgpuElement,
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
tensor::WgpuTensor,
};
kernel_wgsl!(Gather, "../../template/index/gather.wgsl");
pub(crate) fn gather<E: WgpuElement, I: WgpuElement, const D: usize>(
dim: usize,
tensor: WgpuTensor<E, D>,
indices: WgpuTensor<I, D>,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let shape_output = indices.shape.clone();
let num_elems = shape_output.num_elements();
let indices = kernel::into_continuous(indices);
let buffer = tensor
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer);
let mut info = build_info(&[&tensor, &output]);
info.push(dim as u32);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<Gather, E, i32, WORKGROUP, WORKGROUP, 1>>();
tensor.context.execute(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[
&tensor.buffer,
&indices.buffer,
&output.buffer,
&info_buffer,
],
);
output
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{backend::Backend, Distribution, Int, Tensor};
#[test]
fn gather_should_work_with_multiple_workgroups() {
TestBackend::seed(0);
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Standard);
let indices = Tensor::<TestBackend, 1, Int>::from_data(
Tensor::<TestBackend, 1>::random([6 * 256], Distribution::Uniform(0., 256.))
.into_data()
.convert(),
)
.reshape([6, 256]);
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let indices_ref =
Tensor::<ReferenceBackend, 2, Int>::from_data(indices.to_data().convert());
let actual = Tensor::<TestBackend, 2>::from_primitive(gather(
1,
tensor.into_primitive(),
indices.into_primitive(),
));
let expected = tensor_ref.gather(1, indices_ref);
expected
.into_data()
.assert_approx_eq(&actual.into_data(), 3);
}
}

View File

@ -0,0 +1,9 @@
mod gather;
mod scatter;
mod select;
mod slice;
pub use gather::*;
pub use scatter::*;
pub use select::*;
pub use slice::*;

View File

@ -0,0 +1,127 @@
use crate::{
element::WgpuElement,
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
tensor::WgpuTensor,
};
kernel_wgsl!(Scatter, "../../template/index/scatter.wgsl");
pub(crate) fn scatter<E: WgpuElement, I: WgpuElement, const D: usize>(
dim: usize,
tensor: WgpuTensor<E, D>,
indices: WgpuTensor<I, D>,
value: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let indices = kernel::into_continuous(indices);
let tensor = kernel::into_continuous(tensor);
let value = kernel::into_continuous(value);
let tensor = match tensor.can_mut() {
true => tensor,
false => tensor.copy(),
};
let mut info = build_info(&[&tensor]);
let mut strides = [0; D];
let mut current = 1;
let mut num_elems_per_workgroup = 1;
tensor
.shape
.dims
.iter()
.enumerate()
.rev()
.filter(|(index, _val)| *index != dim)
.for_each(|(index, val)| {
strides[index] = current;
current *= val;
num_elems_per_workgroup *= tensor.shape.dims[index];
});
strides
.into_iter()
.for_each(|stride| info.push(stride as u32));
info.push(dim as u32);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<Scatter, E, i32, WORKGROUP, WORKGROUP, 1>>();
tensor.context.execute(
elemwise_workgroup(num_elems_per_workgroup, WORKGROUP),
kernel,
&[&tensor.buffer, &indices.buffer, &value.buffer, &info_buffer],
);
tensor
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{backend::Backend, Distribution, Int, Tensor};
#[test]
fn scatter_should_work_with_multiple_workgroups_2d_dim0() {
same_as_reference(0, [256, 32]);
}
#[test]
fn scatter_should_work_with_multiple_workgroups_2d_dim1() {
same_as_reference(1, [32, 256]);
}
#[test]
fn scatter_should_work_with_multiple_workgroups_3d_dim0() {
same_as_reference(0, [256, 6, 6]);
}
#[test]
fn scatter_should_work_with_multiple_workgroups_3d_dim1() {
same_as_reference(1, [6, 256, 6]);
}
#[test]
fn scatter_should_work_with_multiple_workgroups_3d_dim2() {
same_as_reference(2, [6, 6, 256]);
}
fn same_as_reference<const D: usize>(dim: usize, shape: [usize; D]) {
TestBackend::seed(0);
let tensor = Tensor::<TestBackend, D>::random(shape, Distribution::Standard);
let value = Tensor::<TestBackend, D>::random(shape, Distribution::Standard);
let indices = Tensor::<TestBackend, 1, Int>::from_data(
Tensor::<TestBackend, 1>::random(
[shape.iter().product()],
Distribution::Uniform(0., shape[dim] as f32),
)
.into_data()
.convert(),
)
.reshape(shape);
let tensor_ref = Tensor::<ReferenceBackend, D>::from_data(tensor.to_data());
let value_ref = Tensor::<ReferenceBackend, D>::from_data(value.to_data());
let indices_ref =
Tensor::<ReferenceBackend, D, Int>::from_data(indices.to_data().convert());
let actual = Tensor::<TestBackend, D>::from_primitive(scatter(
dim,
tensor.into_primitive(),
indices.into_primitive(),
value.into_primitive(),
));
let expected = tensor_ref.scatter(dim, indices_ref, value_ref);
expected
.into_data()
.assert_approx_eq(&actual.into_data(), 3);
}
}

View File

@ -0,0 +1,186 @@
use crate::{
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
tensor::WgpuTensor,
};
kernel_wgsl!(IndexSelect, "../../template/index/select.wgsl");
kernel_wgsl!(
SelectAssignInplace,
"../../template/index/select_assign_inplace.wgsl"
);
pub(crate) fn select<E: WgpuElement, I: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
dim: usize,
indices: WgpuTensor<I, 1>,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let mut output_shape = tensor.shape.clone();
output_shape.dims[dim] = indices.shape.dims[0];
let num_elems = output_shape.num_elements();
let buffer = tensor
.context
.create_buffer(num_elems * std::mem::size_of::<E>());
let output = WgpuTensor::new(tensor.context.clone(), output_shape, buffer);
let mut info = build_info(&[&tensor, &output]);
info.push(dim as u32);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<IndexSelect, E, I, WORKGROUP, WORKGROUP, 1>>();
tensor.context.execute(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[
&tensor.buffer,
&indices.buffer,
&output.buffer,
&info_buffer,
],
);
output
}
pub(crate) fn select_assign<E: WgpuElement, I: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
dim: usize,
indices: WgpuTensor<I, 1>,
value: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let tensor = match tensor.can_mut() {
true => tensor,
false => tensor.copy(),
};
let mut info = build_info(&[&tensor, &value]);
let mut strides = [0; D];
let mut current = 1;
let mut num_elems_per_workgroup = 1;
tensor
.shape
.dims
.iter()
.enumerate()
.rev()
.filter(|(index, _val)| *index != dim)
.for_each(|(index, val)| {
strides[index] = current;
current *= val;
num_elems_per_workgroup *= tensor.shape.dims[index];
});
strides
.into_iter()
.for_each(|stride| info.push(stride as u32));
info.push(dim as u32);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<SelectAssignInplace, E, I, WORKGROUP, WORKGROUP, 1>>();
tensor.context.execute(
elemwise_workgroup(num_elems_per_workgroup, WORKGROUP),
kernel,
&[&tensor.buffer, &indices.buffer, &value.buffer, &info_buffer],
);
tensor
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{backend::Backend, Distribution, Int, Tensor};
#[test]
fn select_should_work_with_multiple_workgroups() {
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Standard);
let indices = Tensor::<TestBackend, 1, Int>::arange(0..100);
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let indices_ref =
Tensor::<ReferenceBackend, 1, Int>::from_data(indices.to_data().convert());
let actual = select(tensor.into_primitive(), 1, indices.into_primitive());
let expected = tensor_ref.select(1, indices_ref);
expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
3,
);
}
#[test]
fn select_assign_should_work_with_multiple_workgroups_2d_dim0() {
select_assign_same_as_ref(0, [256, 6]);
}
#[test]
fn select_assign_should_work_with_multiple_workgroups_2d_dim1() {
select_assign_same_as_ref(1, [6, 256]);
}
#[test]
fn select_assign_should_work_with_multiple_workgroups_3d_dim0() {
select_assign_same_as_ref(0, [256, 6, 6]);
}
#[test]
fn select_assign_should_work_with_multiple_workgroups_3d_dim1() {
select_assign_same_as_ref(1, [6, 256, 6]);
}
#[test]
fn select_assign_should_work_with_multiple_workgroups_3d_dim2() {
select_assign_same_as_ref(2, [6, 6, 256]);
}
fn select_assign_same_as_ref<const D: usize>(dim: usize, shape: [usize; D]) {
TestBackend::seed(0);
let tensor = Tensor::<TestBackend, D>::random(shape, Distribution::Standard);
let value = Tensor::<TestBackend, D>::random(shape, Distribution::Standard);
let indices = Tensor::<TestBackend, 1, Int>::from_data(
Tensor::<TestBackend, 1>::random(
[shape[dim]],
Distribution::Uniform(0., shape[dim] as f32),
)
.into_data()
.convert(),
);
let tensor_ref = Tensor::<ReferenceBackend, D>::from_data(tensor.to_data());
let value_ref = Tensor::<ReferenceBackend, D>::from_data(value.to_data());
let indices_ref =
Tensor::<ReferenceBackend, 1, Int>::from_data(indices.to_data().convert());
let actual = Tensor::<TestBackend, D>::from_primitive(select_assign(
tensor.into_primitive(),
dim,
indices.into_primitive(),
value.into_primitive(),
));
let expected = tensor_ref.select_assign(dim, indices_ref, value_ref);
expected
.into_data()
.assert_approx_eq(&actual.into_data(), 3);
}
}

View File

@ -0,0 +1,139 @@
use crate::{
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
tensor::WgpuTensor,
};
use burn_tensor::Shape;
use std::ops::Range;
kernel_wgsl!(IndexRaw, "../../template/index/slice.wgsl");
kernel_wgsl!(
IndexAssignInplaceRaw,
"../../template/index/slice_assign_inplace.wgsl"
);
pub(crate) fn slice<E: WgpuElement, const D1: usize, const D2: usize>(
tensor: WgpuTensor<E, D1>,
indices: [Range<usize>; D2],
) -> WgpuTensor<E, D1> {
const WORKGROUP: usize = 32;
let mut dims = tensor.shape.dims;
for i in 0..D2 {
dims[i] = indices[i].end - indices[i].start;
}
let shape_output = Shape::new(dims);
let num_elems = shape_output.num_elements();
let buffer = tensor
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer);
let mut info = build_info(&[&tensor, &output]);
for i in 0..D1 {
let start = indices.get(i).map(|index| index.start).unwrap_or(0);
info.push(start as u32);
}
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<IndexRaw, E, i32, WORKGROUP, WORKGROUP, 1>>();
tensor.context.execute(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&tensor.buffer, &output.buffer, &info_buffer],
);
output
}
pub(crate) fn slice_assign<E: WgpuElement, const D1: usize, const D2: usize>(
tensor: WgpuTensor<E, D1>,
indices: [Range<usize>; D2],
value: WgpuTensor<E, D1>,
) -> WgpuTensor<E, D1> {
const WORKGROUP: usize = 32;
let tensor = match tensor.can_mut() {
true => tensor,
false => tensor.copy(),
};
let num_elems = tensor.shape.num_elements();
let mut info = build_info(&[&tensor, &value]);
for i in 0..D1 {
let start = indices.get(i).map(|index| index.start).unwrap_or(0);
info.push(start as u32);
}
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor.context.compile_static::<KernelSettings<
IndexAssignInplaceRaw,
E,
i32,
WORKGROUP,
WORKGROUP,
1,
>>();
tensor.context.execute(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&tensor.buffer, &value.buffer, &info_buffer],
);
tensor
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{Distribution, Tensor};
#[test]
fn slice_should_work_with_multiple_workgroups() {
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Standard);
let indices = [3..5, 45..256];
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let actual = slice(tensor.into_primitive(), indices.clone());
let expected = tensor_ref.slice(indices);
expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
3,
);
}
#[test]
fn slice_assign_should_work_with_multiple_workgroups() {
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Standard);
let value = Tensor::<TestBackend, 2>::random([2, 211], Distribution::Standard);
let indices = [3..5, 45..256];
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let value_ref = Tensor::<ReferenceBackend, 2>::from_data(value.to_data());
let actual = slice_assign(
tensor.into_primitive(),
indices.clone(),
value.into_primitive(),
);
let expected = tensor_ref.slice_assign(indices, value_ref);
expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
3,
);
}
}

View File

@ -3,6 +3,7 @@ mod binary_elemwise;
mod cat;
mod comparison;
mod comparison_elem;
mod index;
mod mask;
mod matmul;
mod reduction;
@ -20,5 +21,6 @@ pub use unary_scalar::*;
pub(crate) use cat::*;
pub(crate) use comparison::*;
pub(crate) use comparison_elem::*;
pub(crate) use index::*;
pub(crate) use mask::*;
pub(crate) use reduction::*;

View File

@ -59,14 +59,14 @@ mod tests {
burn_tensor::testgen_matmul!();
burn_tensor::testgen_reshape!();
burn_tensor::testgen_transpose!();
burn_tensor::testgen_index!();
burn_tensor::testgen_slice!();
burn_tensor::testgen_aggregation!();
burn_tensor::testgen_arg!();
burn_tensor::testgen_map_comparison!();
burn_tensor::testgen_arange!();
burn_tensor::testgen_mask!();
burn_tensor::testgen_cat!();
burn_tensor::testgen_index_select!();
burn_tensor::testgen_select!();
burn_tensor::testgen_gather_scatter!();
type TestADBackend = burn_autodiff::ADBackendDecorator<TestBackend>;
@ -88,11 +88,11 @@ mod tests {
burn_autodiff::testgen_ad_matmul!();
burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_transpose!();
burn_autodiff::testgen_ad_index!();
burn_autodiff::testgen_ad_slice!();
burn_autodiff::testgen_ad_aggregation!();
burn_autodiff::testgen_ad_cat!();
burn_autodiff::testgen_ad_mask!();
burn_autodiff::testgen_ad_index_select!();
burn_autodiff::testgen_ad_select!();
// Once all operations will be implemented.
// burn_tensor::testgen_all!();

View File

@ -1,18 +1,16 @@
use crate::{
comparison, comparison_elem, comparison_elem_inplace, comparison_inplace,
context::WorkGroup,
element::WgpuElement,
kernel::{
build_info, cat, comparison, comparison_elem, comparison_elem_inplace, comparison_inplace,
mask_fill, mask_fill_inplace, mask_where, mask_where_inplace, KernelSettings,
self, cat, comparison, comparison_elem, comparison_elem_inplace, comparison_inplace,
mask_fill, mask_fill_inplace, mask_where, mask_where_inplace,
},
kernel_wgsl,
pool::get_context,
tensor::WgpuTensor,
GraphicsApi, WgpuDevice,
};
use burn_tensor::{backend::Backend, Data, Shape};
use std::{marker::PhantomData, mem, ops::Range};
use std::{marker::PhantomData, mem};
pub type FloatElem<B> = <B as Backend>::FloatElem;
pub type Device<B> = <B as Backend>::Device;
@ -63,7 +61,7 @@ impl<G: GraphicsApi> BaseOps<G> {
}
pub fn into_data<E: WgpuElement, const D: usize>(tensor: WgpuTensor<E, D>) -> Data<E, D> {
let tensor = Self::into_continuous(tensor);
let tensor = kernel::into_continuous(tensor);
let bytes = tensor.context.read_buffer(tensor.buffer);
let values = E::from_bytes(&bytes);
@ -109,230 +107,11 @@ impl<G: GraphicsApi> BaseOps<G> {
shape: Shape<D2>,
) -> WgpuTensor<E, D2> {
// TODO: Not force standard layout all the time (improve performance).
let tensor = Self::into_continuous(tensor);
let tensor = kernel::into_continuous(tensor);
WgpuTensor::new(tensor.context, shape, tensor.buffer)
}
pub fn into_continuous<E: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
if tensor.is_continuous() {
return tensor;
}
kernel_wgsl!(ContinuousRaw, "../template/continuous.wgsl");
let buffer = tensor
.context
.create_buffer(tensor.shape.num_elements() * core::mem::size_of::<E>());
let output = WgpuTensor::new(tensor.context.clone(), tensor.shape.clone(), buffer);
let info = build_info(&[&tensor, &output]);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<ContinuousRaw, E, i32, 256, 1, 1>>();
tensor.context.execute(
WorkGroup::new(
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
kernel,
&[&tensor.buffer, &output.buffer, &info_buffer],
);
output
}
pub fn index<E: WgpuElement, const D1: usize, const D2: usize>(
tensor: WgpuTensor<E, D1>,
indexes: [Range<usize>; D2],
) -> WgpuTensor<E, D1> {
kernel_wgsl!(IndexRaw, "../template/index/index.wgsl");
let mut dims = tensor.shape.dims;
for i in 0..D2 {
dims[i] = indexes[i].end - indexes[i].start;
}
let shape_output = Shape::new(dims);
let buffer = tensor
.context
.create_buffer(shape_output.num_elements() * core::mem::size_of::<E>());
let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer);
let mut info = build_info(&[&tensor, &output]);
for i in 0..D1 {
let start = indexes.get(i).map(|index| index.start).unwrap_or(0);
info.push(start as u32);
}
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<IndexRaw, E, i32, 256, 1, 1>>();
tensor.context.execute(
WorkGroup::new(
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
kernel,
&[&tensor.buffer, &output.buffer, &info_buffer],
);
output
}
pub fn index_assign<E: WgpuElement, const D1: usize, const D2: usize>(
tensor: WgpuTensor<E, D1>,
indexes: [Range<usize>; D2],
value: WgpuTensor<E, D1>,
) -> WgpuTensor<E, D1> {
kernel_wgsl!(
IndexAssignInplaceRaw,
"../template/index/index_assign_inplace.wgsl"
);
let tensor = match tensor.can_mut() {
true => tensor,
false => tensor.copy(),
};
let mut info = build_info(&[&tensor, &value]);
for i in 0..D1 {
let start = indexes.get(i).map(|index| index.start).unwrap_or(0);
info.push(start as u32);
}
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<IndexAssignInplaceRaw, E, i32, 256, 1, 1>>();
tensor.context.execute(
WorkGroup::new(
f32::ceil(value.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
kernel,
&[&tensor.buffer, &value.buffer, &info_buffer],
);
tensor
}
pub fn index_select<E: WgpuElement, I: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
dim: usize,
indexes: WgpuTensor<I, 1>,
) -> WgpuTensor<E, D> {
kernel_wgsl!(IndexSelect, "../template/index/index_select.wgsl");
let mut output_shape = tensor.shape.clone();
output_shape.dims[dim] = indexes.shape.dims[0];
let buffer = tensor
.context
.create_buffer(std::mem::size_of::<E>() * output_shape.num_elements());
let output = WgpuTensor::new(tensor.context.clone(), output_shape, buffer);
let mut info = build_info(&[&tensor, &output]);
info.push(dim as u32);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<IndexSelect, E, I, 256, 1, 1>>();
tensor.context.execute(
WorkGroup::new(
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
kernel,
&[
&tensor.buffer,
&indexes.buffer,
&output.buffer,
&info_buffer,
],
);
output
}
pub fn index_select_assign<E: WgpuElement, I: WgpuElement, const D: usize, const D2: usize>(
tensor: WgpuTensor<E, D>,
dim: usize,
indexes: WgpuTensor<I, 1>,
values: WgpuTensor<E, D2>,
) -> WgpuTensor<E, D> {
kernel_wgsl!(
IndexSelectAssignInplace,
"../template/index/index_select_assign_inplace.wgsl"
);
let tensor = match tensor.can_mut() {
true => tensor,
false => tensor.copy(),
};
let mut shape = tensor.shape.clone();
shape.dims[dim] = values.shape.dims[dim];
let values = WgpuTensor::new(values.context, shape, values.buffer);
let mut info = build_info(&[&tensor, &values]);
info.push(dim as u32);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<IndexSelectAssignInplace, E, I, 256, 1, 1>>();
let mut shape_tmp = values.shape;
shape_tmp.dims[dim] = 1; // Just one thread for the dim.
tensor.context.execute(
WorkGroup::new(
f32::ceil(shape_tmp.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
kernel,
&[
&tensor.buffer,
&indexes.buffer,
&values.buffer,
&info_buffer,
],
);
tensor
}
pub fn equal<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
@ -501,107 +280,4 @@ impl<G: GraphicsApi> BaseOps<G> {
) -> WgpuTensor<E, D> {
cat(tensors, dim)
}
pub fn gather<E: WgpuElement, I: WgpuElement, const D: usize>(
dim: usize,
tensor: WgpuTensor<E, D>,
indexes: WgpuTensor<I, D>,
) -> WgpuTensor<E, D> {
kernel_wgsl!(Gather, "../template/gather.wgsl");
let shape_output = indexes.shape.clone();
let indexes = Self::into_continuous(indexes);
let buffer = tensor
.context
.create_buffer(shape_output.num_elements() * core::mem::size_of::<E>());
let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer);
let mut info = build_info(&[&tensor, &output]);
info.push(dim as u32);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<Gather, E, i32, 256, 1, 1>>();
tensor.context.execute(
WorkGroup::new(
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
kernel,
&[
&tensor.buffer,
&indexes.buffer,
&output.buffer,
&info_buffer,
],
);
output
}
pub fn scatter<E: WgpuElement, I: WgpuElement, const D: usize>(
dim: usize,
tensor: WgpuTensor<E, D>,
indexes: WgpuTensor<I, D>,
value: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
kernel_wgsl!(Scatter, "../template/scatter.wgsl");
const WORKGROUP: usize = 256;
let indexes = Self::into_continuous(indexes);
let tensor = Self::into_continuous(tensor);
let value = Self::into_continuous(value);
let tensor = match tensor.can_mut() {
true => tensor,
false => tensor.copy(),
};
let mut info = build_info(&[&tensor]);
let mut strides = [0; D];
let mut current = 1;
let mut num_elems_per_workgroup = 1;
tensor
.shape
.dims
.iter()
.enumerate()
.rev()
.filter(|(index, _val)| *index != dim)
.for_each(|(index, val)| {
strides[index] = current;
current *= val;
num_elems_per_workgroup *= tensor.shape.dims[index];
});
strides
.into_iter()
.for_each(|stride| info.push(stride as u32));
info.push(dim as u32);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<Scatter, E, i32, WORKGROUP, 1, 1>>();
tensor.context.execute(
WorkGroup::new(
f32::ceil(num_elems_per_workgroup as f32 / WORKGROUP as f32) as u32,
1,
1,
),
kernel,
&[&tensor.buffer, &indexes.buffer, &value.buffer, &info_buffer],
);
tensor
}
}

View File

@ -1,14 +1,12 @@
use std::ops::Range;
use burn_tensor::{ops::BoolTensorOps, ops::IntTensorOps, Data, Shape};
use super::{BaseOps, BoolTensor, Device, IntTensor};
use crate::{
element::{FloatElement, IntElement},
kernel,
tensor::WgpuTensor,
GraphicsApi, WgpuBackend,
};
use super::{BaseOps, BoolTensor, Device, IntTensor};
use burn_tensor::{ops::BoolTensorOps, ops::IntTensorOps, Data, Shape};
use std::ops::Range;
impl<G, F, I> BoolTensorOps<WgpuBackend<G, F, I>> for WgpuBackend<G, F, I>
where
@ -76,19 +74,19 @@ where
BaseOps::<G>::reshape(tensor, shape)
}
fn bool_index<const D1: usize, const D2: usize>(
fn bool_slice<const D1: usize, const D2: usize>(
tensor: BoolTensor<Self, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> BoolTensor<Self, D1> {
BaseOps::<G>::index(tensor, indexes)
kernel::slice(tensor, ranges)
}
fn bool_index_assign<const D1: usize, const D2: usize>(
fn bool_slice_assign<const D1: usize, const D2: usize>(
tensor: BoolTensor<Self, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: BoolTensor<Self, D1>,
) -> BoolTensor<Self, D1> {
BaseOps::<G>::index_assign(tensor, indexes, value)
kernel::slice_assign(tensor, ranges, value)
}
fn bool_cat<const D: usize>(

View File

@ -1,7 +1,7 @@
use super::numeric::NumericOps;
use super::{BaseOps, BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
use crate::kernel::{
matmul_tiling_2d_default, unary_default, unary_inplace_default, unary_scalar_default,
self, matmul_tiling_2d_default, unary_default, unary_inplace_default, unary_scalar_default,
unary_scalar_inplace_default,
};
use crate::unary_scalar_inplace;
@ -138,8 +138,8 @@ where
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
let lhs = BaseOps::<G>::into_continuous(lhs);
let rhs = BaseOps::<G>::into_continuous(rhs);
let lhs = kernel::into_continuous(lhs);
let rhs = kernel::into_continuous(rhs);
matmul_tiling_2d_default::<FloatElem<Self>, D>(lhs, rhs)
}
@ -162,50 +162,50 @@ where
fn gather<const D: usize>(
dim: usize,
tensor: FloatTensor<Self, D>,
indexes: IntTensor<Self, D>,
indices: IntTensor<Self, D>,
) -> FloatTensor<Self, D> {
BaseOps::<G>::gather(dim, tensor, indexes)
kernel::gather(dim, tensor, indices)
}
fn scatter<const D: usize>(
dim: usize,
tensor: FloatTensor<Self, D>,
indexes: IntTensor<Self, D>,
indices: IntTensor<Self, D>,
value: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
BaseOps::<G>::scatter(dim, tensor, indexes, value)
kernel::scatter(dim, tensor, indices, value)
}
fn index_select<const D: usize>(
fn select<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
indexes: IntTensor<Self, 1>,
indices: IntTensor<Self, 1>,
) -> FloatTensor<Self, D> {
BaseOps::<G>::index_select(tensor, dim, indexes)
kernel::select(tensor, dim, indices)
}
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
fn select_assign<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
indexes: IntTensor<Self, 1>,
value: FloatTensor<Self, D2>,
) -> FloatTensor<Self, D1> {
BaseOps::<G>::index_select_assign(tensor, dim, indexes, value)
indices: IntTensor<Self, 1>,
value: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
kernel::select_assign(tensor, dim, indices, value)
}
fn index<const D1: usize, const D2: usize>(
fn slice<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> FloatTensor<Self, D1> {
BaseOps::<G>::index(tensor, indexes)
kernel::slice(tensor, ranges)
}
fn index_assign<const D1: usize, const D2: usize>(
fn slice_assign<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: FloatTensor<Self, D1>,
) -> FloatTensor<Self, D1> {
BaseOps::<G>::index_assign(tensor, indexes, value)
kernel::slice_assign(tensor, ranges, value)
}
fn mask_where<const D: usize>(

View File

@ -1,11 +1,9 @@
use std::ops::Range;
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Shape};
use crate::{
element::{FloatElement, IntElement},
GraphicsApi, WgpuBackend,
kernel, GraphicsApi, WgpuBackend,
};
use burn_tensor::{ops::IntTensorOps, Data, Shape};
use std::ops::Range;
use super::{numeric::NumericOps, BaseOps, BoolTensor, Device, IntElem, IntTensor};
@ -52,19 +50,19 @@ where
BaseOps::<G>::reshape(tensor, shape)
}
fn int_index<const D1: usize, const D2: usize>(
fn int_slice<const D1: usize, const D2: usize>(
tensor: IntTensor<Self, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
) -> IntTensor<Self, D1> {
BaseOps::<G>::index(tensor, indexes)
kernel::slice(tensor, ranges)
}
fn int_index_assign<const D1: usize, const D2: usize>(
fn int_slice_assign<const D1: usize, const D2: usize>(
tensor: IntTensor<Self, D1>,
indexes: [Range<usize>; D2],
ranges: [Range<usize>; D2],
value: IntTensor<Self, D1>,
) -> IntTensor<Self, D1> {
BaseOps::<G>::index_assign(tensor, indexes, value)
kernel::slice_assign(tensor, ranges, value)
}
fn int_mask_where<const D: usize>(
@ -86,35 +84,35 @@ where
fn int_gather<const D: usize>(
dim: usize,
tensor: IntTensor<Self, D>,
indexes: IntTensor<Self, D>,
indices: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
BaseOps::<G>::gather(dim, tensor, indexes)
kernel::gather(dim, tensor, indices)
}
fn int_scatter<const D: usize>(
dim: usize,
tensor: IntTensor<Self, D>,
indexes: IntTensor<Self, D>,
indices: IntTensor<Self, D>,
value: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
BaseOps::<G>::scatter(dim, tensor, indexes, value)
kernel::scatter(dim, tensor, indices, value)
}
fn int_index_select_dim<const D: usize>(
_tensor: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_dim: usize,
_indexes: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
) -> <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
fn int_select<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
indices: IntTensor<Self, 1>,
) -> IntTensor<Self, D> {
kernel::select(tensor, dim, indices)
}
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
_tensor: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
_dim: usize,
_indexes: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
_value: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<D2>,
) -> <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<D1> {
todo!()
fn int_select_assign<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
indices: IntTensor<Self, 1>,
value: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
kernel::select_assign(tensor, dim, indices, value)
}
fn int_cat<const D: usize>(tensors: Vec<IntTensor<Self, D>>, dim: usize) -> IntTensor<Self, D> {

View File

@ -57,22 +57,22 @@ where
todo!()
}
fn max_pool2d_with_indexes(
fn max_pool2d_with_indices(
_x: <WgpuBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
) -> burn_tensor::ops::MaxPool2dWithIndexes<WgpuBackend<G, F, I>> {
) -> burn_tensor::ops::MaxPool2dWithIndices<WgpuBackend<G, F, I>> {
todo!()
}
fn max_pool2d_with_indexes_backward(
fn max_pool2d_with_indices_backward(
_x: <WgpuBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
_output_grad: <WgpuBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_indexes: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<4>,
_indices: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<4>,
) -> burn_tensor::ops::MaxPool2dBackward<WgpuBackend<G, F, I>> {
todo!()
}

View File

@ -4,7 +4,7 @@ var<storage, read> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> indexes: array<{{ int }}>;
var<storage, read> indices: array<{{ int }}>;
@group(0)
@binding(2)
@ -14,9 +14,15 @@ var<storage, read_write> output: array<{{ elem }}>;
@binding(3)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let rank = info[0];
let dim = info[4u * rank + 1u];
@ -31,9 +37,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
} else {
let stride_output = info[i + rank];
let shape_output = info[i + 3u * rank];
index_input += global_id.x / stride_output % shape_output * stride_input;
index_input += id / stride_output % shape_output * stride_input;
}
}
output[global_id.x] = input[index_input + u32(indexes[global_id.x]) * stride];
output[id] = input[index_input + u32(indices[id]) * stride];
}

View File

@ -1,62 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> indexes: array<{{ int }}>;
@group(0)
@binding(2)
var<storage, read> values: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
@compute
@workgroup_size({{ workgroup_size_x }}, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let rank = info[0];
let dim = info[4u * rank + 1u];
var index_input_offset = 0u;
var index_values_offset = 0u;
var stride_input_dim = 0u;
var stride_values_dim = 0u;
var shape_input_dim = 0u;
var shape_values_dim = 0u;
var num_elem = 1u;
for (var i = 1u; i <= rank; i++) {
let stride_input = info[i];
let stride_values = info[i + rank];
let shape_input = info[i + 2u * rank];
let shape_values = info[i + 3u * rank];
if i - 1u != dim {
index_input_offset += global_id.x / stride_input % shape_input * stride_input;
index_values_offset += global_id.x / stride_values % shape_values * stride_values;
num_elem += shape_input;
} else {
shape_input_dim = shape_input;
shape_values_dim = shape_values;
stride_input_dim = stride_input;
stride_values_dim = stride_values;
}
}
if global_id.x > num_elem {
return;
}
for (var i = 0u; i < shape_values_dim; i++) {
let index = u32(indexes[i]);
input[index_input_offset + index * stride_input_dim] += values[index_values_offset + i * stride_values_dim];
}
}

View File

@ -4,7 +4,7 @@ var<storage, read_write> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> indexes: array<{{ int }}>;
var<storage, read> indices: array<{{ int }}>;
@group(0)
@binding(2)
@ -14,11 +14,18 @@ var<storage, read> value: array<{{ elem }}>;
@binding(3)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let rank = info[0];
let dim = info[3u * rank + 1u];
let shape = info[dim + rank + 1u];
let stride = info[dim + 1u];
@ -31,17 +38,17 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let shape_input = info[i + rank];
let stride_tmp = info[i + 2u * rank];
index_offset += global_id.x / stride_tmp % shape_input * stride_input;
num_elems *= shape_input;
index_offset += id / stride_tmp % shape_input * stride_input;
}
}
if global_id.x >= num_elems {
if id >= num_elems {
return;
}
for (var i = 0u; i < shape; i++) {
let index = i * stride + index_offset;
input[index_offset + stride * u32(indexes[index])] += value[index];
input[index_offset + stride * u32(indices[index])] += value[index];
}
}

View File

@ -4,7 +4,7 @@ var<storage, read> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> indexes: array<{{ int }}>;
var<storage, read> indices: array<{{ int }}>;
@group(0)
@binding(2)
@ -14,9 +14,15 @@ var<storage, read_write> output: array<{{ elem }}>;
@binding(3)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let rank = info[0];
let dim = info[4u * rank + 1u];
var index_input = 0u;
@ -27,15 +33,15 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let shape_input = info[i + 2u * rank];
let shape_output = info[i + 3u * rank];
let index = global_id.x / stride_output % shape_output;
let index = id / stride_output % shape_output;
if i - 1u == dim {
index_input += u32(indexes[index]) * stride_input;
index_input += u32(indices[index]) * stride_input;
} else {
index_input += index * stride_input;
}
}
output[global_id.x] = input[index_input];
output[id] = input[index_input];
}

View File

@ -0,0 +1,61 @@
@group(0)
@binding(0)
var<storage, read_write> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> indices: array<{{ int }}>;
@group(0)
@binding(2)
var<storage, read> values: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let rank = info[0];
let dim = info[5u * rank + 1u];
let dim_stride_input = info[dim + 1u];
let dim_stride_value = info[dim + rank + 1u];
let dim_shape_value = info[dim + 3u * rank + 1u];
var num_elems = 1u;
var index_input_offset = 0u;
var index_value_offset = 0u;
var num_elem = 1u;
for (var i = 1u; i <= rank; i++) {
if i - 1u != dim {
let stride_input = info[i];
let stride_value = info[i + rank];
let shape_input = info[i + 2u * rank];
let shape_value = info[i + 3u * rank];
let stride_tmp = info[i + 4u * rank];
num_elem *= shape_input;
index_input_offset += id / stride_tmp % shape_input * stride_input;
index_value_offset += id / stride_tmp % shape_value * stride_value;
}
}
if id >= num_elem {
return;
}
for (var i = 0u; i < dim_shape_value; i++) {
let index = u32(indices[i]);
input[index_input_offset + index * dim_stride_input] += values[index_value_offset + i * dim_stride_value];
}
}

View File

@ -10,9 +10,15 @@ var<storage, read_write> output: array<{{ elem }}>;
@binding(2)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim: u32 = info[0];
var index_input: u32 = 0u;
@ -22,10 +28,10 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let shape_output = info[i + 3u * dim];
let start = info[i + 4u * dim];
let num_block = global_id.x / stride_output % shape_output + start;
let num_block = id / stride_output % shape_output + start;
index_input += num_block * stride_input;
}
output[global_id.x] = input[index_input];
output[id] = input[index_input];
}

View File

@ -10,12 +10,19 @@ var<storage, read> value: array<{{ elem }}>;
@binding(2)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim: u32 = info[0];
var index_input: u32 = 0u;
var index_value: u32 = 0u;
var num_elems = 0u;
for (var i: u32 = 1u; i <= dim; i++) {
let stride_input = info[i];
@ -24,7 +31,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let shape_value = info[i + 3u * dim];
let start = info[i + 4u * dim];
let num_block = global_id.x / stride_value % shape_value;
let num_block = id / stride_value % shape_value;
index_input += (num_block + start) * stride_input;
index_value += num_block * stride_value;

View File

@ -23,7 +23,7 @@ fn main(
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
// Indexes
// Indices
let row = workgroup_id.x * BLOCK_SIZE + (local_idx / BLOCK_SIZE);
let col = workgroup_id.y * BLOCK_SIZE + (local_idx % BLOCK_SIZE);
let batch = global_id.z;

View File

@ -19,7 +19,7 @@ var<storage, read> info: array<u32>;
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
// Indexes
// Indices
let row = global_id.x;
let col = global_id.y;
let batch = global_id.z;

View File

@ -65,7 +65,7 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
// Print out predictions for each sample
for (i, text) in samples.into_iter().enumerate() {
let prediction = predictions.clone().index([i..i + 1]); // Get prediction for current sample
let prediction = predictions.clone().slice([i..i + 1]); // Get prediction for current sample
let logits = prediction.to_data(); // Convert prediction tensor to data
let class_index = prediction.argmax(1).into_data().convert::<i32>().value[0]; // Get class index with the highest value
let class = D::class_name(class_index as usize); // Get class name

View File

@ -110,7 +110,7 @@ impl<B: Backend> TextClassificationModel<B> {
let output = self.output.forward(encoded);
let output_classification = output
.index([0..batch_size, 0..1])
.slice([0..batch_size, 0..1])
.reshape([batch_size, self.n_classes]);
let loss = CrossEntropyLoss::new(None);
@ -148,7 +148,7 @@ impl<B: Backend> TextClassificationModel<B> {
.forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad));
let output = self.output.forward(encoded);
let output = output
.index([0..batch_size, 0..1])
.slice([0..batch_size, 0..1])
.reshape([batch_size, self.n_classes]);
softmax(output, 1)

View File

@ -57,9 +57,9 @@ impl<B: Backend> Batcher<TextGenerationItem, TrainingTextGenerationBatch<B>>
let inputs = item
.tokens
.clone()
.index([0..batch_size, 0..seq_length - 1]);
let targets = item.tokens.index([0..batch_size, 1..seq_length]);
let mask_pad = item.mask_pad.index([0..batch_size, 0..seq_length - 1]);
.slice([0..batch_size, 0..seq_length - 1]);
let targets = item.tokens.slice([0..batch_size, 1..seq_length]);
let mask_pad = item.mask_pad.slice([0..batch_size, 0..seq_length - 1]);
TrainingTextGenerationBatch::new(inputs, targets, mask_pad)
}